summaryrefslogtreecommitdiff
path: root/numpy/core/overrides.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core/overrides.py')
-rw-r--r--numpy/core/overrides.py37
1 files changed, 36 insertions, 1 deletions
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py
index 17e3d475f..906292613 100644
--- a/numpy/core/overrides.py
+++ b/numpy/core/overrides.py
@@ -2,8 +2,11 @@
TODO: rewrite this in C for performance.
"""
+import collections
import functools
+
from numpy.core.multiarray import ndarray
+from numpy.compat._inspect import getargspec
_NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__
@@ -107,13 +110,45 @@ def array_function_implementation_or_override(
.format(public_api, list(map(type, overloaded_args))))
-def array_function_dispatch(dispatcher):
+ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults')
+
+
+def verify_matching_signatures(implementation, dispatcher):
+ """Verify that a dispatcher function has the right signature."""
+ implementation_spec = ArgSpec(*getargspec(implementation))
+ dispatcher_spec = ArgSpec(*getargspec(dispatcher))
+
+ if (implementation_spec.args != dispatcher_spec.args or
+ implementation_spec.varargs != dispatcher_spec.varargs or
+ implementation_spec.keywords != dispatcher_spec.keywords or
+ (bool(implementation_spec.defaults) !=
+ bool(dispatcher_spec.defaults)) or
+ (implementation_spec.defaults is not None and
+ len(implementation_spec.defaults) !=
+ len(dispatcher_spec.defaults))):
+ raise RuntimeError('implementation and dispatcher for %s have '
+ 'different function signatures' % implementation)
+
+ if implementation_spec.defaults is not None:
+ if dispatcher_spec.defaults != (None,) * len(dispatcher_spec.defaults):
+ raise RuntimeError('dispatcher functions can only use None for '
+ 'default argument values')
+
+
+def array_function_dispatch(dispatcher, verify=True):
"""Decorator for adding dispatch with the __array_function__ protocol."""
def decorator(implementation):
+ # TODO: only do this check when the appropriate flag is enabled or for
+ # a dev install. We want this check for testing but don't want to
+ # slow down all numpy imports.
+ if verify:
+ verify_matching_signatures(implementation, dispatcher)
+
@functools.wraps(implementation)
def public_api(*args, **kwargs):
relevant_args = dispatcher(*args, **kwargs)
return array_function_implementation_or_override(
implementation, public_api, relevant_args, args, kwargs)
return public_api
+
return decorator