# -*- coding: utf-8 -*- """ A very simple testing library intended for use by students in an intro course. Includes capabilities for mocking input and setting up checks for the results and/or printed output of arbitrary expressions. optimism.py ## Example usage ```py from optimism import * # Simple example function def f(x, y): "Example function" return x + y + 1 # Simple test for that function test = testFunction(f) # line 9 case = test.case(1, 2) case.checkReturnValue(4) # line 11 # Function that prints multiple lines of printed output def display(message): print("The message is:") print('-' + message + '-') # One test case with two lines of printed output test = testFunction(display) # line 20 case = test.case('hello') case.checkPrintedLines('The message is:', '-hello-') # line 22 # A function that uses input def askNameAge(): name = input("What is your name? ") age = input("How old are you? ") return (name, age) test = testFunction(askNameAge) # line 31 case = test.case() case.provideInputs("Name", "thirty") case.checkReturnValue(('Name', 'thirty')) # line 34 case.checkPrintedLines( # line 35 'What is your name? Name', 'How old are you? thirty' # line 37 ) ``` If we were to run this example file, we should see the following printed output: ```txt ✓ example.py:11 ✓ example.py:22 ✓ example.py:34 ✓ example.py:35 ``` If an expectation is not met, the printed output would look different. For example, if our first expectation had been 3 instead of 4, we'd see: ```txt ✗ example.py:11 Result: 4 was NOT equivalent to the expected value: 3 Called function 'f' with arguments: x = 1 y = 2 ``` Note that line numbers reported for function calls that span multiple lines may be different between Python versions before 3.8 and versions 3.8+. ## Core functionality The main functions you'll need are: - `trace` works like `print`, but shows some extra information, and you can use it as part of a larger expression. Use this for figuring out what's going wrong when your tests don't pass. - `expect` takes two arguments and prints a check-mark if they're equivalent or an x if not. If they aren't equivalent, it prints details about values that are part of the first expression. Use this for expectations-based debugging. - `expectType` works like `expect`, but the second argument is a type, and it checks whether the value of the first argument is an instance of that type or not. - `testFunction` establishes a test manager object, which can be used to create test cases. When you call `testFunction` you specify the function that you want to test. The `.case` method of the resulting `TestManager` object can be used to set up individual test cases. - `testFile` establishes a test manager object just like `testFunction`, but for running an entire file instead of for calling a single function. - `testBlock` establishes a test manager object just like `testFunction`, but for running a block of code (given as a string). - The `TestManager.case` method establishes a test case by specifying what arguments are going to be used with the associated function. It returns a `TestCase` object that can be used to establish expectations. For test managers based on files or blocks, no arguments are needed. - The `TestCase.checkReturnValue` and `TestCase.checkPrintedLines`, and/or `TestCase.checkCustom` methods can be used to run checks for the return value and/or printed output of a test case. - Normally, output printed during tests is hidden, but `showPrintedLines` can be used to show the text that's being captured instead. - `TestCase.provideInputs` sets up inputs for a test case, so that interactive code can be tested without pausing for real user input. - `detailLevel` can be called to control the level of detail printed in the output. It affects all tracing, expectation, and testing output produced until it gets called again. - `showSummary` can be used to summarize the number of checks which passed or failed. - `colors` can be used to enable or disable color codes for printed text. Disable this if you're getting garbled printed output. TODO: Workaround for tracing in interactive console? TODO: Prevent crash if expr_src is unavailable when tracing... ## Changelog - Version 2.0 introduced the `TestManager` and `TestCase` classes, and got rid of automatic tracking for test cases. The old test case functionality was moved over to the `expect` function. This helps make tests more stable and makes meta-reasoning easier. - Version 2.2.0 changed the names of `checkResult` and `checkOutputLines` to `checkReturnValue` and checkPrintedLines` - Version 2.3.0 adds `testFunctionMaybe` for skippable tests if the target function hasn't been defined yet. - Version 2.4.0 adds a more object-oriented structure behind the scenes, without changing any core API functions. It also adds support for variable specifications in block tests. - Version 2.6.2 fixes a bug with dictionary comparison which caused a crash when key sets weren't equal. It also adds unit tests for the `compare` function. - Version 2.6.3 introduces the `checkFileContainsLines` method, and also standardizes the difference-finding code and merges it with equality-checking code, removing `checkEquality` and introducing `findFirstDifference` instead (`compare`) remains but just calls `findFirstDifference` internally. Also fixes a bug w/ printing tracebacks for come checkers (need to standardize that!). Also adds global skip-on-failure and sets that as the default! - Version 2.6.4 immediately changes `checkFileContainsLines` to `checkFileLines` to avoid confusion about whether we're checking against the whole file (we are). - Version 2.6.5 fixes a bug with displaying filenames when a file does not exist and `checkFileLines` is used, and also sets the default length higher for printing first differing lines in `findFirstDifference` since those lines are displayed on their own line anyways. Fixes a bug where list differences were not displayed correctly, and improves the usefulness of first differences displayed for dictionaries w/ different lengths. Also fixes a bug where strings which were equivalent modulo trailing whitespace would not be treated as equivalent. - Version 2.6.6 changes from splitlines to split('\\n') in a few places because the latter is more robust to extra carriage returns. This changes how some error messages look and it means that in some places the newline at the end of a file actually counts as having a blank line there in terms of output (behavior is mostly unchanged). Also escaped carriage returns in displayed strings so they're more visible. From this version we do NOT support files that use just '\\r' as a newline as easily. But `IGNORE_TRAILING_WHITESPACE` will properly get rid of any extra '\\r' before a newline, so when that's on (the default) this shouldn't change much. A test of this behavior was added to the file test example. - Version 2.6.7 changes default SKIP_ON_FAILURE back to 'case', since 'all' makes interactive testing hard, and dedicated test files can call skipChecksAfterFail at the top. Also fixes an issue where comparing printed output correctly is lenient about the presence or absence of a final newline, but comparing file contents didn't do that. This change means that extra blank lines (including lines with whitespace on them) are ignored when comparing strings and IGNORE_TRAILING_WHITESPACE is on, and even when IGNORE_TRAILING_WHITESPACE is off, the presence or absence of a final newline in a file or in printed output will be copied over to a multi-line expectation (since otherwise there's no way to specify the lack of a final newline when giving multiple string arguments to `checkPrintedLines` or `checkFileLines`). """ # TODO: Cache compiled ASTs! __version__ = "2.6.7" import sys import traceback import inspect import ast import copy import io import os import re import types import builtins import cmath import textwrap #---------# # Globals # #---------# ALL_CASES = {} """ All test cases that have been created, organized by test-suite names. By default all cases are added to the 'default' test suite, but this can be changed using `startTestSuite`. Each entry has a test suite name as the key and a list of `TestCase` objects as the value. """ _CURRENT_SUITE_NAME = "default" """ The name of the current test suite, which organizes newly-created test cases within the ALL_CASES variable. Use `startTestSuite` to begin/resume a new test suite, and `currentTestSuite` to retrieve the value. """ COMPLETED_PER_LINE = {} """ A dictionary mapping function names to dictionaries mapping (filename, line-number) pairs to counts. Each count represents the number of functions of that name which have finished execution on the given line of the given file already. This allows us to figure out which expression belongs to which invocation if `get_my_context` is called multiple times from the same line of code. """ DETAIL_LEVEL = 0 """ The current detail level, which controls how verbose our messages are. See `detailLevel`. """ SKIP_ON_FAILURE = "case" """ Controls which checks get skipped when a check fails. If set to `'all'`, ALL checks will be skipped once one fails, until `clearFailure` is called. If set to `'case'` (the default), subsequent checks for the same test case will be skipped when one fails. If set to `'manager'`, then all checks for any case from a case manager will be skipped when any check for any case derived from that manager fails. Any other value will disable the skipping of checks based on failures. """ CHECK_FAILED = False """ Remembers whether we've failed a check yet or not. If True and `SKIP_ON_FAILURE` is set to `'all'`, all checks will be skipped. Use `clearFailure` to reset this and resume checking without changing `SKIP_ON_FAILURE` if you need to. """ COLORS = True """ Whether to print ANSI color control sequences to color the printed output or not. """ MSG_COLORS = { "succeeded": "34", # blue "skipped": "33", # yellow "failed": "1;31", # bright red "reset": "0", # resets color properties } IGNORE_TRAILING_WHITESPACE = True """ Controls equality and inclusion tests on strings, including multiline strings and strings within other data structures, causing them to ignore trailing whitespace. True by default, since trailing whitespace is hard to reason about because it's invisible. Trailing whitespace is any sequence whitespace characters before a newline character (which is the only thing we count as a line break, meaning \\r\\n breaks are only accepted if IGNORE_TRAILING_WHITESPACE is on). Specifically, we use the `rstrip` method after splitting on \\n. Additionally, in multi-line scenarios, if there is a single extra line containing just whitespace, that will be ignored, although that's not the same as applying `rstrip` to the entire string, since if there are multiple extra trailing newline characters, that still counts as a difference. """ _SHOW_OUTPUT = False """ Controls whether or not output printed during tests appears as normal or is suppressed. Control this using the `showPrintedLines` function. """ FLOAT_REL_TOLERANCE = 1e-8 """ The relative tolerance for floating-point similarity (see `cmath.isclose`). """ FLOAT_ABS_TOLERANCE = 1e-8 """ The absolute tolerance for floating-point similarity (see `cmath.isclose`). """ #--------# # Errors # #--------# class TestError(Exception): """ An error with the testing mechanisms, as opposed to an error with the actual code being tested. """ pass #-------------------# # Test Case Classes # #-------------------# class NoResult: """ A special class used to indicate the absence of a result when None is a valid result value. """ pass class TestCase: """ Represents a specific test to run, managing things like specific arguments, inputs or available variables that need to be in place. Derived from a `TestManager` using the `TestManager.case` method. `TestCase` is abstract; subclasses should override a least the `run` and `testDetails` functions. """ def __init__(self, manager): """ A manager must be specified, but that's it. This does extra things like registering the case in the current test suite (see `startTestSuite`) and figuring out the location tag for the case. """ self.manager = manager # Inputs to provide on stdin self.inputs = None # Results of running this case self.results = None # Whether to echo captured printed outputs (overrides global) self.echo = None # Location and tag for test case creation self.location = get_my_location() self.tag = tag_for(self.location) # List of outcomes of checks that were made. Each is a triple # with a True/False indicator for success/failure, a string tag # for the expectation, and a full result message. self.outcomes = [] # Whether or not a check has failed for this case yet. self.any_failed = False # Register as a test case ALL_CASES.setdefault(_CURRENT_SUITE_NAME, []).append(self) def provideInputs(self, *inputLines): """ Sets up fake inputs (each argument must be a string and is used for one line of input) for this test case. When information is read from stdin during the test, including via the `input` function, these values are the result. If you don't call `provideInputs`, then the test will pause and wait for real user input when `input` is called. You must call this before the test actually runs (i.e., before `TestCase.run` or one of the `check` functions is called), otherwise you'll get an error. """ if self.results is not None: raise TestError( "You cannot provide inputs because this test case has" " already been run." ) self.inputs = inputLines def showPrintedLines(self, show=True): """ Overrides the global `showPrintedLines` setting for this test. Use None as the parameter to remove the override. """ self.echo = show def _run(self, payload): """ Given a payload (a zero-argument function), runs the payload while managing things like output capturing and input mocking. Sets the `self.results` field to reflect the results of the run, which will be a dictionary that has the following slots: - "result": The result value from a function call. This key will not be present for tests that don't have a result, like file or code block tests. To achieve this with a custom payload, have the payload return `NoResult`. - "output": The output printed during the test. Will be an empty string if nothing gets printed. - "error": An Exception object representing an error that occurred during the test, or None if no errors happened. - "traceback": If an exception occurred, this will be a string containing the traceback for that exception. Otherwise it will be None. In addition to being added to the results slot, this dictionary is also returned. """ # Set up the `input` function to echo what is typed original_input = builtins.input def echoing_input(prompt): """ A stand-in for the built-in input which echoes the received input to stdout, under the assumption that stdin will NOT be echoed to the output stream because the output stream is not the console any more. """ nonlocal original_input result = original_input(prompt) sys.stdout.write(result + '\n') return result builtins.input = echoing_input # Set up a capturing stream for output outputCapture = CapturingStream() outputCapture.install() if self.echo or (self.echo is None and _SHOW_OUTPUT): outputCapture.echo() # Set up fake input contents if self.inputs is not None: fakeInput = io.StringIO('\n'.join(self.inputs)) original_stdin = sys.stdin sys.stdin = fakeInput # Set up default values before we run things error = None tb = None value = NoResult # Actually run the test try: value = payload() except Exception as e: # Catch any error that occurred error = e tb = traceback.format_exc() finally: # Release stream captures and reset the input function outputCapture.uninstall() builtins.input = original_input if self.inputs is not None: sys.stdin = original_stdin # Grab captured output output = outputCapture.getvalue() # Create self.results w/ output, error, and maybe result value self.results = { "output": output, "error": error, "traceback": tb } if value is not NoResult: self.results["result"] = value # Return new results object return self.results def run(self): """ Runs this test case, capturing printed output and supplying fake input if `TestCase.provideInputs` has been called. Stores the results in `self.results`. This will be called once automatically the first time an expectation method like `TestCase.checkReturnValue` is used, but the cached value will be re-used for subsequent expectations, unless you manually call this method again. This method is overridden by specific test case types. """ raise NotImplementedError( "Cannot run a TestCase; you must create a specific kind of" " test case like a FunctionCase to be able to run it." ) def fetchResults(self): """ Fetches the results of the test, which will run the test if it hasn't already been run, but otherwise will just return the latest cached results. `run` describes the format of the results. """ if self.results is None: self.run() return self.results def _create_success_message( self, tag, details, extra_details=None, include_test_details=True ): """ Returns an expectation success message (a string) for an expectation with the given tag, using the given details and extra details. Unless `include_test_details` is set to False, details of the test expression/block will also be included (but only when the detail level is at least 1). The tag should be a filename:lineno string indicating where the expectation originated. """ # Detail level 1 gives more output for successes if DETAIL_LEVEL < 1: result = f"✓ {tag}" else: # Detail level is at least 1 result = ( f"✓ expectation from {tag} met for test case at {self.tag}" ) detail_msg = indent(details, 2) if not detail_msg.startswith('\n'): detail_msg = '\n' + detail_msg if DETAIL_LEVEL >= 2 and extra_details: extra_detail_msg = indent(extra_details, 2) if not extra_detail_msg.startswith('\n'): extra_detail_msg = '\n' + extra_detail_msg detail_msg += extra_detail_msg # Test details unless suppressed if include_test_details: test_base, test_extra = self.testDetails() detail_msg += '\n' + indent(test_base, 2) if DETAIL_LEVEL >= 2 and test_extra is not None: detail_msg += '\n' + indent(test_extra, 2) result += detail_msg return result def _create_failure_message( self, tag, details, extra_details=None, include_test_details=True ): """ Creates a failure message string for an expectation with the given tag that includes the details and/or extra details depending on the current global detail level. Normally, information about the test that was run is included as well, but you can set `include_test_details` to False to prevent this. """ # Detail level controls initial message if DETAIL_LEVEL < 1: result = f"✗ {tag}" else: result = ( f"✗ expectation from {tag} NOT met for test case at" f" {self.tag}" ) # Assemble our details message detail_msg = '' # Detail level controls printing of detail messages if DETAIL_LEVEL >= 0: detail_msg += '\n' + indent(details, 2) if DETAIL_LEVEL >= 1 and extra_details: detail_msg += '\n' + indent(extra_details, 2) # Test details unless suppressed if include_test_details: test_base, test_extra = self.testDetails() if DETAIL_LEVEL >= 0: detail_msg += '\n' + indent(test_base, 2) if DETAIL_LEVEL >= 1 and test_extra is not None: detail_msg += '\n' + indent(test_extra, 2) return result + detail_msg def _print_skip_message(self, tag, reason): """ Prints a standard message about the check being skipped, using the given tag and a reason (shown only if detail level is 1+). """ # Detail level controls initial message if DETAIL_LEVEL < 1: msg = f"~ {tag} (skipped)" else: msg = ( f"~ check at {tag} for test case at {self.tag} skipped" f" ({reason})" ) print_message(msg, color=msg_color("skipped")) def testDetails(self): """ Returns a pair of strings containing base and extra details describing what was tested by this test case. If the base details capture all available information, the extra details value will be None. This method is abstract and only sub-class implementations actually do anything. """ raise NotImplementedError( "Cannot get test details for a TestCase; you must create a" " specific kind of test case like a FunctionCase to be able" " to get test details." ) def _should_skip(self): """ Returns True if this check should be skipped based on a previous failure. """ return ( (SKIP_ON_FAILURE == "all" and CHECK_FAILED) or (SKIP_ON_FAILURE == "case" and self.any_failed) or (SKIP_ON_FAILURE == "manager" and self.manager.any_failed) ) def _register_outcome(self, passed, tag, message): """ Registers an outcome for this test case. `passed` should be either True or False indicating whether the check passed, `tag` is a string to label the outcome with, and `message` is the message displayed by the check. This appends an entry to `self.outcomes` with the passed boolean, the tag, and the message in a tuple, and it sets `self.any_failed` and `self.manager.any_failed` if the outcome is a failure. """ global CHECK_FAILED self.outcomes.append((passed, tag, message)) if not passed: CHECK_FAILED = True self.any_failed = True self.manager.any_failed = True def checkReturnValue(self, expectedValue): """ Checks the result value for this test case, comparing it against the given expected value and printing a message about success or failure depending on whether they are considered different by the `findFirstDifference` function. If this is the first check performed using this test case, the test case will run; otherwise a cached result will be used. This method returns True if the expectation is met and False if it is not, in addition to printing a message indicating success/failure and recording that message along with the status and tag in `self.outcomes`. If the check is skipped, it returns None and does not add an entry to `self.outcomes`. """ results = self.fetchResults() # Figure out the tag for this expectation tag = tag_for(get_my_location()) # Skip this check if the case has failed already if self._should_skip(): self._print_skip_message(tag, "prior test failed") # Note that we don't add an outcome here, and we return None # instead of True or False return None # Figure out whether we've got an error or an actual result if results["error"] is not None: # An error during testing tb = results["traceback"] tblines = tb.splitlines() if len(tblines) < 12: base_msg = "Failed due to an error:\n" + indent(tb, 2) extra_msg = None else: short_tb = '\n'.join(tblines[:4] + ['...'] + tblines[-4:]) base_msg = "Failed due to an error:\n" + indent(short_tb, 2) extra_msg = "Full traceback is:\n" + indent(tb, 2) msg = self._create_failure_message( tag, base_msg, extra_msg ) print_message(msg, color=msg_color("failed")) self._register_outcome(False, tag, msg) return False elif "result" not in results: # Likely impossible, since we verified the category above # and we're in a condition where no error was logged... msg = self._create_failure_message( tag, ( "This test case does not have a result value. (Did" " you mean to use checkPrintedLines?)" ) ) print_message(msg, color=msg_color("failed")) self._register_outcome(False, tag, msg) return False else: # We produced a result, so check equality # Check equivalence passed = False firstDiff = findFirstDifference(results["result"], expectedValue) if firstDiff is None: equivalence = "equivalent to" passed = True else: equivalence = "NOT equivalent to" # Get short/long versions of result/expected short_result = ellipsis(repr(results["result"]), 72) full_result = repr(results["result"]) short_expected = ellipsis(repr(expectedValue), 72) full_expected = repr(expectedValue) # Create base/extra messages if ( short_result == full_result and short_expected == full_expected ): base_msg = ( f"Result:\n{indent(short_result, 2)}\nwas" f" {equivalence} the expected value:\n" f"{indent(short_expected, 2)}" ) extra_msg = None if ( firstDiff is not None and differencesAreSubtle(short_result, short_expected) ): base_msg += ( f"\nFirst difference was:\n{indent(firstDiff, 2)}" ) else: base_msg = ( f"Result:\n{indent(short_result, 2)}\nwas" f" {equivalence} the expected value:\n" f"{indent(short_expected, 2)}" ) extra_msg = "" if ( firstDiff is not None and differencesAreSubtle(short_result, short_expected) ): base_msg += ( f"\nFirst difference was:\n{indent(firstDiff, 2)}" ) if short_result != full_result: extra_msg += ( f"Full result:\n{indent(full_result, 2)}\n" ) if short_expected != full_expected: extra_msg += ( f"Full expected value:\n" f"{indent(full_expected, 2)}\n" ) if passed: msg = self._create_success_message( tag, base_msg, extra_msg ) print_message(msg, color=msg_color("succeeded")) self._register_outcome(True, tag, msg) return True else: msg = self._create_failure_message( tag, base_msg, extra_msg ) print_message(msg, color=msg_color("failed")) self._register_outcome(False, tag, msg) return False def checkPrintedLines(self, *expectedLines): """ Checks that the exact printed output captured during the test matches a sequence of strings each specifying one line of the output. Note that the global `IGNORE_TRAILING_WHITESPACE` affects how this function treats line matches. If this is the first check performed using this test case, the test case will run; otherwise a cached result will be used. This method returns True if the check succeeds and False if it fails, in addition to printing a message indicating success/failure and recording that message along with the status and tag in `self.outcomes`. If the check is skipped, it returns None and does not add an entry to `self.outcomes`. """ # Fetch captured output results = self.fetchResults() output = results["output"] # Figure out the tag for this expectation tag = tag_for(get_my_location()) # Skip this check if the case has failed already if self._should_skip(): self._print_skip_message(tag, "prior test failed") # Note that we don't add an outcome here, and we return None # instead of True or False return None # Figure out whether we've got an error or an actual result if results["error"] is not None: # An error during testing tb = results["traceback"] tblines = tb.splitlines() if len(tblines) < 12: base_msg = "Failed due to an error:\n" + indent(tb, 2) extra_msg = None else: short_tb = '\n'.join(tblines[:4] + ['...'] + tblines[-4:]) base_msg = "Failed due to an error:\n" + indent(short_tb, 2) extra_msg = "Full traceback is:\n" + indent(tb, 2) msg = self._create_failure_message( tag, base_msg, extra_msg ) print_message(msg, color=msg_color("failed")) self._register_outcome(False, tag, msg) return False else: # We produced printed output, so check it # Get lines/single versions expected = '\n'.join(expectedLines) + '\n' # If the output doesn't end with a newline, don't add one to # our expectation either... if not output.endswith('\n'): expected = expected[:-1] # Figure out equivalence category equivalence = None passed = False firstDiff = findFirstDifference(output, expected) if output == expected: equivalence = "exactly the same as" passed = True elif firstDiff is None: equivalence = "equivalent to" passed = True else: equivalence = "NOT the same as" # passed remains False # Get short/long representations of our strings short, long = dual_string_repr(output) short_exp, long_exp = dual_string_repr(expected) # Construct base and extra messages if short == long and short_exp == long_exp: base_msg = ( f"Printed lines:\n{indent(short, 2)}\nwere" f" {equivalence} the expected printed" f" lines:\n{indent(short_exp, 2)}" ) if not passed: base_msg += ( f"\nFirst difference was:\n{indent(firstDiff, 2)}" ) extra_msg = None else: base_msg = ( f"Printed lines:\n{indent(short, 2)}\nwere" f" {equivalence} the expected printed" f" lines:\n{indent(short_exp, 2)}" ) if not passed: base_msg += ( f"\nFirst difference was:\n{indent(firstDiff, 2)}" ) extra_msg = "" if short != long: extra_msg += f"Full printed lines:\n{indent(long, 2)}\n" if short_exp != long_exp: extra_msg += ( f"Full expected printed" f" lines:\n{indent(long_exp, 2)}\n" ) if passed: msg = self._create_success_message( tag, base_msg, extra_msg ) print_message(msg, color=msg_color("succeeded")) self._register_outcome(True, tag, msg) return True else: msg = self._create_failure_message( tag, base_msg, extra_msg ) print_message(msg, color="1;31" if COLORS else None) self._register_outcome(False, tag, msg) return False def checkPrintedFragment(self, fragment, copies=1, allowExtra=False): """ Works like checkPrintedLines, except instead of requiring that the printed output exactly match a set of lines, it requires that a certain fragment of text appears somewhere within the printed output (or perhaps that multiple non-overlapping copies appear, if the copies argument is set to a number higher than the default of 1). If allowExtra is set to True, more than the specified number of copies will be ignored, but by default, extra copies are not allowed. The fragment is matched against the entire output as a single string, so it may contain newlines and if it does these will only match newlines in the captured output. If `IGNORE_TRAILING_WHITESPACE` is active (it's on by default), the trailing whitespace in the output will be removed before matching, and trailing whitespace in the fragment will also be removed IF it has a newline after it (trailing whitespace at the end of the string with no final newline will be retained). This function returns True if the check succeeds and False if it fails, and prints a message either way. If the check is skipped, it returns None and does not add an entry to `self.outcomes`. """ # Fetch captured output results = self.fetchResults() output = results["output"] # Figure out the tag for this expectation tag = tag_for(get_my_location()) # Skip this check if the case has failed already if self._should_skip(): self._print_skip_message(tag, "prior test failed") # Note that we don't add an outcome here, and we return None # instead of True or False return None # Figure out whether we've got an error or an actual result if results["error"] is not None: # An error during testing tb = results["traceback"] tblines = tb.splitlines() if len(tblines) < 12: base_msg = "Failed due to an error:\n" + indent(tb, 2) extra_msg = None else: short_tb = '\n'.join(tblines[:4] + ['...'] + tblines[-4:]) base_msg = "Failed due to an error:\n" + indent(short_tb, 2) extra_msg = "Full traceback is:\n" + indent(tb, 2) msg = self._create_failure_message( tag, base_msg, extra_msg ) print_message(msg, color=msg_color("failed")) self._register_outcome(False, tag, msg) return False else: # We produced printed output, so check it if IGNORE_TRAILING_WHITESPACE: matches = re.findall( re.escape(trimWhitespace(fragment, True)), trimWhitespace(output) ) else: matches = re.findall(re.escape(fragment), output) passed = False if copies == 1: copiesPhrase = "" exactly = "" atLeast = "at least " else: copiesPhrase = f"{copies} copies of " exactly = "exactly " atLeast = "at least " fragShort, fragLong = dual_string_repr(fragment) outShort, outLong = dual_string_repr(output) if len(matches) == copies: passed = True base_msg = ( f"Found {exactly}{copiesPhrase}the target" f" fragment in the printed output." f"\nFragment was:\n{indent(fragShort, 2)}" f"\nOutput was:\n{indent(outShort, 2)}" ) elif allowExtra and len(matches) > copies: passed = True base_msg = ( f"Found {atLeast}{copiesPhrase}the target" f" fragment in the printed output (found" f" {len(matches)})." f"\nFragment was:\n{indent(fragShort, 2)}" f"\nOutput was:\n{indent(outShort, 2)}" ) else: passed = False base_msg = ( f"Did not find {copiesPhrase}the target fragment" f" in the printed output (found {len(matches)})." f"\nFragment was:\n{indent(fragShort, 2)}" f"\nOutput was:\n{indent(outShort, 2)}" ) extra_msg = "" if fragLong != fragShort: extra_msg += f"Full fragment was:\n{indent(fragLong, 2)}" if outLong != outShort: if not extra_msg.endswith('\n'): extra_msg += '\n' extra_msg += f"Full output was:\n{indent(outLong, 2)}" if passed: msg = self._create_success_message( tag, base_msg, extra_msg ) print_message(msg, color=msg_color("succeeded")) self._register_outcome(True, tag, msg) return True else: msg = self._create_failure_message( tag, base_msg, extra_msg ) print_message(msg, color="1;31" if COLORS else None) self._register_outcome(False, tag, msg) return False def checkFileLines(self, filename, *lines): """ Works like `checkPrintedLines`, but checks for lines in the specified file, rather than checking for printed lines. """ # Figure out the tag for this expectation tag = tag_for(get_my_location()) # Skip this check if the case has failed already if self._should_skip(): self._print_skip_message(tag, "prior test failed") # Note that we don't add an outcome here, and we return None # instead of True or False return None # Fetch the results to actually run the test! expected = '\n'.join(lines) + '\n' results = self.fetchResults() # Figure out whether we've got an error or an actual result if results["error"] is not None: # An error during testing tb = results["traceback"] tblines = tb.splitlines() if len(tblines) < 12: base_msg = "Failed due to an error:\n" + indent(tb, 2) extra_msg = None else: short_tb = '\n'.join(tblines[:4] + ['...'] + tblines[-4:]) base_msg = "Failed due to an error:\n" + indent(short_tb, 2) extra_msg = "Full traceback is:\n" + indent(tb, 2) msg = self._create_failure_message( tag, base_msg, extra_msg ) print_message(msg, color=msg_color("failed")) self._register_outcome(False, tag, msg) return False else: # The test was able to run, so check the file contents # Fetch file contents try: with open(filename, 'r', newline='') as fileInput: fileContents = fileInput.read() except (OSError, FileNotFoundError, PermissionError): # We can't even read the file! msg = self._create_failure_message( tag, f"Expected file '{filename}' cannot be read.", None ) print_message(msg, color=msg_color("failed")) self._register_outcome(False, tag, msg) return False # If the file doesn't end with a newline, don't add one to # our expectation either... if not fileContents.endswith('\n'): expected = expected[:-1] # Get lines/single versions firstDiff = findFirstDifference(fileContents, expected) equivalence = None passed = False if fileContents == expected: equivalence = "exactly the same as" passed = True elif firstDiff is None: equivalence = "equivalent to" passed = True else: # Some other kind of difference equivalence = "NOT the same as" # passed remains False # Get short/long representations of our strings short, long = dual_string_repr(fileContents) short_exp, long_exp = dual_string_repr(expected) # Construct base and extra messages if short == long and short_exp == long_exp: base_msg = ( f"File contents:\n{indent(short, 2)}\nwere" f" {equivalence} the expected file" f" contents:\n{indent(short_exp, 2)}" ) if not passed: base_msg += ( f"\nFirst difference was:\n{indent(firstDiff, 2)}" ) extra_msg = None else: base_msg = ( f"File contents:\n{indent(short, 2)}\nwere" f" {equivalence} the expected file" f" contents:\n{indent(short_exp, 2)}" ) if not passed: base_msg += ( f"\nFirst difference was:\n{indent(firstDiff, 2)}" ) extra_msg = "" if short != long: extra_msg += f"Full file contents:\n{indent(long, 2)}\n" if short_exp != long_exp: extra_msg += ( f"Full expected file" f" contents:\n{indent(long_exp, 2)}\n" ) if passed: msg = self._create_success_message( tag, base_msg, extra_msg ) print_message(msg, color=msg_color("succeeded")) self._register_outcome(True, tag, msg) return True else: msg = self._create_failure_message( tag, base_msg, extra_msg ) print_message(msg, color="1;31" if COLORS else None) self._register_outcome(False, tag, msg) return False def checkCustom(self, checker, *args, **kwargs): """ Sets up a custom check using a testing function. The provided function will be given one argument, plus any additional arguments given to this function. The first and/or only argument to the checker function will be a dictionary with the following keys: - "case": The test case object on which `checkCustom` was called. This could be used to do things like access arguments passed to the function being tested for a `FunctionCase` for example. - "output": Output printed by the test case, as a string. - "result": the result value (for function tests only, otherwise this key will not be present). - "error": the error that occurred (or None if no error occurred). - "traceback": the traceback (a string, or None if there was no error). The testing function must return True to indicate success and False for failure. If it returns something other than True or False, it will be counted as a failure, that value will be shown as part of the test result if the `DETAIL_LEVEL` is 1 or higher, and this method will return False. If this check is skipped (e.g., because of a previous failure), this method returns None and does not add an entry to `self.outcomes`; the custom checker function won't be called in that case. """ results = self.fetchResults() # Add a 'case' entry checker_input = copy.copy(results) checker_input["case"] = self # Figure out the tag for this expectation tag = tag_for(get_my_location()) # Skip this check if the case has failed already if self._should_skip(): self._print_skip_message(tag, "prior test failed") # Note that we don't add an outcome here, and we return None # instead of True or False return None # Only run the checker if we're not skipping the test test_result = checker(checker_input, *args, **kwargs) if test_result is True: msg = self._create_success_message(tag, "Custom check passed.") print_message(msg, color=msg_color("succeeded")) self._register_outcome(True, tag, msg) return True elif test_result is False: msg = self._create_failure_message(tag, "Custom check failed") print_message(msg, color="1;31" if COLORS else None) self._register_outcome(False, tag, msg) return False else: msg = self._create_failure_message( tag, "Custom check failed:\n" + indent(str(test_result), 2), ) print_message(msg, color="1;31" if COLORS else None) self._register_outcome(False, tag, msg) return False class FileCase(TestCase): """ Runs a particular file when executed. Its manager should be a `FileManager`. """ # __init__ is inherited def run(self): """ Reads the target file, parses its code, and runs that code in an empty environment (except that `__name__` is set to `'__main__'`, to make the file behave as if it were run as the main file). """ def payload(): "Payload function to run a file." # Read the file with open(self.manager.target, 'r') as fin: src = fin.read() # Parse the file node = ast.parse( src, filename=self.manager.target, mode='exec' ) # Compile the results code = compile(node, self.manager.target, 'exec') # Run the code, setting __name__ to __main__ (this is # why we don't just import the file) env = {"__name__": "__main__"} exec(code, env) return self._run(payload) def testDetails(self): """ Returns a pair of strings containing base and extra details describing what was tested by this test case. If the base details capture all available information, the extra details value will be None. """ return ( f"Ran file '{self.manager.target}'", None # no further details to report ) class FunctionCase(TestCase): """ Calls a particular function with specific arguments when run. """ def __init__(self, manager, args=None, kwargs=None): """ The arguments and/or keyword arguments to be used for the case are provided after the manager (as a list and a dictionary, NOT as actual arguments). If omitted, the function will be called with no arguments. """ super().__init__(manager) self.args = args or () self.kwargs = kwargs or {} def run(self): """ Runs the target function with the arguments specified for this case. The 'result' slot of the `self.results` dictionary that it creates holds the return value of the function. """ def payload(): "Payload for running a function with specific arguments." return self.manager.target(*self.args, **self.kwargs) return self._run(payload) def testDetails(self): """ Returns a pair of strings containing base and extra details describing what was tested by this test case. If the base details capture all available information, the extra details value will be None. """ # Show function name + args, possibly with some abbreviation fn = self.manager.target msg = f"Called function '{fn.__name__}'" args = self.args if self.args is not None else [] kwargs = self.kwargs if self.args is not None else {} all_args = len(args) + len(kwargs) argnames = fn.__code__.co_varnames[:all_args] if len(args) > len(argnames): msg += " with too many arguments (!):" elif all_args > 0: msg += " with arguments:" # TODO: Proper handling of *args and **kwargs entries! # Create lists of full and maybe-abbreviated argument # strings argstrings = [] short_argstrings = [] for i, arg in enumerate(args): if i < len(argnames): name = argnames[i] else: name = f"extra argument #{i - len(argnames) + 1}" short_name = ellipsis(name, 20) argstrings.append(f"{name} = {repr(arg)}") short_argstrings.append( f"{short_name} = {ellipsis(repr(arg), 60)}" ) # Order kwargs by original kwargs order and then by natural # order of kwargs dictionary keyset = set(kwargs) ordered = list(filter(lambda x: x in keyset, argnames)) rest = [k for k in kwargs if k not in ordered] for k in ordered + rest: argstrings.append(f"{k} = {repr(kwargs[k])}") short_name = ellipsis(k, 20) short_argstrings.append( f"{short_name} = {ellipsis(repr(kwargs[k]), 60)}" ) full_args = ' ' + '\n '.join(argstrings) # In case there are too many arguments if len(short_argstrings) < 20: short_args = ' ' + '\n '.join(short_argstrings) else: short_args = ( ' ' + '\n '.join(short_argstrings[:19]) + f"...plus {len(argstrings) - 19} more arguments..." ) if short_args == full_args: return ( msg + '\n' + short_args, None ) else: return ( msg + '\n' + short_args, "Full arguments were:\n" + full_args ) class BlockCase(TestCase): """ Executes a block of code (provided as text) when run. Per-case variables may be defined for the execution environment, which otherwise just has builtins. """ def __init__(self, manager, assignments=None): """ A dictionary of variable name : value assignments may be provided and these will be inserted into the execution environment for the code block. If omitted, no extra variables will be defined. """ super().__init__(manager) self.assignments = assignments or {} def run(self): """ Compiles and runs the target code block in an environment which is empty except for the assignments specified in this case (and builtins). """ def payload(): "Payload for running a code block specific variables active." env = dict(self.assignments) exec(self.manager.target, env) return self._run(payload) def testDetails(self): """ Returns a pair of strings containing base and extra details describing what was tested by this test case. If the base details capture all available information, the extra details value will be None. """ block = self.manager.target short = limited_repr(block) if block == short: # Short enough to show whole block return ( "Ran code:\n" + indent(block, 2), None ) else: # Too long to show whole block in short view... return ( "Ran code:\n" + indent(short, 2), "Full code was:\n" + indent(block, 2) ) class SkipCase(TestCase): """ A type of test case which actually doesn't run checks, but instead prints a message that the check was skipped. """ # __init__ is inherited def run(self): """ Since there is no real test, our results are fake. The keys "error" and "traceback" have None as their value, and "output" also has None. We add a key "skipped" with value True. """ self.results = { "output": None, "error": None, "traceback": None, "skipped": True } return self.results def testDetails(self): """ Provides a pair of topic/details strings about this test. """ return (f"Skipped check of '{self.manager.target}'", None) def checkReturnValue(self, _, **__): """ Skips the check. """ self._print_skip_message( tag_for(get_my_location()), "testing target not available" ) def checkPrintedLines(self, *_, **__): """ Skips the check. """ self._print_skip_message( tag_for(get_my_location()), "testing target not available" ) def checkPrintedFragment(self, *_, **__): """ Skips the check. """ self._print_skip_message( tag_for(get_my_location()), "testing target not available" ) def checkFileLines(self, *_, **__): """ Skips the check. """ self._print_skip_message( tag_for(get_my_location()), "testing target not available" ) def checkCustom(self, _, **__): """ Skips the check. """ self._print_skip_message( tag_for(get_my_location()), "testing target not available" ) #----------------------# # Test Manager Classes # #----------------------# class TestManager: """ Abstract base class for managing tests for a certain function, file, or block of code. Create these using the `testFunction`, `testFile`, and/or `testBlock` factory functions. The `TestManager.case` function can be used to derive `TestCase` objects which can then be used to set up checks. """ case_type = TestCase """ The case type determines what kind of test case will be constructed when calling the `TestManager.case` method. Subclasses override this. """ def __init__(self, target): """ A testing target (a filename string, function object, code string, or test label string) must be provided. """ self.target = target # Keeps track of whether any cases derived from this manager have # failed so far self.any_failed = False def case(self): """ Returns a `TestCase` object that will test the target file/function/block. Some manager types allow arguments to this function. """ return self.case_type(self) class FileManager(TestManager): """ Manages test cases for running an entire file. Unlike other managers, cases for a file cannot have parameters. Calling `TestCase.provideInputs` on a case to provide inputs still means that having multiple cases can be useful, however. """ case_type = FileCase def __init__(self, filename): """ A FileManager needs a filename string that specifies which file we'll run when we run a test case. """ if not isinstance(filename, str): raise TypeError( f"For a file test manager, the target must be a file" f" name string. (You provided a/an {type(filename)}.)" ) super().__init__(filename) # case is inherited as-is class FunctionManager(TestManager): """ Manages test cases for running a specific function. Arguments to the `TestManager.case` function are passed to the function being tested for that case. """ case_type = FunctionCase def __init__(self, function): """ A FunctionManager needs a function object as the target. Each case will call that function with arguments provided when the case is created. """ if not isinstance(function, types.FunctionType): raise TypeError( f"For a function test manager, the target must be a" f" function. (You provided a/an {type(function)}.)" ) super().__init__(function) def case(self, *args, **kwargs): """ Arguments supplied here are used when calling the function which is what happens when the case is run. Returns a `FunctionCase` object. """ return self.case_type(self, args, kwargs) class BlockManager(TestManager): """ Manages test cases for running a block of code (from a string). Keyword arguments to the `TestManager.case` function are defined as variables before the block is executed in that case. """ case_type = BlockCase def __init__(self, code): """ A BlockManager needs a code string as the target. """ if not isinstance(code, str): raise TypeError( f"For a 'block' test manager, the target must be a" f" string. (You provided a/an {type(code)}.)" ) super().__init__(code) def case(self, **assignments): """ Keyword argument supplied here will be defined as variables in the environment used to run the code block. Returns a `BlockCase` object. """ return self.case_type(self, assignments) class SkipManager(TestManager): """ Manages fake test cases for a file, function, or code block that needs to be skipped (perhaps for a function that doesn't yet exist, for example). Cases derived are `SkipCase` objects which just print skip messages for any checks requested. """ case_type = SkipCase def __init__(self, label): """ Needs a label string to identify which tests are being skipped. """ if not isinstance(label, str): raise TypeError( f"For a skip test manager, the target must be a string." f" (You provided a/an {type(label)}.)" ) super().__init__(label) def case(self, *_, **__): """ Accepts (and ignores) any extra arguments. """ return super().case() #----------------# # Test factories # #----------------# def testFunction(fn): """ Creates a test-manager for the given function. """ if not isinstance(fn, types.FunctionType): raise TypeError( "Test target must be a function (use testFile or testBlock" " instead to test a file or block of code)." ) return FunctionManager(fn) def testFunctionMaybe(module, fname): """ This function creates a test-manager for a named function from a specific module, but displays an alternate message and returns a dummy manager if that module doesn't define any variable with the target name. Useful for defining tests for functions that will be skipped if the functions aren't done yet. """ # Message if we can't find the function if not hasattr(module, fname): print_message( f"Did not find '{fname}' in module '{module.__name__}'...", color=msg_color("skipped") ) return SkipManager(f"{module.__name__}.{fname}") else: target = getattr(module, fname) if not isinstance(target, types.FunctionType): print_message( ( f"'{fname}' in module '{module.__name__}' is not a" f" function..." ), color=msg_color("skipped") ) return SkipManager(f"{module.__name__}.{fname}") else: return FunctionManager(target) def testFile(filename): """ Creates a test-manager for running the named file. """ if not isinstance(filename, str): raise TypeError( "Test target must be a file name (use testFunction instead" " to test a function)." ) if not os.path.exists(filename): raise FileNotFoundError( f"We cannot create a test for running '{filename}' because" f" that file does not exist." ) return FileManager(filename) def testBlock(code): """ Creates a test-manager for running a block of code. """ if not isinstance(code, str): raise TypeError( "Test target must be a code string (use testFunction instead" " to test a function)." ) # TODO: This check is good, but avoiding multiple parsing passes # might be nice for larger code blocks... try: _ = ast.parse(code) except Exception: raise ValueError( "The code block you provided could not be parsed as Python" " code." ) return BlockManager(code) #----------------# # Output capture # #----------------# class CapturingStream(io.StringIO): """ An output capture object which is an `io.StringIO` underneath, but which has an option to also write incoming text to normal `sys.stdout`. Call the install function to begin capture. """ def __init__(self, *args, **kwargs): """ Passes arguments through to `io.StringIO`'s constructor. """ self.original_stdout = None self.tee = False super().__init__(*args, **kwargs) def echo(self, doit=True): """ Turn on echoing to stdout along with capture, or turn it off if False is given. """ self.tee = doit def install(self): """ Replaces `sys.stdout` to begin capturing printed output. Remembers the old `sys.stdout` value so that `uninstall` can work. Note that if someone else changes `sys.stdout` after this is installed, uninstall will set `sys.stdout` back to what it was when `install` was called, which could cause issues. For example, if we have two capturing streams A and B, and we call: ```py A.install() B.install() A.uninstall() B.uninstall() ``` The original `sys.stdout` will not be restored. In general, you must uninstall capturing streams in the reverse order that you installed them. """ self.original_stdout = sys.stdout sys.stdout = self def uninstall(self): """ Returns `sys.stdout` to what it was before `install` was called, or does nothing if `install` was never called. """ if self.original_stdout is not None: sys.stdout = self.original_stdout def reset(self): """ Resets the captured output. """ self.seek(0) self.truncate(0) def writelines(self, lines): """ Override writelines to work through write. """ for line in lines: self.write(line) def write(self, stuff): """ Accepts a string and writes to our capture buffer (and to original stdout if `echo` has been called). Returns the number of characters written. """ if self.tee and self.original_stdout is not None: self.original_stdout.write(stuff) super().write(stuff) def showPrintedLines(show=True): """ Changes the testing mechanisms so that printed output produced during tests is shown as normal in addition to being captured. Call it with False as an argument to disable this. """ global _SHOW_OUTPUT _SHOW_OUTPUT = show #---------------------# # Debugging functions # #---------------------# def differencesAreSubtle(val, ref): """ Judges whether differences between two strings are 'subtle' in which case the first difference details will be displayed. Returns true if either value is longer than a typical floating-point number, or if the representations are the same once all whitespace is stripped out. """ # If either has non-trivial length, we'll include the first # difference report. 18 is the length of a floating-point number # with two digits before the decimal point and max digits afterwards if len(val) > 18 or len(ref) > 18: return True valNoWS = re.sub(r'\s', '', val) refNoWS = re.sub(r'\s', '', ref) # If they're the same modulo whitespace, then it's probably useful # to report first difference, otherwise we won't return valNoWS == refNoWS def expect(expr, value): """ Establishes an immediate expectation that the values of the two arguments should be equivalent. The expression provided will be picked out of the source code of the module calling `expect` (see `get_my_context`). The expression and sub-values will be displayed if the expectation is not met, and either way a message indicating success or failure will be printed. Use `detailLevel` to control how detailed the messages are. For `expect` to work properly, the following rules must be followed: 1. When multiple calls to `expect` appear on a single line of the source code (something you should probably avoid anyway), none of the calls should execute more times than another when that line is executed (it's difficult to violate this, but examples include the use of `expect` multiple times on one line within generator or if/else expressions) 2. None of the following components of the expression passed to `expect` should have side effects when evaluated: - Attribute accesses - Subscripts (including expressions inside brackets) - Variable lookups (Note that those things don't normally have side effects!) This function returns True if the expectation is met and False otherwise. It returns None if the check is skipped, which will happen when `SKIP_ON_FAILURE` is `'all'` and a previous check failed. If the values are not equivalent, this will count as a failed check and other checks may be skipped. """ global CHECK_FAILED context = get_my_context(expect) tag = tag_for(context) # Skip this expectation if necessary if SKIP_ON_FAILURE == 'all' and CHECK_FAILED: if DETAIL_LEVEL < 1: msg = f"~ {tag} (skipped)" else: msg = ( f"~ direct expectation at {tag} for skipped because a" f" prior check failed" ) print_message(msg, color=msg_color("skipped")) return None short_result = ellipsis(repr(expr), 78) short_expected = ellipsis(repr(value), 78) full_result = repr(expr) full_expected = repr(value) firstDiff = findFirstDifference(expr, value) if firstDiff is None: message = f"✓ {tag}" equivalent = "equivalent to" msg_cat = "succeeded" same = True else: message = f"✗ {tag}" equivalent = "NOT equivalent to" msg_cat = "failed" same = False if DETAIL_LEVEL >= 1 or not same: message += f""" Result: {indent(short_result, 4)} was {equivalent} the expected value: {indent(short_expected, 4)}""" if ( firstDiff is not None and differencesAreSubtle(short_result, short_expected) ): message += f"\n First difference was:\n{indent(firstDiff, 4)}" # Report full values if detail level is turned up and the short # values were abbreviations if DETAIL_LEVEL >= 1: if short_result != full_result: message += f"\n Full result:\n{indent(full_result, 4)}" if short_expected != full_expected: message += ( f"\n Full expected value:\n{indent(full_expected, 4)}" ) # Report info about the test expression base, extra = expr_details(context) if same and DETAIL_LEVEL >= 1 or not same and DETAIL_LEVEL >= 0: message += '\n' + indent(base, 2) if DETAIL_LEVEL >= 1 and extra: message += '\n' + indent(extra, 2) # Register a check failure if the expectation was not met if not same: CHECK_FAILED = True # Print our message and return our result print_message(message, color=msg_color(msg_cat)) return same def expectType(expr, typ): """ Works like `expect`, but establishes an expectation for the type of the result of the expression, not for the exact value. The same rules must be followed as for `expect` for this to work properly. If the type of the expression's result is an instance of the target type, the expectation counts as met. """ global CHECK_FAILED context = get_my_context(expectType) tag = tag_for(context) # Skip this expectation if necessary if SKIP_ON_FAILURE == 'all' and CHECK_FAILED: if DETAIL_LEVEL < 1: msg = f"~ {tag} (skipped)" else: msg = ( f"~ direct expectation at {tag} for skipped because a" f" prior check failed" ) print_message(msg, color=msg_color("skipped")) return None if type(expr) == typ: message = f"✓ {tag}" desc = "the expected type" msg_cat = "succeeded" same = True elif isinstance(expr, typ): message = f"✓ {tag}" desc = f"a kind of {typ}" msg_cat = "succeeded" same = True else: message = f"✗ {tag}" desc = f"NOT a kind of {typ}" msg_cat = "failed" same = False if not same: CHECK_FAILED = True # Report on the type if the detail level warrants it if same and DETAIL_LEVEL >= 1 or not same and DETAIL_LEVEL >= 0: message += f"\n The result type ({type(expr)}) was {desc}." # Report info about the test expression base, extra = expr_details(context) if same and DETAIL_LEVEL >= 1 or not same and DETAIL_LEVEL >= 0: message += '\n' + indent(base, 2) if DETAIL_LEVEL >= 1 and extra: message += '\n' + indent(extra, 2) # Print our message and return our result print_message(message, color=msg_color(msg_cat)) return same #------------------# # Message Handling # #------------------# def indent(msg, level=2): """ Indents every line of the given message (a string). """ indent = ' ' * level return indent + ('\n' + indent).join(msg.splitlines()) def ellipsis(string, maxlen=40): """ Returns the provided string as-is, or if it's longer than the given maximum length, returns the string, truncated, with '...' at the end, which will, including the ellipsis, be exactly the given maximum length. The maximum length must be 4 or more. """ if len(string) > maxlen: return string[:maxlen - 3] + "..." else: return string def dual_string_repr(string): """ Returns a pair containing full and truncated representations of the given string. The formatting of even the full representation depends on whether it's a multi-line string or not and how long it is. """ lines = string.split('\n') if len(repr(string)) < 80 and len(lines) == 1: full = repr(string) short = repr(string) else: full = '"""\\\n' + string.replace('\r', '\\r') + '"""' if len(string) < 240 and len(lines) <= 7: short = full elif len(lines) > 7: head = '\n'.join(lines[:7]) short = ( '"""\\\n' + ellipsis(head.replace('\r', '\\r'), 240) + '"""' ) else: short = ( '"""\\\n' + ellipsis(string.replace('\r', '\\r'), 240) + '"""' ) return (full, short) def limited_repr(string): """ Given a string that might include multiple lines and/or lots of characters (regardless of lines), returns version cut off by ellipses either after 5 or so lines, or after 240 characters. Returns the full string if it's both less than 240 characters and less than 5 lines. """ # Split by lines lines = string.split('\n') # Already short enough if len(string) < 240 and len(lines) < 5: return string # Try up to 5 lines, cutting them off until we've got a # short-enough head string for n in range(min(5, len(lines)), 0, -1): head = '\n'.join(lines[:n]) if n < len(lines): head += '\n...' if len(head) < 240: break else: # If we didn't break, just use first 240 characters # of the string head = string[:240] + '...' # If we cut things too short (e.g., because of initial # empty lines) use first 240 characters of the string if len(head) < 12: head = string[:240] + '...' return head def msg_color(category): """ Returns an ANSI color code for the given category of message (one of "succeeded", "failed", "skipped", or "reset"), or returns None if COLORS is disabled or an invalid category is provided. """ if not COLORS: return None else: return MSG_COLORS.get(category) def print_message(msg, color=None): """ Prints a test result message to sys.stderr, but also flushes stdout and stderr both beforehand and afterwards to improve message ordering. If a color is given, it should be an ANSI terminal color code string (just the digits, for example '34' for blue or '1;31' for bright red). """ sys.stdout.flush() sys.stderr.flush() # Make the whole message blue if color: print(f"\x1b[{color}m", end="", file=sys.stderr) suffix = "\x1b[0m" else: suffix = "" print(msg + suffix, file=sys.stderr) sys.stdout.flush() sys.stderr.flush() def expr_details(context): """ Returns a pair of strings containing base and extra details for an expression as represented by a dictionary returned from `get_my_context`. The extra message may be an empty string if the base message contains all relevant information. """ # Expression that was evaluated expr = context.get("expr_src", "???") short_expr = ellipsis(expr, 78) # Results msg = "" extra_msg = "" # Base message msg += f"Test expression was:\n{indent(short_expr, 2)}" # Figure out values to display vdict = context.get("values", {}) if context.get("relevant") is not None: show = sorted( context["relevant"], key=lambda fragment: (expr.index(fragment), len(fragment)) ) else: show = sorted( vdict.keys(), key=lambda fragment: (expr.index(fragment), len(fragment)) ) if len(show) > 0: msg += "\nValues were:" longs = [] for key in show: if key in vdict: val = repr(vdict[key]) else: val = "???" entry = f" {key} = {val}" fits = ellipsis(entry) msg += '\n' + fits if fits != entry: longs.append(entry) # Extra message if short_expr != expr: if extra_msg != "" and not extra_msg.endswith('\n'): extra_msg += '\n' extra_msg += f"Full expression:\n{indent(expr, 2)}" extra_values = sorted( [ key for key in vdict.keys() if key not in context.get("relevant", []) ], key=lambda fragment: (expr.index(fragment), len(fragment)) ) if context.get("relevant") is not None and extra_values: if extra_msg != "" and not extra_msg.endswith('\n'): extra_msg += '\n' extra_msg += "Extra values:" for ev in extra_values: if ev in vdict: val = repr(vdict[ev]) else: val = "???" entry = f" {ev} = {val}" fits = ellipsis(entry, 78) extra_msg += '\n' + fits if fits != entry: longs.append(entry) if longs: if extra_msg != "" and not extra_msg.endswith('\n'): extra_msg += '\n' extra_msg += "Full values:" for entry in longs: extra_msg += '\n' + entry return msg, extra_msg #------------# # Comparison # #------------# def findFirstDifference(val, ref, comparing=None): """ Returns a string describing the first point of difference between `val` and `ref`, or None if the two values are equivalent. If IGNORE_TRAILING_WHITESPACE is True, trailing whitespace will be trimmed from each string before looking for differences. A small amount of difference is ignored between floating point numbers, including those found in complex structures. Works for recursive data structures; the `comparing` argument serves as a memo to avoid infinite recursion, and the `within` argument indicates where in a complex structure we are; both should normally be left as their defaults. """ if comparing is None: comparing = set() cmpkey = (id(val), id(ref)) if cmpkey in comparing: # Either they differ somewhere else, or they're functionally # identical # TODO: Does this really ward off all infinite recursion on # finite structures? return None comparing.add(cmpkey) try: simple = val == ref except RecursionError: simple = False if simple: return None else: # let's hunt for differences if ( isinstance(val, (int, float, complex)) and isinstance(ref, (int, float, complex)) ): # what if they're both numbers? if cmath.isclose( val, ref, rel_tol=FLOAT_REL_TOLERANCE, abs_tol=FLOAT_ABS_TOLERANCE ): return None else: if isinstance(val, complex) and isinstance(ref, complex): return f"complex numbers {val} and {ref} are different" elif isinstance(val, complex) or isinstance(ref, complex): return f"numbers {val} and {ref} are different" elif val > 0 and ref < 0: return f"numbers {val} and {ref} have different signs" else: return f"numbers {val} and {ref} are different" elif type(val) != type(ref): # different types; not both numbers svr = ellipsis(repr(val), 8) srr = ellipsis(repr(ref), 8) return ( f"values {svr} and {srr} have different types" f" ({type(val)} and {type(ref)})" ) elif isinstance(val, str): # both strings if '\n' in val or '\n' in ref: # multi-line strings; look for first different line # Note: we *don't* use splitlines here because it will # give multiple line breaks in a \r\r\n situation like # those caused by csv.DictWriter on windows when opening # a file without newlines=''. We'd like to instead ignore # '\r' as a line break (we're not going to work on early # Macs) and strip it if IGNORE_TRAILING_WHITESPACE is on. valLines = val.split('\n') refLines = ref.split('\n') # First line # where they differ (1-indexed) firstDiff = None # Compute point of first difference i = None for i in range(min(len(valLines), len(refLines))): valLine = valLines[i] refLine = refLines[i] if IGNORE_TRAILING_WHITESPACE: valLine = valLine.rstrip() refLine = refLine.rstrip() if valLine != refLine: firstDiff = i + 1 break else: if i is not None: # if one has more lines if len(valLines) != len(refLines): # In this case, one of the two is longer... # If IGNORE_TRAILING_WHITESPACE is on, and # the longer one just has a blank extra line # (possibly with some whitespace on it), then # the difference is just in the presence or # absence of a final newline, which we also # count as a "trailing whitespace" difference # and ignore. Note that multiple final '\n' # characters will be counted as a difference, # since they result in multiple final # lines... if ( IGNORE_TRAILING_WHITESPACE and ( ( len(valLines) == len(refLines) + 1 and valLines[i + 1].strip() == '' ) or ( len(valLines) + 1 == len(refLines) and refLines[i + 1].strip() == '' ) ) ): return None else: # If we're attending trailing whitespace, # or if there are multiple extra lines or # the single extra line is not blank, # then that's where our first difference # is. firstDiff = i + 2 else: # There is no difference once we trim # trailing whitespace... return None else: # Note: this is a line number, NOT a line index firstDiff = 1 got = "nothing (string had fewer lines than expected)" expected = "nothing (string had more lines than expected)" i = firstDiff - 1 if i < len(valLines): got = repr(valLines[i]) if i < len(refLines): expected = repr(refLines[i]) limit = 60 shortGot = ellipsis(got, limit) shortExpected = ellipsis(expected, limit) while ( shortGot == shortExpected and limit < len(got) or limit < len(expected) and limit < 200 ): limit += 10 shortGot = ellipsis(got, limit) shortExpected = ellipsis(expected, limit) return ( f"strings differ on line {firstDiff} where we got:" f"\n {shortGot}\nbut we expected:" f"\n {shortExpected}" ) else: # Single-line strings: find character pos of difference if IGNORE_TRAILING_WHITESPACE: val = val.rstrip() ref = ref.rstrip() if val == ref: return None # Find character position of first difference pos = None i = None for i in range(min(len(val), len(ref))): if val[i] != ref[i]: pos = i break else: if i is not None: pos = i + 1 else: pos = 0 # one string is empty vchar = None rchar = None if pos < len(val): vchar = val[pos] if pos < len(ref): rchar = ref[pos] if vchar is None: missing = ellipsis(repr(ref[pos:]), 20) return ( f"expected text missing from end of string:" f" {missing}" ) return f"strings {svr} and {srr} differ at position {pos}" elif rchar is None: extra = ellipsis(repr(val[pos:]), 20) return ( f"extra text at end of string:" f" {extra}" ) else: if pos > 6: got = ellipsis(repr(val[pos:]), 14) expected = ellipsis(repr(ref[pos:]), 14) return ( f"strings differ from position {pos}: got {got}" f" but expected {expected}" ) else: got = ellipsis(repr(val), 14) expected = ellipsis(repr(ref), 14) return ( f"strings are different: got {got}" f" but expected {expected}" ) elif isinstance(val, (list, tuple)): # both lists or tuples svr = ellipsis(repr(val), 10) srr = ellipsis(repr(ref), 10) typ = type(val).__name__ if len(val) != len(ref): return ( f"{typ}s {svr} and {srr} have different lengths" f" ({len(val)} and {len(ref)})" ) else: for i in range(len(val)): diff = findFirstDifference(val[i], ref[i], comparing) if diff is not None: return f"in slot {i} of {typ}, " + diff return None # no differences in any slot elif isinstance(val, (set)): # both sets svr = ellipsis(repr(val), 10) srr = ellipsis(repr(ref), 10) onlyVal = (val - ref) onlyRef = (ref - val) # Sort so we can match up different-but-equivalent # floating-point items... try: sonlyVal = sorted(onlyVal) sonlyRef = sorted(onlyRef) diff = findFirstDifference( sonlyVal, sonlyRef, comparing ) except TypeError: # not sortable, so not just floating-point diffs diff = "some" if diff is None: return None else: nMissing = len(onlyRef) nExtra = len(onlyVal) if nExtra == 0: firstMissing = ellipsis(repr(list(onlyRef)[0]), 12) result = f"in a set, missing element {firstMissing}" if nMissing > 1: result += f" ({nMissing} missing elements in total)" return result elif nMissing == 0: firstExtra = ellipsis(repr(list(onlyVal)[0]), 12) return f"in a set, extra element {firstExtra}" if nExtra > 1: result += f" ({nExtra} extra elements in total)" return result else: firstMissing = ellipsis(repr(list(onlyRef)[0]), 8) firstExtra = ellipsis(repr(list(onlyVal)[0]), 8) result = ( "in a set, elements are different (extra" f" element {firstExtra} and missing element" f" {firstMissing}" ) if nMissing > 1 and nExtra > 1: result += ( f" ({nExtra} total extra elements and" f" {nMissing} total missing elements" ) elif nMissing == 1: if nExtra > 1: result += ( f" (1 missing and {nExtra} total extra" f" elements)" ) else: # nExtra must be 1 if nMissing > 1: result += ( f" (1 extra and {nExtra} total missing" f" elements)" ) return result elif isinstance(val, dict): # both dicts svr = ellipsis(repr(val), 14) srr = ellipsis(repr(ref), 14) if len(val) != len(ref): if len(val) < len(ref): ldiff = len(ref) - len(val) firstMissing = ellipsis( repr(list(set(ref.keys()) - set(val.keys()))[0]), 30 ) return ( f"dictionary is missing key {firstMissing} (has" f" {ldiff} fewer key{'s' if ldiff > 1 else ''}" f" than expected)" ) else: ldiff = len(val) - len(ref) firstExtra = ellipsis( repr(list(set(val.keys()) - set(ref.keys()))[0]), 30 ) return ( f"dictionary has extra key {firstExtra} (has" f" {ldiff} more key{'s' if ldiff > 1 else ''}" f" than expected)" ) return ( f"dictionaries {svr} and {srr} have different sizes" f" ({len(val)} and {len(ref)})" ) vkeys = set(val.keys()) rkeys = set(ref.keys()) try: onlyVal = sorted(vkeys - rkeys) onlyRef = sorted(rkeys - vkeys) keyCorrespondence = {} except TypeError: # unsortable... keyCorrespondence = None # Check for floating-point equivalence of keys if sets are # sortable... if keyCorrespondence is not None: if findFirstDifference(onlyVal, onlyRef, comparing) is None: keyCorrespondence = { onlyVal[i]: onlyRef[i] for i in range(len(onlyVal)) } # Add pass-through mappings for matching keys for k in vkeys & rkeys: keyCorrespondence[k] = k else: # No actual mapping is available... keyCorrespondence = None # We couldn't find a correspondence between keys, so we # return a key-based difference if keyCorrespondence is None: onlyVal = vkeys - rkeys onlyRef = rkeys - vkeys nExtra = len(onlyVal) nMissing = len(onlyRef) if nExtra == 0: firstMissing = ellipsis(repr(list(onlyRef)[0]), 10) result = f"dictionary is missing key {firstMissing}" if nMissing > 1: result += f" ({nMissing} missing keys in total)" return result elif nMissing == 0: firstExtra = ellipsis(repr(list(onlyVal)[0]), 10) result = f"dictionary has extra key {firstExtra}" if nExtra > 1: result += f" ({nExtra} extra keys in total)" return result else: # neither is 0 firstMissing = ellipsis(repr(list(onlyRef)[0]), 10) firstExtra = ellipsis(repr(list(onlyVal)[0]), 10) result = ( f"dictionary is missing key {firstMissing} and" f" has extra key {firstExtra}" ) if nMissing > 1 and nExtra > 1: result += ( f" ({nMissing} missing and {nExtra} extra" f" keys in total)" ) elif nMissing == 1: if nExtra > 1: result += ( f" (1 missing and {nExtra} extra keys" f" in total)" ) else: # nExtra must be 1 if nMissing > 1: result += ( f" (1 extra and {nMissing} missing keys" f" in total)" ) return result # if we reach here, keyCorrespondence maps val keys to # equivalent (but not necessarily identical) ref keys for vk in keyCorrespondence: rk = keyCorrespondence[vk] vdiff = findFirstDifference(val[vk], ref[rk], comparing) if vdiff is not None: krep = ellipsis(repr(vk), 14) return f"in dictionary slot {krep}, " + vdiff return None else: # not sure what kind of thing this is... if val == ref: return None else: limit = 15 vr = repr(val) rr = repr(ref) svr = ellipsis(vr, limit) srr = ellipsis(rr, limit) while ( svr == srr and (limit < len(vr) or limit < len(rr)) and limit < 100 ): limit += 10 svr = ellipsis(vr, limit) srr = ellipsis(rr, limit) return f" objects {svr} and {srr} are different" def checkContainment(val1, val2): """ Returns True if val1 is 'contained in' to val2, and False otherwise. If IGNORE_TRAILING_WHITESPACE is True, will ignore trailing whitespace in two strings when comparing them for containment. """ if (not isinstance(val1, str)) or (not isinstance(val2, str)): return val1 in val2 # use regular containment test # For two strings, pay attention to IGNORE_TRAILING_WHITESPACE elif IGNORE_TRAILING_WHITESPACE: # remove trailing whitespace from both strings (on all lines) return trimWhitespace(val1) in trimWhitespace(val2) else: return val1 in val2 # use regular containment test def trimWhitespace(st, requireNewline=False): """ Assume st a string. Use .rstrip() to remove trailing whitespace from each line. This has the side effect of replacing complex newlines with just '\\n'. If requireNewline is set to true, only whitespace that comes before a newline will be trimmed, and whitespace which occurs at the end of the string on the last line will be retained if there is no final newline. """ if requireNewline: return re.sub('[ \t\r]*([\r\n])', r'\1', st) else: result = '\n'.join(line.rstrip() for line in st.split('\n')) return result def compare(val, ref): """ Comparison function returning a boolean which uses findFirstDifference under the hood. """ return findFirstDifference(val, ref) is None def test_compare(): "Tests the compare function." # TODO: test findFirstDifference instead & confirm correct # messages!!! # Integers assert compare(1, 1) assert compare(1, 2) is False assert compare(2, 1 + 1) # Floating point numbers assert compare(1.1, 1.1) assert compare(1.1, 1.2) is False assert compare(1.1, 1.1000000001) assert compare(1.1, 1.1001) is False # Complex numbers assert compare(1.1 + 2.3j, 1.1 + 2.3j) assert compare(1.1 + 2.3j, 1.1 + 2.4j) is False # Strings assert compare('abc', 1.1001) is False assert compare('abc', 'abc') assert compare('abc', 'def') is False # Lists assert compare([1, 2, 3], [1, 2, 3]) assert compare([1, 2, 3], [1, 2, 4]) is False assert compare([1, 2, 3], [1, 2, 3.0000000001]) # Tuples assert compare((1, 2, 3), (1, 2, 3)) assert compare((1, 2, 3), (1, 2, 4)) is False assert compare((1, 2, 3), (1, 2, 3.0000000001)) # Nested lists + tuples assert compare( ['a', 'b', 'cdefg', [1, 2, [3]], (4, 5)], ['a', 'b', 'cdefg', [1, 2, [3]], (4, 5)] ) assert compare( ['a', 'b', 'cdefg', [1, 2, [3]], (4, 5)], ['a', 'b', 'cdefg', [1, 2, [3]], (4, '5')] ) is False assert compare( ['a', 'b', 'cdefg', [1, 2, [3]], (4, 5)], ['a', 'b', 'cdefg', [1, 2, [3]], [4, 5]] ) is False # Sets assert compare({1, 2}, {1, 2}) assert compare({1, 2}, {1}) is False assert compare({1}, {1, 2}) is False assert compare({1, 2}, {'1', 2}) is False assert compare({'a', 'b', 'c'}, {'a', 'b', 'c'}) assert compare({'a', 'b', 'c'}, {'a', 'b', 'C'}) is False # Two tricky cases assert compare({1, 2}, {1.00000001, 2}) assert compare({(1, 2), 3}, {(1.00000001, 2), 3}) # Dictionaries assert compare({1: 2, 3: 4}, {1: 2, 3: 4}) assert compare({1: 2, 3: 4}, {1: 2, 3.00000000001: 4}) assert compare({1: 2, 3: 4}, {1: 2, 3: 4.00000000001}) assert compare({1: 2, 3: 4}, {1: 2, 3.1: 4}) is False assert compare({1: 2, 3: 4}, {1: 2, 3: 4.1}) is False # Nested dictionaries & lists assert compare( {1: {1.1: 2.2}, 2: [2.2, 3.3]}, {1: {1.1: 2.2}, 2: [2.2, 3.3]} ) assert compare( {1: {1.1: 2.2}, 2: [2.2, 3.3]}, {1: {1.2: 2.2}, 2: [2.2, 3.3]} ) is False assert compare( {1: {1.1: 2.2}, 2: [2.2, 3.3]}, {1: {1.1: 2.3}, 2: [2.2, 3.3]} ) is False assert compare( {1: {1.1: 2.2}, 2: [2.2, 3.3]}, {1: {1.1: 2.2}, 2: [2.2, 3.4]} ) is False assert compare( {1: {1.1: 2.2}, 2: [2.2, 3.3]}, {1: {1.1: 2.2}, 2: 2.2} ) is False assert compare( {1: {1.1: 2.2}, 2: [2.2, 3.3]}, {1: {1.1: 2.2}, 2: (2.2, 3.3)} ) is False # Equivalent infinitely recursive list structures a = [1, 2, 3] a.append(a) a2 = [1, 2, 3] a2.append(a2) b = [1, 2, 3, [1, 2, 3]] b[3].append(b) c = [1, 2, 3, [1, 2, 3, [1, 2, 3]]] c[3][3].append(c[3]) d = [1, 2, 3] d.insert(2, d) assert compare(a, a) assert compare(a, a2) assert compare(a2, a) assert compare(a, b) assert compare(a, c) assert compare(a2, b) assert compare(b, a) assert compare(b, a2) assert compare(c, a) assert compare(b, c) assert compare(a, d) is False assert compare(b, d) is False assert compare(c, d) is False assert compare(d, a) is False assert compare(d, b) is False assert compare(d, c) is False # Equivalent infinitely recursive dicitonaries e = {1: 2} e[2] = e e2 = {1: 2} e2[2] = e2 f = {1: 2} f2 = {1: 2} f[2] = f2 f2[2] = f g = {1: 2, 2: {1: 2.0000000001}} g[2][2] = g h = {1: 2, 2: {1: 3}} h[2][2] = h assert compare(e, e2) assert compare(e2, e) assert compare(f, f2) assert compare(f2, f) assert compare(e, f) assert compare(f, e) assert compare(e, g) assert compare(f, g) assert compare(g, e) assert compare(g, f) assert compare(e, h) is False assert compare(f, h) is False assert compare(g, h) is False assert compare(h, e) is False assert compare(h, f) is False assert compare(h, g) is False # Custom types + objects class T: pass assert compare(T, T) assert compare(T, 1) is False assert compare(T(), T()) is False # Custom type w/ a custom __eq__ class E: def __eq__(self, other): return isinstance(other, E) or other == 1 assert compare(E, E) assert compare(E, 1) is False assert compare(E(), 1) assert compare(E(), 2) is False assert compare(E(), E()) # Custom type w/ custom __hash__ and __eq__ class A: def __init__(self, val): self.val = val def __hash__(self): return 3 # hashes collide def __eq__(self, other): return isinstance(other, A) and self.val == other.val assert compare({A(1), A(2)}, {A(1), A(2)}) assert compare({A(1), A(2)}, {A(1), A(3)}) is False #-----------------------# # Configuration control # #-----------------------# def detailLevel(level): """ Sets the level of detail for printed messages. The detail levels are: * -1: Super-minimal output, with no details beyond success/failure. * 0: Succinct messages indicating success/failure, with minimal details when failure occurs. * 1: More verbose success/failure messages, with details about successes and more details about failures. """ global DETAIL_LEVEL DETAIL_LEVEL = level def attendTrailingWhitespace(on=True): """ Call this function to force `optimism` to pay attention to whitespace at the end of lines when checking expectations. By default, such whitespace is removed both from expected values/output fragments and from captured outputs/results before checking expectations. To turn that functionality on again, you can call this function with False as the argument. """ global IGNORE_TRAILING_WHITESPACE IGNORE_TRAILING_WHITESPACE = not on def skipChecksAfterFail(mode="all"): """ The argument should be either 'case' (the default), 'manager', or None. In 'manager' mode, when one check fails, any other checks of cases derived from that manager, including the case where the check failed, will be skipped. In 'case' mode, once a check fails any further checks of the same case will be skipped, but checks of other cases derived from the same manager will not be. In None mode (or if any other value is provided) no checks will be skipped because of failed checks (but they might be skipped for other reasons). """ global SKIP_ON_FAILURE SKIP_ON_FAILURE = mode def clearFailure(): """ Resets the failure status so that checks will resume when `SKIP_ON_FAILURE` is set to `'all'`. """ global CHECK_FAILED CHECK_FAILED = False #---------------------------------# # Summarization and Case Tracking # #---------------------------------# def showSummary(suiteName=None): """ Shows a summary of the number of checks in the current test suite (see `currentTestSuite`) that have been met or not. You can also give an argument to specify the name of the test suite to summarize. Prints output to sys.stderr. """ if suiteName is None: suiteName = currentTestSuite() if suiteName not in ALL_CASES: raise ValueError(f"Test suite '{suiteName}' does not exist.") # Flush stdout & stderr to improve ordering sys.stdout.flush() sys.stderr.flush() met = [] unmet = [] for case in ALL_CASES[suiteName]: for passed, tag, msg in case.outcomes: if passed: met.append(tag) else: unmet.append(tag) print('---', file=sys.stderr) if len(unmet) == 0: if len(met) == 0: print("No expectations were established.", file=sys.stderr) else: print( f"All {len(met)} expectation(s) were met.", file=sys.stderr ) else: if len(met) == 0: print( f"None of the {len(unmet)} expectation(s) were met!", file=sys.stderr ) else: print( ( f"{len(unmet)} of the {len(met) + len(unmet)}" f" expectation(s) were NOT met:" ), file=sys.stderr ) if COLORS: # bright red print("\x1b[1;31m", end="", file=sys.stderr) for tag in unmet: print(f" ✗ {tag}", file=sys.stderr) if COLORS: # reset print("\x1b[0m", end="", file=sys.stderr) print('---', file=sys.stderr) # Flush stdout & stderr to improve ordering sys.stdout.flush() sys.stderr.flush() def currentTestSuite(): """ Returns the name of the current test suite (a string). """ return _CURRENT_SUITE_NAME def startTestSuite(name): """ Starts a new test suite with the given name, or resumes an old one. Any cases created subsequently will be registered to that suite. """ global _CURRENT_SUITE_NAME if not isinstance(name, str): raise TypeError( f"The test suite name must be a string (got: '{repr(name)}'" f" which is a {type(name)})." ) _CURRENT_SUITE_NAME = name def resetTestSuite(suiteName=None): """ Resets the cases recorded in the current test suite (or the named test suite if an argument is provided). """ if suiteName is None: suiteName = currentTestSuite() ALL_CASES[suiteName] = [] def deleteAllTestSuites(): """ Deletes all test suites, removing all recorded test cases, and setting the current test suite name back to "default". """ global ALL_CASES, _CURRENT_SUITE_NAME _CURRENT_SUITE_NAME = "default" ALL_CASES = {} def listCasesInSuite(suiteName=None): """ Returns a list of test cases (`TestCase` objects) in the current test suite (or the named suite if an argument is provided). """ if suiteName is None: suiteName = currentTestSuite() if suiteName not in ALL_CASES: raise ValueError(f"Test suite '{suiteName}' does not exist.") return ALL_CASES[suiteName][:] def listAllCases(): """ Returns a list of all registered test cases (`TestCase` objects) in any known test suite. Note that if `deleteAllTestSuites` has been called, this will not include any `TestCase` objects created before that point. """ result = [] for suiteName in ALL_CASES: result.extend(ALL_CASES[suiteName]) return result #---------------# # Color control # #---------------# def colors(enable=False): """ Enables or disables colors in printed output. If your output does not support ANSI color codes, the color output will show up as garbage and you can disable this. """ global COLORS COLORS = enable #---------# # Tracing # #---------# def trace(expr): """ Given an expression (actually, of course, just a value), returns the value it was given. But also prints a trace message indicating what the expression was, what value it had, and the line number of that line of code. The file name and overlength results are printed only when the `detailLevel` is set to 1 or higher. """ # Flush stdout & stderr to improve ordering sys.stdout.flush() sys.stderr.flush() ctx = get_my_context(trace) rep = repr(expr) short = ellipsis(repr(expr)) tag = "{line}".format(**ctx) if DETAIL_LEVEL >= 1: tag = "{file}:{line}".format(**ctx) print( f"{tag} {ctx['expr_src']} ⇒ {short}", file=sys.stderr ) if DETAIL_LEVEL >= 1 and short != rep: print(" Full result is:\n " + rep, file=sys.stderr) # Flush stdout & stderr to improve ordering sys.stdout.flush() sys.stderr.flush() return expr #------------------------------# # Reverse evaluation machinery # #------------------------------# def get_src_index(src, lineno, col_offset): """ Turns a line number and column offset into an absolute index into the given source string, assuming length-1 newlines. """ lines = src.splitlines() above = lines[:lineno - 1] return sum(len(line) for line in above) + len(above) + col_offset def test_gsr(): """Tests for get_src_index.""" s = 'a\nb\nc' assert get_src_index(s, 1, 0) == 0 assert get_src_index(s, 2, 0) == 2 assert get_src_index(s, 3, 0) == 4 assert s[get_src_index(s, 1, 0)] == 'a' assert s[get_src_index(s, 2, 0)] == 'b' assert s[get_src_index(s, 3, 0)] == 'c' def find_identifier_end(code, start_index): """ Given a code string and an index in that string which is the start of an identifier, returns the index of the end of that identifier. """ at = start_index + 1 while at < len(code): ch = code[at] if not ch.isalpha() and not ch.isdigit() and ch != '_': break at += 1 return at - 1 def test_find_identifier_end(): """Tests for find_identifier_end.""" assert find_identifier_end("abc.xyz", 0) == 2 assert find_identifier_end("abc.xyz", 1) == 2 assert find_identifier_end("abc.xyz", 2) == 2 assert find_identifier_end("abc.xyz", 4) == 6 assert find_identifier_end("abc.xyz", 5) == 6 assert find_identifier_end("abc.xyz", 6) == 6 assert find_identifier_end("abc_xyz123", 0) == 9 assert find_identifier_end("abc xyz123", 0) == 2 assert find_identifier_end("abc xyz123", 4) == 9 assert find_identifier_end("x", 0) == 0 assert find_identifier_end(" x", 2) == 2 assert find_identifier_end(" xyz1", 2) == 5 s = "def abc(def):\n print(xyz)\n" assert find_identifier_end(s, 0) == 2 assert find_identifier_end(s, 4) == 6 assert find_identifier_end(s, 8) == 10 assert find_identifier_end(s, 16) == 20 assert find_identifier_end(s, 22) == 24 def unquoted_enumerate(src, start_index): """ A generator that yields index, character pairs from the given code string, skipping quotation marks and the strings that they delimit, including triple-quotes and respecting backslash-escapes within strings. """ quote = None at = start_index while at < len(src): char = src[at] # skip escaped characters in quoted strings if quote and char == '\\': # (thank goodness I don't have to worry about r-strings) at += 2 continue # handle quoted strings elif char == '"' or char == "'": if quote == char: quote = None # single end quote at += 1 continue elif src[at:at + 3] in ('"""', "'''"): tq = src[at:at + 3] at += 3 # going to skip these no matter what if tq == quote or tq[0] == quote: # Ending triple-quote, or matching triple-quote at # end of single-quoted string = ending quote + # empty string quote = None continue else: if quote: # triple quote of other kind inside single or # triple quoted string continue else: quote = tq continue elif quote is None: # opening single quote quote = char at += 1 continue else: # single quote inside other quotes at += 1 continue # Non-quote characters in quoted strings elif quote: at += 1 continue else: yield (at, char) at += 1 continue def test_unquoted_enumerate(): """Tests for unquoted_enumerate.""" uqe = unquoted_enumerate assert list(uqe("abc'123'", 0)) == list(zip(range(3), "abc")) assert list(uqe("'abc'123", 0)) == list(zip(range(5, 8), "123")) assert list(uqe("'abc'123''", 0)) == list(zip(range(5, 8), "123")) assert list(uqe("'abc'123''", 1)) == [(1, 'a'), (2, 'b'), (3, 'c')] mls = "'''\na\nb\nc'''\ndef" assert list(uqe(mls, 0)) == list(zip(range(12, 16), "\ndef")) tqs = '"""\'\'\'ab\'\'\'\'""" cd' assert list(uqe(tqs, 0)) == [(15, ' '), (16, 'c'), (17, 'd')] rqs = "a'b'''c\"\"\"'''\"d\"''''\"\"\"e'''\"\"\"f\"\"\"'''" print(f"X: '{rqs[23]}'", file=sys.stderr) assert list(uqe(rqs, 0)) == [(0, 'a'), (6, 'c'), (23, 'e')] assert list(uqe(rqs, 6)) == [(6, 'c'), (23, 'e')] bss = "a'\\'b\\''c" assert list(uqe(bss, 0)) == [(0, 'a'), (8, 'c')] mqs = "'\"a'b\"" assert list(uqe(mqs, 0)) == [(4, 'b')] def find_nth_attribute_period(code, start_index, n): """ Given a string of Python code and a start index within that string, finds the nth period character (counting from first = zero) after that start point, but only considers periods which are used for attribute access, i.e., periods outside of quoted strings and which are not part of ellipses. Returns the index within the string of the period that it found. A period at the start index (if there is one) will be counted. Returns None if there are not enough periods in the code. If the start index is inside a quoted string, things will get weird, and the results will probably be wrong. """ for (at, char) in unquoted_enumerate(code, start_index): if char == '.': if code[at - 1:at] == '.' or code[at + 1:at + 2] == '.': # part of an ellipsis, so ignore it continue else: n -= 1 if n < 0: break # Did we hit the end of the string before counting below 0? if n < 0: return at else: return None def test_find_nth_attribute_period(): """Tests for find_nth_attribute_period.""" assert find_nth_attribute_period("a.b", 0, 0) == 1 assert find_nth_attribute_period("a.b", 0, 1) is None assert find_nth_attribute_period("a.b", 0, 100) is None assert find_nth_attribute_period("a.b.c", 0, 1) == 3 assert find_nth_attribute_period("a.b.cde.f", 0, 1) == 3 assert find_nth_attribute_period("a.b.cde.f", 0, 2) == 7 s = "a.b, c.d, 'e.f', g.h" assert find_nth_attribute_period(s, 0, 0) == 1 assert find_nth_attribute_period(s, 0, 1) == 6 assert find_nth_attribute_period(s, 0, 2) == 18 assert find_nth_attribute_period(s, 0, 3) is None assert find_nth_attribute_period(s, 0, 3) is None assert find_nth_attribute_period(s, 1, 0) == 1 assert find_nth_attribute_period(s, 2, 0) == 6 assert find_nth_attribute_period(s, 6, 0) == 6 assert find_nth_attribute_period(s, 7, 0) == 18 assert find_nth_attribute_period(s, 15, 0) == 18 def find_closing_item(code, start_index, openclose='()'): """ Given a string of Python code, a starting index where there's an open paren, bracket, etc., and a 2-character string containing the opening and closing delimiters of interest (parentheses by default), returns the index of the matching closing delimiter, or None if the opening delimiter is unclosed. Note that the given code must not contain syntax errors, or the behavior will be undefined. Does NOT work with quotation marks (single or double). """ level = 1 open_delim = openclose[0] close_delim = openclose[1] for at, char in unquoted_enumerate(code, start_index + 1): # Non-quoted open delimiters if char == open_delim: level += 1 # Non-quoted close delimiters elif char == close_delim: level -= 1 if level < 1: break # Everything else: ignore it if level == 0: return at else: return None def test_find_closing_item(): """Tests for find_closing_item.""" assert find_closing_item('()', 0, '()') == 1 assert find_closing_item('()', 0) == 1 assert find_closing_item('(())', 0, '()') == 3 assert find_closing_item('(())', 1, '()') == 2 assert find_closing_item('((word))', 0, '()') == 7 assert find_closing_item('((word))', 1, '()') == 6 assert find_closing_item('(("(("))', 0, '()') == 7 assert find_closing_item('(("(("))', 1, '()') == 6 assert find_closing_item('(("))"))', 0, '()') == 7 assert find_closing_item('(("))"))', 1, '()') == 6 assert find_closing_item('(()())', 0, '()') == 5 assert find_closing_item('(()())', 1, '()') == 2 assert find_closing_item('(()())', 3, '()') == 4 assert find_closing_item('(""")(\n""")', 0, '()') == 10 assert find_closing_item("\"abc(\" + ('''def''')", 9, '()') == 19 assert find_closing_item("\"abc(\" + ('''def''')", 0, '()') is None assert find_closing_item("\"abc(\" + ('''def''')", 4, '()') is None assert find_closing_item("(()", 0, '()') is None assert find_closing_item("(()", 1, '()') == 2 assert find_closing_item("()(", 0, '()') == 1 assert find_closing_item("()(", 2, '()') is None assert find_closing_item("[]", 0, '[]') == 1 assert find_closing_item("[]", 0) is None assert find_closing_item("{}", 0, '{}') == 1 assert find_closing_item("aabb", 0, 'ab') == 3 def find_unbracketed_comma(code, start_index): """ Given a string of Python code and a starting index, finds the next comma at or after that index which isn't surrounded by brackets of any kind that start at or after that index and which isn't in a quoted string. Returns the index of the matching comma, or None if there is none. Stops and returns None if it finds an unmatched closing bracket. Note that the given code must not contain syntax errors, or the behavior will be undefined. """ seeking = [] delims = { '(': ')', '[': ']', '{': '}' } closing = delims.values() for at, char in unquoted_enumerate(code, start_index): # Non-quoted open delimiter if char in delims: seeking.append(delims[char]) # Non-quoted matching close delimiter elif len(seeking) > 0 and char == seeking[-1]: seeking.pop() # Non-quoted non-matching close delimiter elif char in closing: return None # A non-quoted comma elif char == ',' and len(seeking) == 0: return at # Everything else: ignore it # Got to the end return None def test_find_unbracketed_comma(): """Tests for find_unbracketed_comma.""" assert find_unbracketed_comma('()', 0) is None assert find_unbracketed_comma('(),', 0) == 2 assert find_unbracketed_comma('((,),)', 0) is None assert find_unbracketed_comma('((,),),', 0) == 6 assert find_unbracketed_comma('((,),),', 1) == 4 assert find_unbracketed_comma(',,,', 1) == 1 assert find_unbracketed_comma('",,",","', 0) == 4 assert find_unbracketed_comma('"""\n,,\n""","""\n,,\n"""', 0) == 10 assert find_unbracketed_comma('"""\n,,\n""","""\n,,\n"""', 4) == 4 assert find_unbracketed_comma('"""\n,,\n"""+"""\n,,\n"""', 0) is None assert find_unbracketed_comma('\n\n,\n', 0) == 2 def get_expr_src(src, call_node): """ Gets the string containing the source code for the expression passed as the first argument to a function call, given the string source of the file that defines the function and the AST node for the function call. """ # Find the child node for the first (and only) argument arg_expr = call_node.args[0] # If get_source_segment is available, use that if hasattr(ast, "get_source_segment"): return textwrap.dedent( ast.get_source_segment(src, arg_expr) ).strip() else: # We're going to have to do this ourself: find the start of the # expression and state-machine to find a matching paren start = get_src_index(src, call_node.lineno, call_node.col_offset) open_paren = src.index('(', start) end = find_closing_item(src, open_paren, '()') # Note: can't be None because that would have been a SyntaxError first_comma = find_unbracketed_comma(src, open_paren + 1) # Could be None if it's a 1-argument function if first_comma is not None: end = min(end, first_comma) return textwrap.dedent(src[open_paren + 1:end]).strip() def get_ref_src(src, node): """ Gets the string containing the source code for a variable reference, attribute, or subscript. """ # Use get_source_segment if it's available if hasattr(ast, "get_source_segment"): return ast.get_source_segment(src, node) else: # We're going to have to do this ourself: find the start of the # expression and state-machine to find its end start = get_src_index(src, node.lineno, node.col_offset) # Figure out the end point if isinstance(node, ast.Attribute): # Find sub-attributes so we can count syntactic periods to # figure out where the name part begins to get the span inner_period_count = 0 for node in ast.walk(node): if isinstance(node, ast.Attribute): inner_period_count += 1 inner_period_count -= 1 # for the node itself dot = find_nth_attribute_period(src, start, inner_period_count) end = find_identifier_end(src, dot + 1) elif isinstance(node, ast.Name): # It's just an identifier so we can find the end end = find_identifier_end(src, start) elif isinstance(node, ast.Subscript): # Find start of sub-expression so we can find opening brace # and then match it to find the end inner = node.slice if isinstance(inner, ast.Slice): pass elif hasattr(ast, "Index") and isinstance(inner, ast.Index): # 3.7 Index has a "value" inner = inner.value elif hasattr(ast, "ExtSlice") and isinstance(inner, ast.ExtSlice): # 3.7 ExtSlice has "dims" inner = inner.dims[0] else: raise TypeError( f"Unexpected subscript slice type {type(inner)} for" f" node:\n{ast.dump(node)}" ) sub_start = get_src_index(src, inner.lineno, inner.col_offset) end = find_closing_item(src, sub_start - 1, "[]") return src[start:end + 1] def deepish_copy(obj, memo=None): """ Returns the deepest possible copy of the given object, using copy.deepcopy wherever possible and making shallower copies elsewhere. Basically a middle-ground between copy.deepcopy and copy.copy. """ if memo is None: memo = {} if id(obj) in memo: return memo[id(obj)] try: result = copy.deepcopy(obj) # not sure about memo dict compatibility memo[id(obj)] = result return result except Exception: if isinstance(obj, list): result = [] memo[id(obj)] = result result.extend(deepish_copy(item, memo) for item in obj) return result elif isinstance(obj, tuple): # Note: no way to pre-populate the memo, but also no way to # construct an infinitely-recursive tuple without having # some mutable structure at some layer... result = (deepish_copy(item, memo) for item in obj) memo[id(obj)] = result return result elif isinstance(obj, dict): result = {} memo[id(obj)] = result result.update( { deepish_copy(key, memo): deepish_copy(value, memo) for key, value in obj.items() } ) return result elif isinstance(obj, set): result = set() memo[id(obj)] = result result |= set(deepish_copy(item, memo) for item in obj) return result else: # Can't go deeper I guess try: result = copy.copy(obj) memo[id(obj)] = result return result except Exception: # Can't even copy (e.g., a module) result = obj memo[id(obj)] = result return result def get_external_calling_frame(): """ Uses the inspect module to get a reference to the stack frame which called into the `optimism` module. Returns None if it can't find an appropriate call frame in the current stack. Remember to del the result after you're done with it, so that garbage doesn't pile up. """ myname = __name__ cf = inspect.currentframe() while ( hasattr(cf, "f_back") and cf.f_globals.get("__name__") == myname ): cf = cf.f_back return cf def get_module(stack_frame): """ Given a stack frame, returns a reference to the module where the code from that frame was defined. Returns None if it can't figure that out. """ other_name = stack_frame.f_globals.get("__name__", None) return sys.modules.get(other_name) def get_filename(stack_frame, speculate_filename=True): """ Given a stack frame, returns the filename of the file in which the code which created that stack frame was defined. Returns None if that information isn't available via a __file__ global, or if speculate_filename is True (the default), uses the value of the frame's f_code.co_filename, which may not always be a real file on disk, or which is weird circumstances could be the name of a file on disk which is *not* where the code came from. """ filename = stack_frame.f_globals.get("__file__") if filename is None and speculate_filename: filename = stack_frame.f_code.co_filename return filename def get_code_line(stack_frame): """ Given a stack frame, returns """ return stack_frame.f_lineno def evaluate_in_context(node, stack_frame): """ Given an AST node which is an expression, returns the value of that expression as evaluated in the context of the given stack frame. Shallow copies of the stack frame's locals and globals are made in an attempt to prevent the code being evaluated from having any impact on the stack frame's values, but of course there's still some possibility of side effects... """ expr = ast.Expression(node) code = compile( expr, stack_frame.f_globals.get("__file__", "__unknown__"), 'eval' ) return eval( code, copy.copy(stack_frame.f_globals), copy.copy(stack_frame.f_locals) ) def walk_ast_in_order(node): """ Yields all of the descendants of the given node (or list of nodes) in execution order. Note that this has its limits, for example, if we run it on the code: ```py x = [A for y in C if D] ``` It will yield the nodes for C, then y, then D, then A, and finally x, but in actual execution the nodes for D and A may be executed multiple times before x is assigned. """ if node is None: pass # empty iterator elif isinstance(node, (list, tuple)): for child in node: yield from walk_ast_in_order(child) else: # must be an ast.something # Note: the node itself will be yielded LAST if isinstance(node, (ast.Module, ast.Interactive, ast.Expression)): yield from walk_ast_in_order(node.body) elif ( hasattr(ast, "FunctionType") and isinstance(node, ast.FunctionType) ): yield from walk_ast_in_order(node.argtypes) yield from walk_ast_in_order(node.returns) elif isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): yield from walk_ast_in_order(node.args) yield from walk_ast_in_order(node.returns) yield from walk_ast_in_order(reversed(node.decorator_list)) yield from walk_ast_in_order(node.body) elif isinstance(node, ast.ClassDef): yield from walk_ast_in_order(node.bases) yield from walk_ast_in_order(node.keywords) yield from walk_ast_in_order(reversed(node.decorator_list)) yield from walk_ast_in_order(node.body) elif isinstance(node, ast.Return): yield from walk_ast_in_order(node.value) elif isinstance(node, ast.Delete): yield from walk_ast_in_order(node.targets) elif isinstance(node, ast.Assign): yield from walk_ast_in_order(node.value) yield from walk_ast_in_order(node.targets) elif isinstance(node, ast.AugAssign): yield from walk_ast_in_order(node.value) yield from walk_ast_in_order(node.target) elif isinstance(node, ast.AnnAssign): yield from walk_ast_in_order(node.value) yield from walk_ast_in_order(node.annotation) yield from walk_ast_in_order(node.target) elif isinstance(node, (ast.For, ast.AsyncFor)): yield from walk_ast_in_order(node.iter) yield from walk_ast_in_order(node.target) yield from walk_ast_in_order(node.body) yield from walk_ast_in_order(node.orelse) elif isinstance(node, (ast.While, ast.If, ast.IfExp)): yield from walk_ast_in_order(node.test) yield from walk_ast_in_order(node.body) yield from walk_ast_in_order(node.orelse) elif isinstance(node, (ast.With, ast.AsyncWith)): yield from walk_ast_in_order(node.items) yield from walk_ast_in_order(node.items) elif isinstance(node, ast.Raise): yield from walk_ast_in_order(node.cause) yield from walk_ast_in_order(node.exc) elif isinstance(node, ast.Try): yield from walk_ast_in_order(node.body) yield from walk_ast_in_order(node.handlers) yield from walk_ast_in_order(node.orelse) yield from walk_ast_in_order(node.finalbody) elif isinstance(node, ast.Assert): yield from walk_ast_in_order(node.test) yield from walk_ast_in_order(node.msg) elif isinstance(node, ast.Expr): yield from walk_ast_in_order(node.value) # Import, ImportFrom, Global, Nonlocal, Pass, Break, and # Continue each have no executable content, so we'll yield them # but not any children elif isinstance(node, ast.BoolOp): yield from walk_ast_in_order(node.values) elif hasattr(ast, "NamedExpr") and isinstance(node, ast.NamedExpr): yield from walk_ast_in_order(node.value) yield from walk_ast_in_order(node.target) elif isinstance(node, ast.BinOp): yield from walk_ast_in_order(node.left) yield from walk_ast_in_order(node.right) elif isinstance(node, ast.UnaryOp): yield from walk_ast_in_order(node.operand) elif isinstance(node, ast.Lambda): yield from walk_ast_in_order(node.args) yield from walk_ast_in_order(node.body) elif isinstance(node, ast.Dict): for i in range(len(node.keys)): yield from walk_ast_in_order(node.keys[i]) yield from walk_ast_in_order(node.values[i]) elif isinstance(node, (ast.Tuple, ast.List, ast.Set)): yield from walk_ast_in_order(node.elts) elif isinstance(node, (ast.ListComp, ast.SetComp, ast.GeneratorExp)): yield from walk_ast_in_order(node.generators) yield from walk_ast_in_order(node.elt) elif isinstance(node, ast.DictComp): yield from walk_ast_in_order(node.generators) yield from walk_ast_in_order(node.key) yield from walk_ast_in_order(node.value) elif isinstance(node, (ast.Await, ast.Yield, ast.YieldFrom)): yield from walk_ast_in_order(node.value) elif isinstance(node, ast.Compare): yield from walk_ast_in_order(node.left) yield from walk_ast_in_order(node.comparators) elif isinstance(node, ast.Call): yield from walk_ast_in_order(node.func) yield from walk_ast_in_order(node.args) yield from walk_ast_in_order(node.keywords) elif isinstance(node, ast.FormattedValue): yield from walk_ast_in_order(node.value) yield from walk_ast_in_order(node.format_spec) elif isinstance(node, ast.JoinedStr): yield from walk_ast_in_order(node.values) elif isinstance(node, (ast.Attribute, ast.Starred)): yield from walk_ast_in_order(node.value) elif isinstance(node, ast.Subscript): yield from walk_ast_in_order(node.value) yield from walk_ast_in_order(node.slice) elif isinstance(node, ast.Slice): yield from walk_ast_in_order(node.lower) yield from walk_ast_in_order(node.upper) yield from walk_ast_in_order(node.step) # Constant and Name nodes don't have executable contents elif isinstance(node, ast.comprehension): yield from walk_ast_in_order(node.iter) yield from walk_ast_in_order(node.ifs) yield from walk_ast_in_order(node.target) elif isinstance(node, ast.ExceptHandler): yield from walk_ast_in_order(node.type) yield from walk_ast_in_order(node.body) elif isinstance(node, ast.arguments): yield from walk_ast_in_order(node.defaults) yield from walk_ast_in_order(node.kw_defaults) if hasattr(node, "posonlyargs"): yield from walk_ast_in_order(node.posonlyargs) yield from walk_ast_in_order(node.args) yield from walk_ast_in_order(node.vararg) yield from walk_ast_in_order(node.kwonlyargs) yield from walk_ast_in_order(node.kwarg) elif isinstance(node, ast.arg): yield from walk_ast_in_order(node.annotation) elif isinstance(node, ast.keyword): yield from walk_ast_in_order(node.value) elif isinstance(node, ast.withitem): yield from walk_ast_in_order(node.context_expr) yield from walk_ast_in_order(node.optional_vars) # alias and typeignore have no executable members # Finally, yield this node itself yield node def find_call_nodes_on_line(node, frame, function, lineno): """ Given an AST node, a stack frame, a function object, and a line number, looks for all function calls which occur on the given line number and which are calls to the given function (as evaluated in the given stack frame). Note that calls to functions defined as part of the given AST cannot be found in this manner, because the objects being called are newly created and one could not possibly pass a reference to one of them into this function. For that reason, if the function argument is a string, any function call whose call part matches the given string will be matched. Normally only Name nodes can match this way, but if ast.unparse is available, the string will also attempt to match (exactly) against the unparsed call expression. Calls that start on the given line number will match, but if there are no such calls, then a call on a preceding line whose expression includes the target line will be looked for and may match. The return value will be a list of ast.Call nodes, and they will be ordered in the same order that those nodes would be executed when the line of code is executed. """ def call_matches(call_node): """ Locally-defined matching predicate. """ nonlocal function call_expr = call_node.func return ( ( isinstance(function, str) and ( ( isinstance(call_expr, ast.Name) and call_expr.id == function ) or ( isinstance(call_expr, ast.Attribute) and call_expr.attr == function ) or ( hasattr(ast, "unparse") and ast.unparse(call_expr) == function ) ) ) or ( not isinstance(function, str) and evaluate_in_context(call_expr, frame) is function ) ) result = [] all_on_line = [] for child in walk_ast_in_order(node): # only consider call nodes on the target line if ( hasattr(child, "lineno") and child.lineno == lineno ): all_on_line.append(child) if isinstance(child, ast.Call) and call_matches(child): result.append(child) # If we didn't find any candidates, look outwards from ast nodes on # the target line to find a Call that encompasses them... if len(result) == 0: for on_line in all_on_line: here = getattr(on_line, "parent", None) while ( here is not None and not isinstance( here, # Call (what we're looking for) plus most nodes that # indicate there couldn't be a call grandparent: ( ast.Call, ast.Module, ast.Interactive, ast.Expression, ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Return, ast.Delete, ast.Assign, ast.AugAssign, ast.AnnAssign, ast.For, ast.AsyncFor, ast.While, ast.If, ast.With, ast.AsyncWith, ast.Raise, ast.Try, ast.Assert, ast.Assert, ast.Assert, ast.Assert, ast.Assert, ast.Assert, ast.Assert, ast.Assert, ) ) ): here = getattr(here, "parent", None) # If we found a Call that includes the target line as one # of its children... if isinstance(here, ast.Call) and call_matches(here): result.append(here) return result def assign_parents(root): """ Given an AST node, assigns "parent" attributes to each sub-node indicating their parent AST node. Assigns None as the value of the parent attribute of the root node. """ for node in ast.walk(root): for child in ast.iter_child_nodes(node): child.parent = node root.parent = None def is_inside_call_func(node): """ Given an AST node which has a parent attribute, traverses parents to see if this node is part of the func attribute of a Call node. """ if not hasattr(node, "parent") or node.parent is None: return False if isinstance(node.parent, ast.Call) and node.parent.func is node: return True else: return is_inside_call_func(node.parent) def tag_for(located): """ Given a dictionary which has 'file' and 'line' slots, returns a string to be used as the tag for a test with 'filename:line' as the format. Unless the `DETAIL_LEVEL` is 2 or higher, the filename will be shown without the full path. """ filename = located.get('file', '???') if DETAIL_LEVEL < 2: filename = os.path.basename(filename) line = located.get('line', '?') return f"{filename}:{line}" def get_my_location(speculate_filename=True): """ Fetches the filename and line number of the external module whose call into this module ended up invoking this function. Returns a dictionary with "file" and "line" keys. If speculate_filename is False, then the filename will be set to None in cases where a __file__ global cannot be found, instead of using f_code.co_filename as a backup. In some cases, this is useful because f_code.co_filename may not be a valid file. """ frame = get_external_calling_frame() try: filename = get_filename(frame, speculate_filename) lineno = get_code_line(frame) finally: del frame return { "file": filename, "line": lineno } def get_my_context(function_or_name): """ Returns a dictionary indicating the context of a function call, assuming that this function is called from within a function with the given name (or from within the given function), and that that function is being called from within a different module. The result has the following keys: - file: The filename of the calling module - line: The line number on which the call to the function occurred - src: The source code string of the calling module - expr: An AST node storing the expression passed as the first argument to the function - expr_src: The source code string of the expression passed as the first argument to the function - values: A dictionary mapping source code fragments to their values, for each variable reference in the test expression. These are deepish copies of the values encountered. - relevant: A list of source code fragments which appear in the values dictionary which are judged to be most-relevant to the result of the test. Currently, the relevant list just lists any fragments which aren't found in the func slot of Call nodes, under the assumption that we don't care as much about the values of the functions we're calling. Prints a warning and returns a dictionary with just "file" and "line" entries if the other context info is unavailable. """ if isinstance(function_or_name, types.FunctionType): function_name = function_or_name.__name__ else: function_name = function_or_name frame = get_external_calling_frame() try: filename = get_filename(frame) lineno = get_code_line(frame) if filename is None: src = None else: try: with open(filename, 'r') as fin: src = fin.read() except Exception: # We'll assume here that the source is something like an # interactive shell so we won't warn unless the detail # level is turned up. if DETAIL_LEVEL >= 2: print( "Warning: unable to get calling code's source.", file=sys.stderr ) print( ( "Call is on line {} of module {} from file" " '{}'" ).format( lineno, frame.f_globals.get("__name__"), filename ), file=sys.stderr ) src = None if src is None: return { "file": filename, "line": lineno } src_node = ast.parse(src, filename=filename, mode='exec') assign_parents(src_node) candidates = find_call_nodes_on_line( src_node, frame, function_or_name, lineno ) # What if there are zero candidates? if len(candidates) == 0: print( f"Warning: unable to find call node for {function_name}" f" on line {lineno} of file {filename}.", file=sys.stderr ) return { "file": filename, "line": lineno } # Figure out how many calls to get_my_context have happened # referencing this line before, so that we know which call on # this line we might be prev_this_line = COMPLETED_PER_LINE\ .setdefault(function_name, {})\ .setdefault((filename, lineno), 0) match = candidates[prev_this_line % len(candidates)] # Record this call so the next one will grab the subsequent # candidate COMPLETED_PER_LINE[function_name][(filename, lineno)] += 1 arg_expr = match.args[0] # Add .parent attributes assign_parents(arg_expr) # Source code for the expression expr_src = get_expr_src(src, match) # Prepare our result dictionary result = { "file": filename, "line": lineno, "src": src, "expr": arg_expr, "expr_src": expr_src, "values": {}, "relevant": set() } # Walk expression to find values for each variable for node in ast.walk(arg_expr): # If it's potentially a reference to a variable... if isinstance( node, (ast.Attribute, ast.Subscript, ast.Name) ): key = get_ref_src(src, node) if key not in result["values"]: # Don't re-evaluate multiply-reference expressions # Note: we assume they won't take on multiple # values; if they did, even our first evaluation # would probably be inaccurate. val = deepish_copy(evaluate_in_context(node, frame)) result["values"][key] = val if not is_inside_call_func(node): result["relevant"].add(key) return result finally: del frame