diff options
Diffstat (limited to 'Lib/unittest.py')
-rw-r--r-- | Lib/unittest.py | 467 |
1 files changed, 446 insertions, 21 deletions
diff --git a/Lib/unittest.py b/Lib/unittest.py index 10b68ed6fd..0f2d2357f9 100644 --- a/Lib/unittest.py +++ b/Lib/unittest.py @@ -14,11 +14,11 @@ Simple usage: class IntegerArithmenticTestCase(unittest.TestCase): def testAdd(self): ## test method names begin 'test*' - self.assertEquals((1 + 2), 3) - self.assertEquals(0 + 1, 1) + self.assertEqual((1 + 2), 3) + self.assertEqual(0 + 1, 1) def testMultiply(self): - self.assertEquals((0 * 10), 0) - self.assertEquals((5 * 8), 40) + self.assertEqual((0 * 10), 0) + self.assertEqual((5 * 8), 40) if __name__ == '__main__': unittest.main() @@ -45,12 +45,15 @@ AND THERE IS NO OBLIGATION WHATSOEVER TO PROVIDE MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS. ''' -import time +import difflib +import functools +import os +import pprint +import re import sys +import time import traceback -import os import types -import functools ############################################################################## # Exported classes and functions @@ -251,11 +254,13 @@ class TestResult(object): len(self.failures)) -class AssertRaisesContext(object): +class _AssertRaisesContext(object): + """A context manager used to implement TestCase.assertRaises* methods.""" - def __init__(self, expected, test_case): + def __init__(self, expected, test_case, expected_regexp=None): self.expected = expected self.failureException = test_case.failureException + self.expected_regex = expected_regexp def __enter__(self): pass @@ -268,10 +273,20 @@ class AssertRaisesContext(object): exc_name = str(self.expected) raise self.failureException( "{0} not raised".format(exc_name)) - if issubclass(exc_type, self.expected): + if not issubclass(exc_type, self.expected): + # let unexpexted exceptions pass through + return False + if self.expected_regex is None: return True - # Let unexpected exceptions skip through - return False + + expected_regexp = self.expected_regex + if isinstance(expected_regexp, basestring): + expected_regexp = re.compile(expected_regexp) + if not expected_regexp.search(str(exc_value)): + raise self.failureException('"%s" does not match "%s"' % + (expected_regexp.pattern, str(exc_value))) + return True + class TestCase(object): @@ -315,6 +330,31 @@ class TestCase(object): (self.__class__, methodName)) self._testMethodDoc = testMethod.__doc__ + # Map types to custom assertEqual functions that will compare + # instances of said type in more detail to generate a more useful + # error message. + self.__type_equality_funcs = {} + self.addTypeEqualityFunc(dict, self.assertDictEqual) + self.addTypeEqualityFunc(list, self.assertListEqual) + self.addTypeEqualityFunc(tuple, self.assertTupleEqual) + self.addTypeEqualityFunc(set, self.assertSetEqual) + self.addTypeEqualityFunc(frozenset, self.assertSetEqual) + + def addTypeEqualityFunc(self, typeobj, function): + """Add a type specific assertEqual style function to compare a type. + + This method is for use by TestCase subclasses that need to register + their own type equality functions to provide nicer error messages. + + Args: + typeobj: The data type to call this function on when both values + are of the same type in assertEqual(). + function: The callable taking two arguments and an optional + msg= argument that raises self.failureException with a + useful error message when the two arguments are not equal. + """ + self.__type_equality_funcs[typeobj] = function + def setUp(self): "Hook method for setting up the test fixture before exercising it." pass @@ -330,14 +370,22 @@ class TestCase(object): return TestResult() def shortDescription(self): - """Returns a one-line description of the test, or None if no - description has been provided. + """Returns both the test method name and first line of its docstring. + + If no docstring is given, only returns the method name. - The default implementation of this method returns the first line of - the specified test method's docstring. + This method overrides unittest.TestCase.shortDescription(), which + only returns the first line of the docstring, obscuring the name + of the test upon failure. """ - doc = self._testMethodDoc - return doc and doc.split("\n")[0].strip() or None + desc = str(self) + doc_first_line = None + + if self._testMethodDoc: + doc_first_line = self._testMethodDoc.split("\n")[0].strip() + if doc_first_line: + desc = '\n'.join((desc, doc_first_line)) + return desc def id(self): return "%s.%s" % (_strclass(self.__class__), self._testMethodName) @@ -443,18 +491,45 @@ class TestCase(object): with self.failUnlessRaises(some_error_class): do_something() """ - context = AssertRaisesContext(excClass, self) + context = _AssertRaisesContext(excClass, self) if callableObj is None: return context with context: callableObj(*args, **kwargs) + def _getAssertEqualityFunc(self, first, second): + """Get a detailed comparison function for the types of the two args. + + Returns: A callable accepting (first, second, msg=None) that will + raise a failure exception if first != second with a useful human + readable error message for those types. + """ + # + # NOTE(gregory.p.smith): I considered isinstance(first, type(second)) + # and vice versa. I opted for the conservative approach in case + # subclasses are not intended to be compared in detail to their super + # class instances using a type equality func. This means testing + # subtypes won't automagically use the detailed comparison. Callers + # should use their type specific assertSpamEqual method to compare + # subclasses if the detailed comparison is desired and appropriate. + # See the discussion in http://bugs.python.org/issue2578. + # + if type(first) is type(second): + return self.__type_equality_funcs.get(type(first), + self._baseAssertEqual) + return self._baseAssertEqual + + def _baseAssertEqual(self, first, second, msg=None): + """The default assertEqual implementation, not type specific.""" + if not first == second: + raise self.failureException(msg or '%r != %r' % (first, second)) + def failUnlessEqual(self, first, second, msg=None): """Fail if the two objects are unequal as determined by the '==' operator. """ - if not first == second: - raise self.failureException(msg or '%r != %r' % (first, second)) + assertion_func = self._getAssertEqualityFunc(first, second) + assertion_func(first, second, msg=msg) def failIfEqual(self, first, second, msg=None): """Fail if the two objects are equal as determined by the '==' @@ -504,6 +579,356 @@ class TestCase(object): assertFalse = failIf + def assertSequenceEqual(self, seq1, seq2, msg=None, seq_type=None): + """An equality assertion for ordered sequences (like lists and tuples). + + For the purposes of this function, a valid orderd sequence type is one + which can be indexed, has a length, and has an equality operator. + + Args: + seq1: The first sequence to compare. + seq2: The second sequence to compare. + seq_type: The expected datatype of the sequences, or None if no + datatype should be enforced. + msg: Optional message to use on failure instead of a list of + differences. + """ + if seq_type != None: + seq_type_name = seq_type.__name__ + if not isinstance(seq1, seq_type): + raise self.failureException('First sequence is not a %s: %r' + % (seq_type_name, seq1)) + if not isinstance(seq2, seq_type): + raise self.failureException('Second sequence is not a %s: %r' + % (seq_type_name, seq2)) + else: + seq_type_name = "sequence" + + differing = None + try: + len1 = len(seq1) + except (TypeError, NotImplementedError): + differing = 'First %s has no length. Non-sequence?' % ( + seq_type_name) + + if differing is None: + try: + len2 = len(seq2) + except (TypeError, NotImplementedError): + differing = 'Second %s has no length. Non-sequence?' % ( + seq_type_name) + + if differing is None: + if seq1 == seq2: + return + + for i in xrange(min(len1, len2)): + try: + item1 = seq1[i] + except (TypeError, IndexError, NotImplementedError): + differing = ('Unable to index element %d of first %s\n' % + (i, seq_type_name)) + break + + try: + item2 = seq2[i] + except (TypeError, IndexError, NotImplementedError): + differing = ('Unable to index element %d of second %s\n' % + (i, seq_type_name)) + break + + if item1 != item2: + differing = ('First differing element %d:\n%s\n%s\n' % + (i, item1, item2)) + break + else: + if (len1 == len2 and seq_type is None and + type(seq1) != type(seq2)): + # The sequences are the same, but have differing types. + return + # A catch-all message for handling arbitrary user-defined + # sequences. + differing = '%ss differ:\n' % seq_type_name.capitalize() + if len1 > len2: + differing = ('First %s contains %d additional ' + 'elements.\n' % (seq_type_name, len1 - len2)) + try: + differing += ('First extra element %d:\n%s\n' % + (len2, seq1[len2])) + except (TypeError, IndexError, NotImplementedError): + differing += ('Unable to index element %d ' + 'of first %s\n' % (len2, seq_type_name)) + elif len1 < len2: + differing = ('Second %s contains %d additional ' + 'elements.\n' % (seq_type_name, len2 - len1)) + try: + differing += ('First extra element %d:\n%s\n' % + (len1, seq2[len1])) + except (TypeError, IndexError, NotImplementedError): + differing += ('Unable to index element %d ' + 'of second %s\n' % (len1, seq_type_name)) + if not msg: + msg = '\n'.join(difflib.ndiff(pprint.pformat(seq1).splitlines(), + pprint.pformat(seq2).splitlines())) + self.fail(differing + msg) + + def assertListEqual(self, list1, list2, msg=None): + """A list-specific equality assertion. + + Args: + list1: The first list to compare. + list2: The second list to compare. + msg: Optional message to use on failure instead of a list of + differences. + + """ + self.assertSequenceEqual(list1, list2, msg, seq_type=list) + + def assertTupleEqual(self, tuple1, tuple2, msg=None): + """A tuple-specific equality assertion. + + Args: + tuple1: The first tuple to compare. + tuple2: The second tuple to compare. + msg: Optional message to use on failure instead of a list of + differences. + """ + self.assertSequenceEqual(tuple1, tuple2, msg, seq_type=tuple) + + def assertSetEqual(self, set1, set2, msg=None): + """A set-specific equality assertion. + + Args: + set1: The first set to compare. + set2: The second set to compare. + msg: Optional message to use on failure instead of a list of + differences. + + For more general containership equality, assertSameElements will work + with things other than sets. This uses ducktyping to support + different types of sets, and is optimized for sets specifically + (parameters must support a difference method). + """ + try: + difference1 = set1.difference(set2) + except TypeError, e: + self.fail('invalid type when attempting set difference: %s' % e) + except AttributeError, e: + self.fail('first argument does not support set difference: %s' % e) + + try: + difference2 = set2.difference(set1) + except TypeError, e: + self.fail('invalid type when attempting set difference: %s' % e) + except AttributeError, e: + self.fail('second argument does not support set difference: %s' % e) + + if not (difference1 or difference2): + return + + if msg is not None: + self.fail(msg) + + lines = [] + if difference1: + lines.append('Items in the first set but not the second:') + for item in difference1: + lines.append(repr(item)) + if difference2: + lines.append('Items in the second set but not the first:') + for item in difference2: + lines.append(repr(item)) + self.fail('\n'.join(lines)) + + def assertIn(self, a, b, msg=None): + """Just like self.assert_(a in b), but with a nicer default message.""" + if msg is None: + msg = '"%s" not found in "%s"' % (a, b) + self.assert_(a in b, msg) + + def assertNotIn(self, a, b, msg=None): + """Just like self.assert_(a not in b), but with a nicer default message.""" + if msg is None: + msg = '"%s" unexpectedly found in "%s"' % (a, b) + self.assert_(a not in b, msg) + + def assertDictEqual(self, d1, d2, msg=None): + self.assert_(isinstance(d1, dict), 'First argument is not a dictionary') + self.assert_(isinstance(d2, dict), 'Second argument is not a dictionary') + + if d1 != d2: + self.fail(msg or ('\n' + '\n'.join(difflib.ndiff( + pprint.pformat(d1).splitlines(), + pprint.pformat(d2).splitlines())))) + + def assertDictContainsSubset(self, expected, actual, msg=None): + """Checks whether actual is a superset of expected.""" + missing = [] + mismatched = [] + for key, value in expected.iteritems(): + if key not in actual: + missing.append(key) + elif value != actual[key]: + mismatched.append('%s, expected: %s, actual: %s' % (key, value, + actual[key])) + + if not (missing or mismatched): + return + + missing_msg = mismatched_msg = '' + if missing: + missing_msg = 'Missing: %s' % ','.join(missing) + if mismatched: + mismatched_msg = 'Mismatched values: %s' % ','.join(mismatched) + + if msg: + msg = '%s: %s; %s' % (msg, missing_msg, mismatched_msg) + else: + msg = '%s; %s' % (missing_msg, mismatched_msg) + self.fail(msg) + + def assertSameElements(self, expected_seq, actual_seq, msg=None): + """An unordered sequence specific comparison. + + Raises with an error message listing which elements of expected_seq + are missing from actual_seq and vice versa if any. + """ + try: + expected = set(expected_seq) + actual = set(actual_seq) + missing = list(expected.difference(actual)) + unexpected = list(actual.difference(expected)) + missing.sort() + unexpected.sort() + except TypeError: + # Fall back to slower list-compare if any of the objects are + # not hashable. + expected = list(expected_seq) + actual = list(actual_seq) + expected.sort() + actual.sort() + missing, unexpected = _SortedListDifference(expected, actual) + errors = [] + if missing: + errors.append('Expected, but missing:\n %r\n' % missing) + if unexpected: + errors.append('Unexpected, but present:\n %r\n' % unexpected) + if errors: + self.fail(msg or ''.join(errors)) + + def assertMultiLineEqual(self, first, second, msg=None): + """Assert that two multi-line strings are equal.""" + self.assert_(isinstance(first, types.StringTypes), ( + 'First argument is not a string')) + self.assert_(isinstance(second, types.StringTypes), ( + 'Second argument is not a string')) + + if first != second: + raise self.failureException( + msg or '\n' + ''.join(difflib.ndiff(first.splitlines(True), + second.splitlines(True)))) + + def assertLess(self, a, b, msg=None): + """Just like self.assert_(a < b), but with a nicer default message.""" + if msg is None: + msg = '"%r" unexpectedly not less than "%r"' % (a, b) + self.assert_(a < b, msg) + + def assertLessEqual(self, a, b, msg=None): + """Just like self.assert_(a <= b), but with a nicer default message.""" + if msg is None: + msg = '"%r" unexpectedly not less than or equal to "%r"' % (a, b) + self.assert_(a <= b, msg) + + def assertGreater(self, a, b, msg=None): + """Just like self.assert_(a > b), but with a nicer default message.""" + if msg is None: + msg = '"%r" unexpectedly not greater than "%r"' % (a, b) + self.assert_(a > b, msg) + + def assertGreaterEqual(self, a, b, msg=None): + """Just like self.assert_(a >= b), but with a nicer default message.""" + if msg is None: + msg = '"%r" unexpectedly not greater than or equal to "%r"' % (a, b) + self.assert_(a >= b, msg) + + def assertIsNone(self, obj, msg=None): + """Same as self.assert_(obj is None), with a nicer default message.""" + if msg is None: + msg = '"%s" unexpectedly not None' % obj + self.assert_(obj is None, msg) + + def assertIsNotNone(self, obj, msg='unexpectedly None'): + """Included for symmetry with assertIsNone.""" + self.assert_(obj is not None, msg) + + def assertRaisesRegexp(self, expected_exception, expected_regexp, + callable_obj=None, *args, **kwargs): + """Asserts that the message in a raised exception matches a regexp. + + Args: + expected_exception: Exception class expected to be raised. + expected_regexp: Regexp (re pattern object or string) expected + to be found in error message. + callable_obj: Function to be called. + args: Extra args. + kwargs: Extra kwargs. + """ + context = _AssertRaisesContext(expected_exception, self, expected_regexp) + if callable_obj is None: + return context + with context: + callable_obj(*args, **kwargs) + + def assertRegexpMatches(self, text, expected_regex, msg=None): + if isinstance(expected_regex, basestring): + expected_regex = re.compile(expected_regex) + if not expected_regex.search(text): + msg = msg or "Regexp didn't match" + msg = '%s: %r not found in %r' % (msg, expected_regex.pattern, text) + raise self.failureException(msg) + + +def _SortedListDifference(expected, actual): + """Finds elements in only one or the other of two, sorted input lists. + + Returns a two-element tuple of lists. The first list contains those + elements in the "expected" list but not in the "actual" list, and the + second contains those elements in the "actual" list but not in the + "expected" list. Duplicate elements in either input list are ignored. + """ + i = j = 0 + missing = [] + unexpected = [] + while True: + try: + e = expected[i] + a = actual[j] + if e < a: + missing.append(e) + i += 1 + while expected[i] == e: + i += 1 + elif e > a: + unexpected.append(a) + j += 1 + while actual[j] == a: + j += 1 + else: + i += 1 + try: + while expected[i] == e: + i += 1 + finally: + j += 1 + while actual[j] == a: + j += 1 + except IndexError: + missing.extend(expected[i:]) + unexpected.extend(actual[j:]) + break + return missing, unexpected + class TestSuite(object): """A test suite is a composite test consisting of a number of TestCases. |