diff options
author | Stephan Hoyer <shoyer@gmail.com> | 2018-10-08 12:07:29 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-10-08 12:07:29 -0700 |
commit | 2f4bc6f1d561230e9e295ce0d49fef5c4deb7ea0 (patch) | |
tree | 3c227c971ddb9f534abfbe65c4cf673b613d3468 /numpy/core/overrides.py | |
parent | 1681fda2d2dfd032ca249e2b2ced1818bf5f5d95 (diff) | |
download | numpy-2f4bc6f1d561230e9e295ce0d49fef5c4deb7ea0.tar.gz |
ENH: Validate dispatcher functions in array_function_dispatch (#12099)
* Validate dispatcher functions in array_function_dispatch
They should have the same signature as the decorated function.
Note: eventually these checks should be optional -- we really only need them
to be run as part of NumPy's test suite, not every time numpy is imported.
* ENH: make signature checking in array_function_dispatch optional
* Change verify_signature keyword argument to verify
Diffstat (limited to 'numpy/core/overrides.py')
-rw-r--r-- | numpy/core/overrides.py | 37 |
1 files changed, 36 insertions, 1 deletions
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py index 17e3d475f..906292613 100644 --- a/numpy/core/overrides.py +++ b/numpy/core/overrides.py @@ -2,8 +2,11 @@ TODO: rewrite this in C for performance. """ +import collections import functools + from numpy.core.multiarray import ndarray +from numpy.compat._inspect import getargspec _NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__ @@ -107,13 +110,45 @@ def array_function_implementation_or_override( .format(public_api, list(map(type, overloaded_args)))) -def array_function_dispatch(dispatcher): +ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults') + + +def verify_matching_signatures(implementation, dispatcher): + """Verify that a dispatcher function has the right signature.""" + implementation_spec = ArgSpec(*getargspec(implementation)) + dispatcher_spec = ArgSpec(*getargspec(dispatcher)) + + if (implementation_spec.args != dispatcher_spec.args or + implementation_spec.varargs != dispatcher_spec.varargs or + implementation_spec.keywords != dispatcher_spec.keywords or + (bool(implementation_spec.defaults) != + bool(dispatcher_spec.defaults)) or + (implementation_spec.defaults is not None and + len(implementation_spec.defaults) != + len(dispatcher_spec.defaults))): + raise RuntimeError('implementation and dispatcher for %s have ' + 'different function signatures' % implementation) + + if implementation_spec.defaults is not None: + if dispatcher_spec.defaults != (None,) * len(dispatcher_spec.defaults): + raise RuntimeError('dispatcher functions can only use None for ' + 'default argument values') + + +def array_function_dispatch(dispatcher, verify=True): """Decorator for adding dispatch with the __array_function__ protocol.""" def decorator(implementation): + # TODO: only do this check when the appropriate flag is enabled or for + # a dev install. We want this check for testing but don't want to + # slow down all numpy imports. + if verify: + verify_matching_signatures(implementation, dispatcher) + @functools.wraps(implementation) def public_api(*args, **kwargs): relevant_args = dispatcher(*args, **kwargs) return array_function_implementation_or_override( implementation, public_api, relevant_args, args, kwargs) return public_api + return decorator |