diff options
Diffstat (limited to 'numpy/core/overrides.py')
-rw-r--r-- | numpy/core/overrides.py | 16 |
1 files changed, 10 insertions, 6 deletions
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py index f9c48a156..85a8c32bb 100644 --- a/numpy/core/overrides.py +++ b/numpy/core/overrides.py @@ -5,7 +5,7 @@ TODO: rewrite this in C for performance. import collections import functools -from numpy.core.multiarray import ndarray +from numpy.core._multiarray_umath import ndarray from numpy.compat._inspect import getargspec @@ -71,8 +71,8 @@ def array_function_implementation_or_override( Function that implements the operation on NumPy array without overrides when called like ``implementation(*args, **kwargs)``. public_api : function - Function exposed by NumPy's public API riginally called like - ``public_api(*args, **kwargs`` on which arguments are now being + Function exposed by NumPy's public API originally called like + ``public_api(*args, **kwargs)`` on which arguments are now being checked. relevant_args : iterable Iterable of arguments to check for __array_function__ methods. @@ -105,9 +105,10 @@ def array_function_implementation_or_override( if result is not NotImplemented: return result - raise TypeError('no implementation found for {} on types that implement ' + func_name = '{}.{}'.format(public_api.__module__, public_api.__name__) + raise TypeError("no implementation found for '{}' on types that implement " '__array_function__: {}' - .format(public_api, list(map(type, overloaded_args)))) + .format(func_name, list(map(type, overloaded_args)))) ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults') @@ -135,7 +136,7 @@ def verify_matching_signatures(implementation, dispatcher): 'default argument values') -def array_function_dispatch(dispatcher, verify=True): +def array_function_dispatch(dispatcher, module=None, verify=True): """Decorator for adding dispatch with the __array_function__ protocol.""" def decorator(implementation): # TODO: only do this check when the appropriate flag is enabled or for @@ -150,6 +151,9 @@ def array_function_dispatch(dispatcher, verify=True): return array_function_implementation_or_override( implementation, public_api, relevant_args, args, kwargs) + if module is not None: + public_api.__module__ = module + # TODO: remove this when we drop Python 2 support (functools.wraps # adds __wrapped__ automatically in later versions) public_api.__wrapped__ = implementation |