diff options
Diffstat (limited to 'Lib/unittest/case.py')
-rw-r--r-- | Lib/unittest/case.py | 70 |
1 files changed, 2 insertions, 68 deletions
diff --git a/Lib/unittest/case.py b/Lib/unittest/case.py index d0ee561a3a..f8bc865ee8 100644 --- a/Lib/unittest/case.py +++ b/Lib/unittest/case.py @@ -3,7 +3,6 @@ import sys import functools import difflib -import logging import pprint import re import warnings @@ -297,73 +296,6 @@ class _AssertWarnsContext(_AssertRaisesBaseContext): -_LoggingWatcher = collections.namedtuple("_LoggingWatcher", - ["records", "output"]) - - -class _CapturingHandler(logging.Handler): - """ - A logging handler capturing all (raw and formatted) logging output. - """ - - def __init__(self): - logging.Handler.__init__(self) - self.watcher = _LoggingWatcher([], []) - - def flush(self): - pass - - def emit(self, record): - self.watcher.records.append(record) - msg = self.format(record) - self.watcher.output.append(msg) - - - -class _AssertLogsContext(_BaseTestCaseContext): - """A context manager used to implement TestCase.assertLogs().""" - - LOGGING_FORMAT = "%(levelname)s:%(name)s:%(message)s" - - def __init__(self, test_case, logger_name, level): - _BaseTestCaseContext.__init__(self, test_case) - self.logger_name = logger_name - if level: - self.level = logging._nameToLevel.get(level, level) - else: - self.level = logging.INFO - self.msg = None - - def __enter__(self): - if isinstance(self.logger_name, logging.Logger): - logger = self.logger = self.logger_name - else: - logger = self.logger = logging.getLogger(self.logger_name) - formatter = logging.Formatter(self.LOGGING_FORMAT) - handler = _CapturingHandler() - handler.setFormatter(formatter) - self.watcher = handler.watcher - self.old_handlers = logger.handlers[:] - self.old_level = logger.level - self.old_propagate = logger.propagate - logger.handlers = [handler] - logger.setLevel(self.level) - logger.propagate = False - return handler.watcher - - def __exit__(self, exc_type, exc_value, tb): - self.logger.handlers = self.old_handlers - self.logger.propagate = self.old_propagate - self.logger.setLevel(self.old_level) - if exc_type is not None: - # let unexpected exceptions pass through - return False - if len(self.watcher.records) == 0: - self._raiseFailure( - "no logs of level {} or higher triggered on {}" - .format(logging.getLevelName(self.level), self.logger.name)) - - class _OrderedChainMap(collections.ChainMap): def __iter__(self): seen = set() @@ -854,6 +786,8 @@ class TestCase(object): self.assertEqual(cm.output, ['INFO:foo:first message', 'ERROR:foo.bar:second message']) """ + # Lazy import to avoid importing logging if it is not needed. + from ._log import _AssertLogsContext return _AssertLogsContext(self, logger, level) def _getAssertEqualityFunc(self, first, second): |