summaryrefslogtreecommitdiff
path: root/numpy/core/overrides.py
diff options
context:
space:
mode:
authorStephan Hoyer <shoyer@gmail.com>2018-10-08 12:07:29 -0700
committerGitHub <noreply@github.com>2018-10-08 12:07:29 -0700
commit2f4bc6f1d561230e9e295ce0d49fef5c4deb7ea0 (patch)
tree3c227c971ddb9f534abfbe65c4cf673b613d3468 /numpy/core/overrides.py
parent1681fda2d2dfd032ca249e2b2ced1818bf5f5d95 (diff)
downloadnumpy-2f4bc6f1d561230e9e295ce0d49fef5c4deb7ea0.tar.gz
ENH: Validate dispatcher functions in array_function_dispatch (#12099)
* Validate dispatcher functions in array_function_dispatch They should have the same signature as the decorated function. Note: eventually these checks should be optional -- we really only need them to be run as part of NumPy's test suite, not every time numpy is imported. * ENH: make signature checking in array_function_dispatch optional * Change verify_signature keyword argument to verify
Diffstat (limited to 'numpy/core/overrides.py')
-rw-r--r--numpy/core/overrides.py37
1 files changed, 36 insertions, 1 deletions
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py
index 17e3d475f..906292613 100644
--- a/numpy/core/overrides.py
+++ b/numpy/core/overrides.py
@@ -2,8 +2,11 @@
TODO: rewrite this in C for performance.
"""
+import collections
import functools
+
from numpy.core.multiarray import ndarray
+from numpy.compat._inspect import getargspec
_NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__
@@ -107,13 +110,45 @@ def array_function_implementation_or_override(
.format(public_api, list(map(type, overloaded_args))))
-def array_function_dispatch(dispatcher):
+ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults')
+
+
+def verify_matching_signatures(implementation, dispatcher):
+ """Verify that a dispatcher function has the right signature."""
+ implementation_spec = ArgSpec(*getargspec(implementation))
+ dispatcher_spec = ArgSpec(*getargspec(dispatcher))
+
+ if (implementation_spec.args != dispatcher_spec.args or
+ implementation_spec.varargs != dispatcher_spec.varargs or
+ implementation_spec.keywords != dispatcher_spec.keywords or
+ (bool(implementation_spec.defaults) !=
+ bool(dispatcher_spec.defaults)) or
+ (implementation_spec.defaults is not None and
+ len(implementation_spec.defaults) !=
+ len(dispatcher_spec.defaults))):
+ raise RuntimeError('implementation and dispatcher for %s have '
+ 'different function signatures' % implementation)
+
+ if implementation_spec.defaults is not None:
+ if dispatcher_spec.defaults != (None,) * len(dispatcher_spec.defaults):
+ raise RuntimeError('dispatcher functions can only use None for '
+ 'default argument values')
+
+
+def array_function_dispatch(dispatcher, verify=True):
"""Decorator for adding dispatch with the __array_function__ protocol."""
def decorator(implementation):
+ # TODO: only do this check when the appropriate flag is enabled or for
+ # a dev install. We want this check for testing but don't want to
+ # slow down all numpy imports.
+ if verify:
+ verify_matching_signatures(implementation, dispatcher)
+
@functools.wraps(implementation)
def public_api(*args, **kwargs):
relevant_args = dispatcher(*args, **kwargs)
return array_function_implementation_or_override(
implementation, public_api, relevant_args, args, kwargs)
return public_api
+
return decorator