diff options
author | Benjamin Peterson <benjamin@python.org> | 2009-03-23 21:50:21 +0000 |
---|---|---|
committer | Benjamin Peterson <benjamin@python.org> | 2009-03-23 21:50:21 +0000 |
commit | 692428e77f467a2f2d4cebfff59fb0b5f9099547 (patch) | |
tree | f9a1c67844f2b92e31ccc5ca164364b4715f2980 /Lib/unittest.py | |
parent | 797eaf305a1e4cbaf2041e9b28b125398e2c235a (diff) | |
download | cpython-git-692428e77f467a2f2d4cebfff59fb0b5f9099547.tar.gz |
implement test skipping and expected failures
patch by myself #1034053
Diffstat (limited to 'Lib/unittest.py')
-rw-r--r-- | Lib/unittest.py | 204 |
1 files changed, 189 insertions, 15 deletions
diff --git a/Lib/unittest.py b/Lib/unittest.py index ccce746415..8263887a6a 100644 --- a/Lib/unittest.py +++ b/Lib/unittest.py @@ -53,6 +53,7 @@ import sys import traceback import os import types +import functools ############################################################################## # Exported classes and functions @@ -84,6 +85,79 @@ def _CmpToKey(mycmp): def _strclass(cls): return "%s.%s" % (cls.__module__, cls.__name__) + +class SkipTest(Exception): + """ + Raise this exception in a test to skip it. + + Usually you can use TestResult.skip() or one of the skipping decorators + instead of raising this directly. + """ + pass + +class _ExpectedFailure(Exception): + """ + Raise this when a test is expected to fail. + + This is an implementation detail. + """ + + def __init__(self, exc_info): + super(_ExpectedFailure, self).__init__() + self.exc_info = exc_info + +class _UnexpectedSuccess(Exception): + """ + The test was supposed to fail, but it didn't! + """ + pass + +def _id(obj): + return obj + +def skip(reason): + """ + Unconditionally skip a test. + """ + def decorator(test_item): + if isinstance(test_item, type) and issubclass(test_item, TestCase): + test_item.__unittest_skip__ = True + test_item.__unittest_skip_why__ = reason + return test_item + @functools.wraps(test_item) + def skip_wrapper(*args, **kwargs): + raise SkipTest(reason) + return skip_wrapper + return decorator + +def skipIf(condition, reason): + """ + Skip a test if the condition is true. + """ + if condition: + return skip(reason) + return _id + +def skipUnless(condition, reason): + """ + Skip a test unless the condition is true. + """ + if not condition: + return skip(reason) + return _id + + +def expectedFailure(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + try: + func(*args, **kwargs) + except Exception: + raise _ExpectedFailure(sys.exc_info()) + raise _UnexpectedSuccess + return wrapper + + __unittest = 1 class TestResult(object): @@ -101,6 +175,9 @@ class TestResult(object): self.failures = [] self.errors = [] self.testsRun = 0 + self.skipped = [] + self.expected_failures = [] + self.unexpected_successes = [] self.shouldStop = False def startTest(self, test): @@ -126,6 +203,19 @@ class TestResult(object): "Called when a test has completed successfully" pass + def addSkip(self, test, reason): + """Called when a test is skipped.""" + self.skipped.append((test, reason)) + + def addExpectedFailure(self, test, err): + """Called when an expected failure/error occured.""" + self.expected_failures.append( + (test, self._exc_info_to_string(err, test))) + + def addUnexpectedSuccess(self, test): + """Called when a test was expected to fail, but succeed.""" + self.unexpected_successes.append(test) + def wasSuccessful(self): "Tells whether or not this result was a success" return len(self.failures) == len(self.errors) == 0 @@ -274,25 +364,36 @@ class TestCase(object): try: try: self.setUp() + except SkipTest as e: + result.addSkip(self, str(e)) + return except Exception: result.addError(self, self._exc_info()) return - ok = False + success = False try: testMethod() - ok = True except self.failureException: result.addFailure(self, self._exc_info()) + except _ExpectedFailure as e: + result.addExpectedFailure(self, e.exc_info) + except _UnexpectedSuccess: + result.addUnexpectedSuccess(self) + except SkipTest as e: + result.addSkip(self, str(e)) except Exception: result.addError(self, self._exc_info()) + else: + success = True try: self.tearDown() except Exception: result.addError(self, self._exc_info()) - ok = False - if ok: result.addSuccess(self) + success = False + if success: + result.addSuccess(self) finally: result.stopTest(self) @@ -312,6 +413,10 @@ class TestCase(object): """ return sys.exc_info() + def skip(self, reason): + """Skip this test.""" + raise SkipTest(reason) + def fail(self, msg=None): """Fail immediately, with the given message.""" raise self.failureException(msg) @@ -419,8 +524,8 @@ class TestSuite(object): __str__ = __repr__ def __eq__(self, other): - if type(self) is not type(other): - return False + if not isinstance(other, self.__class__): + return NotImplemented return self._tests == other._tests def __ne__(self, other): @@ -469,6 +574,37 @@ class TestSuite(object): for test in self._tests: test.debug() +class ClassTestSuite(TestSuite): + """ + Suite of tests derived from a single TestCase class. + """ + + def __init__(self, tests, class_collected_from): + super(ClassTestSuite, self).__init__(tests) + self.collected_from = class_collected_from + + def id(self): + module = getattr(self.collected_from, "__module__", None) + if module is not None: + return "{0}.{1}".format(module, self.collected_from.__name__) + return self.collected_from.__name__ + + def run(self, result): + if getattr(self.collected_from, "__unittest_skip__", False): + # ClassTestSuite result pretends to be a TestCase enough to be + # reported. + result.startTest(self) + try: + result.addSkip(self, self.collected_from.__unittest_skip_why__) + finally: + result.stopTest(self) + else: + result = super(ClassTestSuite, self).run(result) + return result + + shortDescription = id + + class FunctionTestCase(TestCase): """A test case that wraps a test function. @@ -540,6 +676,7 @@ class TestLoader(object): testMethodPrefix = 'test' sortTestMethodsUsing = cmp suiteClass = TestSuite + classSuiteClass = ClassTestSuite def loadTestsFromTestCase(self, testCaseClass): """Return a suite of all tests cases contained in testCaseClass""" @@ -548,7 +685,9 @@ class TestLoader(object): testCaseNames = self.getTestCaseNames(testCaseClass) if not testCaseNames and hasattr(testCaseClass, 'runTest'): testCaseNames = ['runTest'] - return self.suiteClass(map(testCaseClass, testCaseNames)) + suite = self.classSuiteClass(map(testCaseClass, testCaseNames), + testCaseClass) + return suite def loadTestsFromModule(self, module): """Return a suite of all tests cases contained in the given module""" @@ -719,6 +858,30 @@ class _TextTestResult(TestResult): self.stream.write('F') self.stream.flush() + def addSkip(self, test, reason): + TestResult.addSkip(self, test, reason) + if self.showAll: + self.stream.writeln("skipped {0!r}".format(reason)) + elif self.dots: + self.stream.write("s") + self.stream.flush() + + def addExpectedFailure(self, test, err): + TestResult.addExpectedFailure(self, test, err) + if self.showAll: + self.stream.writeln("expected failure") + elif self.dots: + self.stream.write(".") + self.stream.flush() + + def addUnexpectedSuccess(self, test): + TestResult.addUnexpectedSuccess(self, test) + if self.showAll: + self.stream.writeln("unexpected success") + elif self.dots: + self.stream.write(".") + self.stream.flush() + def printErrors(self): if self.dots or self.showAll: self.stream.writeln() @@ -760,17 +923,28 @@ class TextTestRunner(object): self.stream.writeln("Ran %d test%s in %.3fs" % (run, run != 1 and "s" or "", timeTaken)) self.stream.writeln() + results = map(len, (result.expected_failures, + result.unexpected_successes, + result.skipped)) + expected_fails, unexpected_successes, skipped = results + infos = [] if not result.wasSuccessful(): - self.stream.write("FAILED (") + self.stream.write("FAILED") failed, errored = map(len, (result.failures, result.errors)) if failed: - self.stream.write("failures=%d" % failed) + infos.append("failures=%d" % failed) if errored: - if failed: self.stream.write(", ") - self.stream.write("errors=%d" % errored) - self.stream.writeln(")") + infos.append("errors=%d" % errored) else: - self.stream.writeln("OK") + self.stream.write("OK") + if skipped: + infos.append("skipped=%d" % skipped) + if expected_fails: + infos.append("expected failures=%d" % expected_fails) + if unexpected_successes: + infos.append("unexpected successes=%d" % unexpected_successes) + if infos: + self.stream.writeln(" (%s)" % (", ".join(infos),)) return result @@ -824,9 +998,9 @@ Examples: def parseArgs(self, argv): import getopt + long_opts = ['help','verbose','quiet'] try: - options, args = getopt.getopt(argv[1:], 'hHvq', - ['help','verbose','quiet']) + options, args = getopt.getopt(argv[1:], 'hHvq', long_opts) for opt, value in options: if opt in ('-h','-H','--help'): self.usageExit() |