diff options
Diffstat (limited to 'Lib/unittest/mock.py')
-rw-r--r-- | Lib/unittest/mock.py | 39 |
1 files changed, 28 insertions, 11 deletions
diff --git a/Lib/unittest/mock.py b/Lib/unittest/mock.py index a98decc2c9..513cd5c924 100644 --- a/Lib/unittest/mock.py +++ b/Lib/unittest/mock.py @@ -852,7 +852,8 @@ class NonCallableMock(Base): else: name, args, kwargs = _call try: - return name, sig.bind(*args, **kwargs) + bound_call = sig.bind(*args, **kwargs) + return call(name, bound_call.args, bound_call.kwargs) except TypeError as e: return e.with_traceback(None) else: @@ -901,9 +902,9 @@ class NonCallableMock(Base): def _error_message(): msg = self._format_mock_failure_message(args, kwargs) return msg - expected = self._call_matcher((args, kwargs)) + expected = self._call_matcher(_Call((args, kwargs), two=True)) actual = self._call_matcher(self.call_args) - if expected != actual: + if actual != expected: cause = expected if isinstance(expected, Exception) else None raise AssertionError(_error_message()) from cause @@ -963,10 +964,10 @@ class NonCallableMock(Base): The assert passes if the mock has *ever* been called, unlike `assert_called_with` and `assert_called_once_with` that only pass if the call is the most recent one.""" - expected = self._call_matcher((args, kwargs)) + expected = self._call_matcher(_Call((args, kwargs), two=True)) + cause = expected if isinstance(expected, Exception) else None actual = [self._call_matcher(c) for c in self.call_args_list] - if expected not in actual: - cause = expected if isinstance(expected, Exception) else None + if cause or expected not in _AnyComparer(actual): expected_string = self._format_mock_call_signature(args, kwargs) raise AssertionError( '%s call not found' % expected_string @@ -1019,6 +1020,22 @@ class NonCallableMock(Base): return f"\n{prefix}: {safe_repr(self.mock_calls)}." +class _AnyComparer(list): + """A list which checks if it contains a call which may have an + argument of ANY, flipping the components of item and self from + their traditional locations so that ANY is guaranteed to be on + the left.""" + def __contains__(self, item): + for _call in self: + if len(item) != len(_call): + continue + if all([ + expected == actual + for expected, actual in zip(item, _call) + ]): + return True + return False + def _try_iter(obj): if obj is None: @@ -2171,9 +2188,9 @@ class AsyncMockMixin(Base): msg = self._format_mock_failure_message(args, kwargs, action='await') return msg - expected = self._call_matcher((args, kwargs)) + expected = self._call_matcher(_Call((args, kwargs), two=True)) actual = self._call_matcher(self.await_args) - if expected != actual: + if actual != expected: cause = expected if isinstance(expected, Exception) else None raise AssertionError(_error_message()) from cause @@ -2192,10 +2209,10 @@ class AsyncMockMixin(Base): """ Assert the mock has ever been awaited with the specified arguments. """ - expected = self._call_matcher((args, kwargs)) + expected = self._call_matcher(_Call((args, kwargs), two=True)) + cause = expected if isinstance(expected, Exception) else None actual = [self._call_matcher(c) for c in self.await_args_list] - if expected not in actual: - cause = expected if isinstance(expected, Exception) else None + if cause or expected not in _AnyComparer(actual): expected_string = self._format_mock_call_signature(args, kwargs) raise AssertionError( '%s await not found' % expected_string |