diff options
author | Lisa Roach <lisaroach14@gmail.com> | 2019-05-20 09:19:53 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-05-20 09:19:53 -0700 |
commit | 77b3b7701a34ecf6316469e05b79bb91de2addfa (patch) | |
tree | 305506415c811e5a01e4ee783f3346f0359b17a2 /Lib/unittest/mock.py | |
parent | 0f72147ce2b3d65235b41eddc6a57be40237b5c7 (diff) | |
download | cpython-git-77b3b7701a34ecf6316469e05b79bb91de2addfa.tar.gz |
bpo-26467: Adds AsyncMock for asyncio Mock library support (GH-9296)
Diffstat (limited to 'Lib/unittest/mock.py')
-rw-r--r-- | Lib/unittest/mock.py | 406 |
1 files changed, 389 insertions, 17 deletions
diff --git a/Lib/unittest/mock.py b/Lib/unittest/mock.py index 47ed06c6f4..166c100376 100644 --- a/Lib/unittest/mock.py +++ b/Lib/unittest/mock.py @@ -13,6 +13,7 @@ __all__ = ( 'ANY', 'call', 'create_autospec', + 'AsyncMock', 'FILTER_DIR', 'NonCallableMock', 'NonCallableMagicMock', @@ -24,13 +25,13 @@ __all__ = ( __version__ = '1.0' - +import asyncio import io import inspect import pprint import sys import builtins -from types import ModuleType, MethodType +from types import CodeType, ModuleType, MethodType from unittest.util import safe_repr from functools import wraps, partial @@ -43,6 +44,13 @@ FILTER_DIR = True # Without this, the __class__ properties wouldn't be set correctly _safe_super = super +def _is_async_obj(obj): + if getattr(obj, '__code__', None): + return asyncio.iscoroutinefunction(obj) or inspect.isawaitable(obj) + else: + return False + + def _is_instance_mock(obj): # can't use isinstance on Mock objects because they override __class__ # The base class for all mocks is NonCallableMock @@ -355,7 +363,20 @@ class NonCallableMock(Base): # every instance has its own class # so we can create magic methods on the # class without stomping on other mocks - new = type(cls.__name__, (cls,), {'__doc__': cls.__doc__}) + bases = (cls,) + if not issubclass(cls, AsyncMock): + # Check if spec is an async object or function + sig = inspect.signature(NonCallableMock.__init__) + bound_args = sig.bind_partial(cls, *args, **kw).arguments + spec_arg = [ + arg for arg in bound_args.keys() + if arg.startswith('spec') + ] + if spec_arg: + # what if spec_set is different than spec? + if _is_async_obj(bound_args[spec_arg[0]]): + bases = (AsyncMockMixin, cls,) + new = type(cls.__name__, bases, {'__doc__': cls.__doc__}) instance = object.__new__(new) return instance @@ -431,6 +452,11 @@ class NonCallableMock(Base): _eat_self=False): _spec_class = None _spec_signature = None + _spec_asyncs = [] + + for attr in dir(spec): + if asyncio.iscoroutinefunction(getattr(spec, attr, None)): + _spec_asyncs.append(attr) if spec is not None and not _is_list(spec): if isinstance(spec, type): @@ -448,7 +474,7 @@ class NonCallableMock(Base): __dict__['_spec_set'] = spec_set __dict__['_spec_signature'] = _spec_signature __dict__['_mock_methods'] = spec - + __dict__['_spec_asyncs'] = _spec_asyncs def __get_return_value(self): ret = self._mock_return_value @@ -886,7 +912,15 @@ class NonCallableMock(Base): For non-callable mocks the callable variant will be used (rather than any custom subclass).""" + _new_name = kw.get("_new_name") + if _new_name in self.__dict__['_spec_asyncs']: + return AsyncMock(**kw) + _type = type(self) + if issubclass(_type, MagicMock) and _new_name in _async_method_magics: + klass = AsyncMock + if issubclass(_type, AsyncMockMixin): + klass = MagicMock if not issubclass(_type, CallableMixin): if issubclass(_type, NonCallableMagicMock): klass = MagicMock @@ -932,14 +966,12 @@ def _try_iter(obj): return obj - class CallableMixin(Base): def __init__(self, spec=None, side_effect=None, return_value=DEFAULT, wraps=None, name=None, spec_set=None, parent=None, _spec_state=None, _new_name='', _new_parent=None, **kwargs): self.__dict__['_mock_return_value'] = return_value - _safe_super(CallableMixin, self).__init__( spec, wraps, name, spec_set, parent, _spec_state, _new_name, _new_parent, **kwargs @@ -1081,7 +1113,6 @@ class Mock(CallableMixin, NonCallableMock): """ - def _dot_lookup(thing, comp, import_path): try: return getattr(thing, comp) @@ -1279,8 +1310,10 @@ class _patch(object): if isinstance(original, type): # If we're patching out a class and there is a spec inherit = True - - Klass = MagicMock + if spec is None and _is_async_obj(original): + Klass = AsyncMock + else: + Klass = MagicMock _kwargs = {} if new_callable is not None: Klass = new_callable @@ -1292,7 +1325,9 @@ class _patch(object): not_callable = '__call__' not in this_spec else: not_callable = not callable(this_spec) - if not_callable: + if _is_async_obj(this_spec): + Klass = AsyncMock + elif not_callable: Klass = NonCallableMagicMock if spec is not None: @@ -1733,7 +1768,7 @@ _non_defaults = { '__reduce__', '__reduce_ex__', '__getinitargs__', '__getnewargs__', '__getstate__', '__setstate__', '__getformat__', '__setformat__', '__repr__', '__dir__', '__subclasses__', '__format__', - '__getnewargs_ex__', + '__getnewargs_ex__', '__aenter__', '__aexit__', '__anext__', '__aiter__', } @@ -1750,6 +1785,11 @@ _magics = { ' '.join([magic_methods, numerics, inplace, right]).split() } +# Magic methods used for async `with` statements +_async_method_magics = {"__aenter__", "__aexit__", "__anext__"} +# `__aiter__` is a plain function but used with async calls +_async_magics = _async_method_magics | {"__aiter__"} + _all_magics = _magics | _non_defaults _unsupported_magics = { @@ -1779,6 +1819,7 @@ _return_values = { '__float__': 1.0, '__bool__': True, '__index__': 1, + '__aexit__': False, } @@ -1811,10 +1852,19 @@ def _get_iter(self): return iter(ret_val) return __iter__ +def _get_async_iter(self): + def __aiter__(): + ret_val = self.__aiter__._mock_return_value + if ret_val is DEFAULT: + return _AsyncIterator(iter([])) + return _AsyncIterator(iter(ret_val)) + return __aiter__ + _side_effect_methods = { '__eq__': _get_eq, '__ne__': _get_ne, '__iter__': _get_iter, + '__aiter__': _get_async_iter } @@ -1879,8 +1929,33 @@ class NonCallableMagicMock(MagicMixin, NonCallableMock): self._mock_set_magics() +class AsyncMagicMixin: + def __init__(self, *args, **kw): + self._mock_set_async_magics() # make magic work for kwargs in init + _safe_super(AsyncMagicMixin, self).__init__(*args, **kw) + self._mock_set_async_magics() # fix magic broken by upper level init + + def _mock_set_async_magics(self): + these_magics = _async_magics -class MagicMock(MagicMixin, Mock): + if getattr(self, "_mock_methods", None) is not None: + these_magics = _async_magics.intersection(self._mock_methods) + remove_magics = _async_magics - these_magics + + for entry in remove_magics: + if entry in type(self).__dict__: + # remove unneeded magic methods + delattr(self, entry) + + # don't overwrite existing attributes if called a second time + these_magics = these_magics - set(type(self).__dict__) + + _type = type(self) + for entry in these_magics: + setattr(_type, entry, MagicProxy(entry, self)) + + +class MagicMock(MagicMixin, AsyncMagicMixin, Mock): """ MagicMock is a subclass of Mock with default implementations of most of the magic methods. You can use MagicMock without having to @@ -1920,6 +1995,218 @@ class MagicProxy(object): return self.create_mock() +class AsyncMockMixin(Base): + awaited = _delegating_property('awaited') + await_count = _delegating_property('await_count') + await_args = _delegating_property('await_args') + await_args_list = _delegating_property('await_args_list') + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # asyncio.iscoroutinefunction() checks _is_coroutine property to say if an + # object is a coroutine. Without this check it looks to see if it is a + # function/method, which in this case it is not (since it is an + # AsyncMock). + # It is set through __dict__ because when spec_set is True, this + # attribute is likely undefined. + self.__dict__['_is_coroutine'] = asyncio.coroutines._is_coroutine + self.__dict__['_mock_awaited'] = _AwaitEvent(self) + self.__dict__['_mock_await_count'] = 0 + self.__dict__['_mock_await_args'] = None + self.__dict__['_mock_await_args_list'] = _CallList() + code_mock = NonCallableMock(spec_set=CodeType) + code_mock.co_flags = inspect.CO_COROUTINE + self.__dict__['__code__'] = code_mock + + async def _mock_call(_mock_self, *args, **kwargs): + self = _mock_self + try: + result = super()._mock_call(*args, **kwargs) + except (BaseException, StopIteration) as e: + side_effect = self.side_effect + if side_effect is not None and not callable(side_effect): + raise + return await _raise(e) + + _call = self.call_args + + async def proxy(): + try: + if inspect.isawaitable(result): + return await result + else: + return result + finally: + self.await_count += 1 + self.await_args = _call + self.await_args_list.append(_call) + await self.awaited._notify() + + return await proxy() + + def assert_awaited(_mock_self): + """ + Assert that the mock was awaited at least once. + """ + self = _mock_self + if self.await_count == 0: + msg = f"Expected {self._mock_name or 'mock'} to have been awaited." + raise AssertionError(msg) + + def assert_awaited_once(_mock_self): + """ + Assert that the mock was awaited exactly once. + """ + self = _mock_self + if not self.await_count == 1: + msg = (f"Expected {self._mock_name or 'mock'} to have been awaited once." + f" Awaited {self.await_count} times.") + raise AssertionError(msg) + + def assert_awaited_with(_mock_self, *args, **kwargs): + """ + Assert that the last await was with the specified arguments. + """ + self = _mock_self + if self.await_args is None: + expected = self._format_mock_call_signature(args, kwargs) + raise AssertionError(f'Expected await: {expected}\nNot awaited') + + def _error_message(): + msg = self._format_mock_failure_message(args, kwargs) + return msg + + expected = self._call_matcher((args, kwargs)) + actual = self._call_matcher(self.await_args) + if expected != actual: + cause = expected if isinstance(expected, Exception) else None + raise AssertionError(_error_message()) from cause + + def assert_awaited_once_with(_mock_self, *args, **kwargs): + """ + Assert that the mock was awaited exactly once and with the specified + arguments. + """ + self = _mock_self + if not self.await_count == 1: + msg = (f"Expected {self._mock_name or 'mock'} to have been awaited once." + f" Awaited {self.await_count} times.") + raise AssertionError(msg) + return self.assert_awaited_with(*args, **kwargs) + + def assert_any_await(_mock_self, *args, **kwargs): + """ + Assert the mock has ever been awaited with the specified arguments. + """ + self = _mock_self + expected = self._call_matcher((args, kwargs)) + actual = [self._call_matcher(c) for c in self.await_args_list] + if expected not in actual: + cause = expected if isinstance(expected, Exception) else None + expected_string = self._format_mock_call_signature(args, kwargs) + raise AssertionError( + '%s await not found' % expected_string + ) from cause + + def assert_has_awaits(_mock_self, calls, any_order=False): + """ + Assert the mock has been awaited with the specified calls. + The :attr:`await_args_list` list is checked for the awaits. + + If `any_order` is False (the default) then the awaits must be + sequential. There can be extra calls before or after the + specified awaits. + + If `any_order` is True then the awaits can be in any order, but + they must all appear in :attr:`await_args_list`. + """ + self = _mock_self + expected = [self._call_matcher(c) for c in calls] + cause = expected if isinstance(expected, Exception) else None + all_awaits = _CallList(self._call_matcher(c) for c in self.await_args_list) + if not any_order: + if expected not in all_awaits: + raise AssertionError( + f'Awaits not found.\nExpected: {_CallList(calls)}\n', + f'Actual: {self.await_args_list}' + ) from cause + return + + all_awaits = list(all_awaits) + + not_found = [] + for kall in expected: + try: + all_awaits.remove(kall) + except ValueError: + not_found.append(kall) + if not_found: + raise AssertionError( + '%r not all found in await list' % (tuple(not_found),) + ) from cause + + def assert_not_awaited(_mock_self): + """ + Assert that the mock was never awaited. + """ + self = _mock_self + if self.await_count != 0: + msg = (f"Expected {self._mock_name or 'mock'} to have been awaited once." + f" Awaited {self.await_count} times.") + raise AssertionError(msg) + + def reset_mock(self, *args, **kwargs): + """ + See :func:`.Mock.reset_mock()` + """ + super().reset_mock(*args, **kwargs) + self.await_count = 0 + self.await_args = None + self.await_args_list = _CallList() + + +class AsyncMock(AsyncMockMixin, AsyncMagicMixin, Mock): + """ + Enhance :class:`Mock` with features allowing to mock + an async function. + + The :class:`AsyncMock` object will behave so the object is + recognized as an async function, and the result of a call is an awaitable: + + >>> mock = AsyncMock() + >>> asyncio.iscoroutinefunction(mock) + True + >>> inspect.isawaitable(mock()) + True + + + The result of ``mock()`` is an async function which will have the outcome + of ``side_effect`` or ``return_value``: + + - if ``side_effect`` is a function, the async function will return the + result of that function, + - if ``side_effect`` is an exception, the async function will raise the + exception, + - if ``side_effect`` is an iterable, the async function will return the + next value of the iterable, however, if the sequence of result is + exhausted, ``StopIteration`` is raised immediately, + - if ``side_effect`` is not defined, the async function will return the + value defined by ``return_value``, hence, by default, the async function + returns a new :class:`AsyncMock` object. + + If the outcome of ``side_effect`` or ``return_value`` is an async function, + the mock async function obtained when the mock object is called will be this + async function itself (and not an async function returning an async + function). + + The test author can also specify a wrapped object with ``wraps``. In this + case, the :class:`Mock` object behavior is the same as with an + :class:`.Mock` object: the wrapped object may have methods + defined as async function functions. + + Based on Martin Richard's asyntest project. + """ + class _ANY(object): "A helper object that compares equal to everything." @@ -2145,7 +2432,6 @@ class _Call(tuple): call = _Call(from_kall=False) - def create_autospec(spec, spec_set=False, instance=False, _parent=None, _name=None, **kwargs): """Create a mock object using another object as a spec. Attributes on the @@ -2171,7 +2457,10 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None, spec = type(spec) is_type = isinstance(spec, type) - + if getattr(spec, '__code__', None): + is_async_func = asyncio.iscoroutinefunction(spec) + else: + is_async_func = False _kwargs = {'spec': spec} if spec_set: _kwargs = {'spec_set': spec} @@ -2188,6 +2477,11 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None, # descriptors don't have a spec # because we don't know what type they return _kwargs = {} + elif is_async_func: + if instance: + raise RuntimeError("Instance can not be True when create_autospec " + "is mocking an async function") + Klass = AsyncMock elif not _callable(spec): Klass = NonCallableMagicMock elif is_type and instance and not _instance_callable(spec): @@ -2204,9 +2498,26 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None, name=_name, **_kwargs) if isinstance(spec, FunctionTypes): + wrapped_mock = mock # should only happen at the top level because we don't # recurse for functions mock = _set_signature(mock, spec) + if is_async_func: + mock._is_coroutine = asyncio.coroutines._is_coroutine + mock.await_count = 0 + mock.await_args = None + mock.await_args_list = _CallList() + + for a in ('assert_awaited', + 'assert_awaited_once', + 'assert_awaited_with', + 'assert_awaited_once_with', + 'assert_any_await', + 'assert_has_awaits', + 'assert_not_awaited'): + def f(*args, **kwargs): + return getattr(wrapped_mock, a)(*args, **kwargs) + setattr(mock, a, f) else: _check_signature(spec, mock, is_type, instance) @@ -2250,9 +2561,13 @@ def create_autospec(spec, spec_set=False, instance=False, _parent=None, skipfirst = _must_skip(spec, entry, is_type) kwargs['_eat_self'] = skipfirst - new = MagicMock(parent=parent, name=entry, _new_name=entry, - _new_parent=parent, - **kwargs) + if asyncio.iscoroutinefunction(original): + child_klass = AsyncMock + else: + child_klass = MagicMock + new = child_klass(parent=parent, name=entry, _new_name=entry, + _new_parent=parent, + **kwargs) mock._mock_children[entry] = new _check_signature(original, new, skipfirst=skipfirst) @@ -2438,3 +2753,60 @@ def seal(mock): continue if m._mock_new_parent is mock: seal(m) + + +async def _raise(exception): + raise exception + + +class _AsyncIterator: + """ + Wraps an iterator in an asynchronous iterator. + """ + def __init__(self, iterator): + self.iterator = iterator + code_mock = NonCallableMock(spec_set=CodeType) + code_mock.co_flags = inspect.CO_ITERABLE_COROUTINE + self.__dict__['__code__'] = code_mock + + def __aiter__(self): + return self + + async def __anext__(self): + try: + return next(self.iterator) + except StopIteration: + pass + raise StopAsyncIteration + + +class _AwaitEvent: + def __init__(self, mock): + self._mock = mock + self._condition = None + + async def _notify(self): + condition = self._get_condition() + try: + await condition.acquire() + condition.notify_all() + finally: + condition.release() + + def _get_condition(self): + """ + Creation of condition is delayed, to minimize the chance of using the + wrong loop. + A user may create a mock with _AwaitEvent before selecting the + execution loop. Requiring a user to delay creation is error-prone and + inflexible. Instead, condition is created when user actually starts to + use the mock. + """ + # No synchronization is needed: + # - asyncio is thread unsafe + # - there are no awaits here, method will be executed without + # switching asyncio context. + if self._condition is None: + self._condition = asyncio.Condition() + + return self._condition |