summaryrefslogtreecommitdiff
path: root/Lib/unittest.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/unittest.py')
-rw-r--r--Lib/unittest.py467
1 files changed, 446 insertions, 21 deletions
diff --git a/Lib/unittest.py b/Lib/unittest.py
index 10b68ed6fd..0f2d2357f9 100644
--- a/Lib/unittest.py
+++ b/Lib/unittest.py
@@ -14,11 +14,11 @@ Simple usage:
class IntegerArithmenticTestCase(unittest.TestCase):
def testAdd(self): ## test method names begin 'test*'
- self.assertEquals((1 + 2), 3)
- self.assertEquals(0 + 1, 1)
+ self.assertEqual((1 + 2), 3)
+ self.assertEqual(0 + 1, 1)
def testMultiply(self):
- self.assertEquals((0 * 10), 0)
- self.assertEquals((5 * 8), 40)
+ self.assertEqual((0 * 10), 0)
+ self.assertEqual((5 * 8), 40)
if __name__ == '__main__':
unittest.main()
@@ -45,12 +45,15 @@ AND THERE IS NO OBLIGATION WHATSOEVER TO PROVIDE MAINTENANCE,
SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS.
'''
-import time
+import difflib
+import functools
+import os
+import pprint
+import re
import sys
+import time
import traceback
-import os
import types
-import functools
##############################################################################
# Exported classes and functions
@@ -251,11 +254,13 @@ class TestResult(object):
len(self.failures))
-class AssertRaisesContext(object):
+class _AssertRaisesContext(object):
+ """A context manager used to implement TestCase.assertRaises* methods."""
- def __init__(self, expected, test_case):
+ def __init__(self, expected, test_case, expected_regexp=None):
self.expected = expected
self.failureException = test_case.failureException
+ self.expected_regex = expected_regexp
def __enter__(self):
pass
@@ -268,10 +273,20 @@ class AssertRaisesContext(object):
exc_name = str(self.expected)
raise self.failureException(
"{0} not raised".format(exc_name))
- if issubclass(exc_type, self.expected):
+ if not issubclass(exc_type, self.expected):
+ # let unexpexted exceptions pass through
+ return False
+ if self.expected_regex is None:
return True
- # Let unexpected exceptions skip through
- return False
+
+ expected_regexp = self.expected_regex
+ if isinstance(expected_regexp, basestring):
+ expected_regexp = re.compile(expected_regexp)
+ if not expected_regexp.search(str(exc_value)):
+ raise self.failureException('"%s" does not match "%s"' %
+ (expected_regexp.pattern, str(exc_value)))
+ return True
+
class TestCase(object):
@@ -315,6 +330,31 @@ class TestCase(object):
(self.__class__, methodName))
self._testMethodDoc = testMethod.__doc__
+ # Map types to custom assertEqual functions that will compare
+ # instances of said type in more detail to generate a more useful
+ # error message.
+ self.__type_equality_funcs = {}
+ self.addTypeEqualityFunc(dict, self.assertDictEqual)
+ self.addTypeEqualityFunc(list, self.assertListEqual)
+ self.addTypeEqualityFunc(tuple, self.assertTupleEqual)
+ self.addTypeEqualityFunc(set, self.assertSetEqual)
+ self.addTypeEqualityFunc(frozenset, self.assertSetEqual)
+
+ def addTypeEqualityFunc(self, typeobj, function):
+ """Add a type specific assertEqual style function to compare a type.
+
+ This method is for use by TestCase subclasses that need to register
+ their own type equality functions to provide nicer error messages.
+
+ Args:
+ typeobj: The data type to call this function on when both values
+ are of the same type in assertEqual().
+ function: The callable taking two arguments and an optional
+ msg= argument that raises self.failureException with a
+ useful error message when the two arguments are not equal.
+ """
+ self.__type_equality_funcs[typeobj] = function
+
def setUp(self):
"Hook method for setting up the test fixture before exercising it."
pass
@@ -330,14 +370,22 @@ class TestCase(object):
return TestResult()
def shortDescription(self):
- """Returns a one-line description of the test, or None if no
- description has been provided.
+ """Returns both the test method name and first line of its docstring.
+
+ If no docstring is given, only returns the method name.
- The default implementation of this method returns the first line of
- the specified test method's docstring.
+ This method overrides unittest.TestCase.shortDescription(), which
+ only returns the first line of the docstring, obscuring the name
+ of the test upon failure.
"""
- doc = self._testMethodDoc
- return doc and doc.split("\n")[0].strip() or None
+ desc = str(self)
+ doc_first_line = None
+
+ if self._testMethodDoc:
+ doc_first_line = self._testMethodDoc.split("\n")[0].strip()
+ if doc_first_line:
+ desc = '\n'.join((desc, doc_first_line))
+ return desc
def id(self):
return "%s.%s" % (_strclass(self.__class__), self._testMethodName)
@@ -443,18 +491,45 @@ class TestCase(object):
with self.failUnlessRaises(some_error_class):
do_something()
"""
- context = AssertRaisesContext(excClass, self)
+ context = _AssertRaisesContext(excClass, self)
if callableObj is None:
return context
with context:
callableObj(*args, **kwargs)
+ def _getAssertEqualityFunc(self, first, second):
+ """Get a detailed comparison function for the types of the two args.
+
+ Returns: A callable accepting (first, second, msg=None) that will
+ raise a failure exception if first != second with a useful human
+ readable error message for those types.
+ """
+ #
+ # NOTE(gregory.p.smith): I considered isinstance(first, type(second))
+ # and vice versa. I opted for the conservative approach in case
+ # subclasses are not intended to be compared in detail to their super
+ # class instances using a type equality func. This means testing
+ # subtypes won't automagically use the detailed comparison. Callers
+ # should use their type specific assertSpamEqual method to compare
+ # subclasses if the detailed comparison is desired and appropriate.
+ # See the discussion in http://bugs.python.org/issue2578.
+ #
+ if type(first) is type(second):
+ return self.__type_equality_funcs.get(type(first),
+ self._baseAssertEqual)
+ return self._baseAssertEqual
+
+ def _baseAssertEqual(self, first, second, msg=None):
+ """The default assertEqual implementation, not type specific."""
+ if not first == second:
+ raise self.failureException(msg or '%r != %r' % (first, second))
+
def failUnlessEqual(self, first, second, msg=None):
"""Fail if the two objects are unequal as determined by the '=='
operator.
"""
- if not first == second:
- raise self.failureException(msg or '%r != %r' % (first, second))
+ assertion_func = self._getAssertEqualityFunc(first, second)
+ assertion_func(first, second, msg=msg)
def failIfEqual(self, first, second, msg=None):
"""Fail if the two objects are equal as determined by the '=='
@@ -504,6 +579,356 @@ class TestCase(object):
assertFalse = failIf
+ def assertSequenceEqual(self, seq1, seq2, msg=None, seq_type=None):
+ """An equality assertion for ordered sequences (like lists and tuples).
+
+ For the purposes of this function, a valid orderd sequence type is one
+ which can be indexed, has a length, and has an equality operator.
+
+ Args:
+ seq1: The first sequence to compare.
+ seq2: The second sequence to compare.
+ seq_type: The expected datatype of the sequences, or None if no
+ datatype should be enforced.
+ msg: Optional message to use on failure instead of a list of
+ differences.
+ """
+ if seq_type != None:
+ seq_type_name = seq_type.__name__
+ if not isinstance(seq1, seq_type):
+ raise self.failureException('First sequence is not a %s: %r'
+ % (seq_type_name, seq1))
+ if not isinstance(seq2, seq_type):
+ raise self.failureException('Second sequence is not a %s: %r'
+ % (seq_type_name, seq2))
+ else:
+ seq_type_name = "sequence"
+
+ differing = None
+ try:
+ len1 = len(seq1)
+ except (TypeError, NotImplementedError):
+ differing = 'First %s has no length. Non-sequence?' % (
+ seq_type_name)
+
+ if differing is None:
+ try:
+ len2 = len(seq2)
+ except (TypeError, NotImplementedError):
+ differing = 'Second %s has no length. Non-sequence?' % (
+ seq_type_name)
+
+ if differing is None:
+ if seq1 == seq2:
+ return
+
+ for i in xrange(min(len1, len2)):
+ try:
+ item1 = seq1[i]
+ except (TypeError, IndexError, NotImplementedError):
+ differing = ('Unable to index element %d of first %s\n' %
+ (i, seq_type_name))
+ break
+
+ try:
+ item2 = seq2[i]
+ except (TypeError, IndexError, NotImplementedError):
+ differing = ('Unable to index element %d of second %s\n' %
+ (i, seq_type_name))
+ break
+
+ if item1 != item2:
+ differing = ('First differing element %d:\n%s\n%s\n' %
+ (i, item1, item2))
+ break
+ else:
+ if (len1 == len2 and seq_type is None and
+ type(seq1) != type(seq2)):
+ # The sequences are the same, but have differing types.
+ return
+ # A catch-all message for handling arbitrary user-defined
+ # sequences.
+ differing = '%ss differ:\n' % seq_type_name.capitalize()
+ if len1 > len2:
+ differing = ('First %s contains %d additional '
+ 'elements.\n' % (seq_type_name, len1 - len2))
+ try:
+ differing += ('First extra element %d:\n%s\n' %
+ (len2, seq1[len2]))
+ except (TypeError, IndexError, NotImplementedError):
+ differing += ('Unable to index element %d '
+ 'of first %s\n' % (len2, seq_type_name))
+ elif len1 < len2:
+ differing = ('Second %s contains %d additional '
+ 'elements.\n' % (seq_type_name, len2 - len1))
+ try:
+ differing += ('First extra element %d:\n%s\n' %
+ (len1, seq2[len1]))
+ except (TypeError, IndexError, NotImplementedError):
+ differing += ('Unable to index element %d '
+ 'of second %s\n' % (len1, seq_type_name))
+ if not msg:
+ msg = '\n'.join(difflib.ndiff(pprint.pformat(seq1).splitlines(),
+ pprint.pformat(seq2).splitlines()))
+ self.fail(differing + msg)
+
+ def assertListEqual(self, list1, list2, msg=None):
+ """A list-specific equality assertion.
+
+ Args:
+ list1: The first list to compare.
+ list2: The second list to compare.
+ msg: Optional message to use on failure instead of a list of
+ differences.
+
+ """
+ self.assertSequenceEqual(list1, list2, msg, seq_type=list)
+
+ def assertTupleEqual(self, tuple1, tuple2, msg=None):
+ """A tuple-specific equality assertion.
+
+ Args:
+ tuple1: The first tuple to compare.
+ tuple2: The second tuple to compare.
+ msg: Optional message to use on failure instead of a list of
+ differences.
+ """
+ self.assertSequenceEqual(tuple1, tuple2, msg, seq_type=tuple)
+
+ def assertSetEqual(self, set1, set2, msg=None):
+ """A set-specific equality assertion.
+
+ Args:
+ set1: The first set to compare.
+ set2: The second set to compare.
+ msg: Optional message to use on failure instead of a list of
+ differences.
+
+ For more general containership equality, assertSameElements will work
+ with things other than sets. This uses ducktyping to support
+ different types of sets, and is optimized for sets specifically
+ (parameters must support a difference method).
+ """
+ try:
+ difference1 = set1.difference(set2)
+ except TypeError, e:
+ self.fail('invalid type when attempting set difference: %s' % e)
+ except AttributeError, e:
+ self.fail('first argument does not support set difference: %s' % e)
+
+ try:
+ difference2 = set2.difference(set1)
+ except TypeError, e:
+ self.fail('invalid type when attempting set difference: %s' % e)
+ except AttributeError, e:
+ self.fail('second argument does not support set difference: %s' % e)
+
+ if not (difference1 or difference2):
+ return
+
+ if msg is not None:
+ self.fail(msg)
+
+ lines = []
+ if difference1:
+ lines.append('Items in the first set but not the second:')
+ for item in difference1:
+ lines.append(repr(item))
+ if difference2:
+ lines.append('Items in the second set but not the first:')
+ for item in difference2:
+ lines.append(repr(item))
+ self.fail('\n'.join(lines))
+
+ def assertIn(self, a, b, msg=None):
+ """Just like self.assert_(a in b), but with a nicer default message."""
+ if msg is None:
+ msg = '"%s" not found in "%s"' % (a, b)
+ self.assert_(a in b, msg)
+
+ def assertNotIn(self, a, b, msg=None):
+ """Just like self.assert_(a not in b), but with a nicer default message."""
+ if msg is None:
+ msg = '"%s" unexpectedly found in "%s"' % (a, b)
+ self.assert_(a not in b, msg)
+
+ def assertDictEqual(self, d1, d2, msg=None):
+ self.assert_(isinstance(d1, dict), 'First argument is not a dictionary')
+ self.assert_(isinstance(d2, dict), 'Second argument is not a dictionary')
+
+ if d1 != d2:
+ self.fail(msg or ('\n' + '\n'.join(difflib.ndiff(
+ pprint.pformat(d1).splitlines(),
+ pprint.pformat(d2).splitlines()))))
+
+ def assertDictContainsSubset(self, expected, actual, msg=None):
+ """Checks whether actual is a superset of expected."""
+ missing = []
+ mismatched = []
+ for key, value in expected.iteritems():
+ if key not in actual:
+ missing.append(key)
+ elif value != actual[key]:
+ mismatched.append('%s, expected: %s, actual: %s' % (key, value,
+ actual[key]))
+
+ if not (missing or mismatched):
+ return
+
+ missing_msg = mismatched_msg = ''
+ if missing:
+ missing_msg = 'Missing: %s' % ','.join(missing)
+ if mismatched:
+ mismatched_msg = 'Mismatched values: %s' % ','.join(mismatched)
+
+ if msg:
+ msg = '%s: %s; %s' % (msg, missing_msg, mismatched_msg)
+ else:
+ msg = '%s; %s' % (missing_msg, mismatched_msg)
+ self.fail(msg)
+
+ def assertSameElements(self, expected_seq, actual_seq, msg=None):
+ """An unordered sequence specific comparison.
+
+ Raises with an error message listing which elements of expected_seq
+ are missing from actual_seq and vice versa if any.
+ """
+ try:
+ expected = set(expected_seq)
+ actual = set(actual_seq)
+ missing = list(expected.difference(actual))
+ unexpected = list(actual.difference(expected))
+ missing.sort()
+ unexpected.sort()
+ except TypeError:
+ # Fall back to slower list-compare if any of the objects are
+ # not hashable.
+ expected = list(expected_seq)
+ actual = list(actual_seq)
+ expected.sort()
+ actual.sort()
+ missing, unexpected = _SortedListDifference(expected, actual)
+ errors = []
+ if missing:
+ errors.append('Expected, but missing:\n %r\n' % missing)
+ if unexpected:
+ errors.append('Unexpected, but present:\n %r\n' % unexpected)
+ if errors:
+ self.fail(msg or ''.join(errors))
+
+ def assertMultiLineEqual(self, first, second, msg=None):
+ """Assert that two multi-line strings are equal."""
+ self.assert_(isinstance(first, types.StringTypes), (
+ 'First argument is not a string'))
+ self.assert_(isinstance(second, types.StringTypes), (
+ 'Second argument is not a string'))
+
+ if first != second:
+ raise self.failureException(
+ msg or '\n' + ''.join(difflib.ndiff(first.splitlines(True),
+ second.splitlines(True))))
+
+ def assertLess(self, a, b, msg=None):
+ """Just like self.assert_(a < b), but with a nicer default message."""
+ if msg is None:
+ msg = '"%r" unexpectedly not less than "%r"' % (a, b)
+ self.assert_(a < b, msg)
+
+ def assertLessEqual(self, a, b, msg=None):
+ """Just like self.assert_(a <= b), but with a nicer default message."""
+ if msg is None:
+ msg = '"%r" unexpectedly not less than or equal to "%r"' % (a, b)
+ self.assert_(a <= b, msg)
+
+ def assertGreater(self, a, b, msg=None):
+ """Just like self.assert_(a > b), but with a nicer default message."""
+ if msg is None:
+ msg = '"%r" unexpectedly not greater than "%r"' % (a, b)
+ self.assert_(a > b, msg)
+
+ def assertGreaterEqual(self, a, b, msg=None):
+ """Just like self.assert_(a >= b), but with a nicer default message."""
+ if msg is None:
+ msg = '"%r" unexpectedly not greater than or equal to "%r"' % (a, b)
+ self.assert_(a >= b, msg)
+
+ def assertIsNone(self, obj, msg=None):
+ """Same as self.assert_(obj is None), with a nicer default message."""
+ if msg is None:
+ msg = '"%s" unexpectedly not None' % obj
+ self.assert_(obj is None, msg)
+
+ def assertIsNotNone(self, obj, msg='unexpectedly None'):
+ """Included for symmetry with assertIsNone."""
+ self.assert_(obj is not None, msg)
+
+ def assertRaisesRegexp(self, expected_exception, expected_regexp,
+ callable_obj=None, *args, **kwargs):
+ """Asserts that the message in a raised exception matches a regexp.
+
+ Args:
+ expected_exception: Exception class expected to be raised.
+ expected_regexp: Regexp (re pattern object or string) expected
+ to be found in error message.
+ callable_obj: Function to be called.
+ args: Extra args.
+ kwargs: Extra kwargs.
+ """
+ context = _AssertRaisesContext(expected_exception, self, expected_regexp)
+ if callable_obj is None:
+ return context
+ with context:
+ callable_obj(*args, **kwargs)
+
+ def assertRegexpMatches(self, text, expected_regex, msg=None):
+ if isinstance(expected_regex, basestring):
+ expected_regex = re.compile(expected_regex)
+ if not expected_regex.search(text):
+ msg = msg or "Regexp didn't match"
+ msg = '%s: %r not found in %r' % (msg, expected_regex.pattern, text)
+ raise self.failureException(msg)
+
+
+def _SortedListDifference(expected, actual):
+ """Finds elements in only one or the other of two, sorted input lists.
+
+ Returns a two-element tuple of lists. The first list contains those
+ elements in the "expected" list but not in the "actual" list, and the
+ second contains those elements in the "actual" list but not in the
+ "expected" list. Duplicate elements in either input list are ignored.
+ """
+ i = j = 0
+ missing = []
+ unexpected = []
+ while True:
+ try:
+ e = expected[i]
+ a = actual[j]
+ if e < a:
+ missing.append(e)
+ i += 1
+ while expected[i] == e:
+ i += 1
+ elif e > a:
+ unexpected.append(a)
+ j += 1
+ while actual[j] == a:
+ j += 1
+ else:
+ i += 1
+ try:
+ while expected[i] == e:
+ i += 1
+ finally:
+ j += 1
+ while actual[j] == a:
+ j += 1
+ except IndexError:
+ missing.extend(expected[i:])
+ unexpected.extend(actual[j:])
+ break
+ return missing, unexpected
+
class TestSuite(object):
"""A test suite is a composite test consisting of a number of TestCases.