diff options
Diffstat (limited to 'numpy/core/overrides.py')
-rw-r--r-- | numpy/core/overrides.py | 27 |
1 files changed, 16 insertions, 11 deletions
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py index 088c15e65..0bbb39123 100644 --- a/numpy/core/overrides.py +++ b/numpy/core/overrides.py @@ -4,6 +4,7 @@ TODO: rewrite this in C for performance. """ import collections import functools +import os from numpy.core._multiarray_umath import ndarray from numpy.compat._inspect import getargspec @@ -12,6 +13,9 @@ from numpy.compat._inspect import getargspec _NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__ _NDARRAY_ONLY = [ndarray] +ENABLE_ARRAY_FUNCTION = bool( + int(os.environ.get('NUMPY_EXPERIMENTAL_ARRAY_FUNCTION', 0))) + def get_overloaded_types_and_args(relevant_args): """Returns a list of arguments on which to call __array_function__. @@ -149,17 +153,18 @@ def verify_matching_signatures(implementation, dispatcher): def array_function_dispatch(dispatcher, module=None, verify=True): """Decorator for adding dispatch with the __array_function__ protocol.""" def decorator(implementation): - # TODO: only do this check when the appropriate flag is enabled or for - # a dev install. We want this check for testing but don't want to - # slow down all numpy imports. - if verify: - verify_matching_signatures(implementation, dispatcher) - - @functools.wraps(implementation) - def public_api(*args, **kwargs): - relevant_args = dispatcher(*args, **kwargs) - return array_function_implementation_or_override( - implementation, public_api, relevant_args, args, kwargs) + if ENABLE_ARRAY_FUNCTION: + # __array_function__ requires an explicit opt-in for now + public_api = implementation + else: + if verify: + verify_matching_signatures(implementation, dispatcher) + + @functools.wraps(implementation) + def public_api(*args, **kwargs): + relevant_args = dispatcher(*args, **kwargs) + return array_function_implementation_or_override( + implementation, public_api, relevant_args, args, kwargs) if module is not None: public_api.__module__ = module |