diff options
author | Stephan Hoyer <shoyer@google.com> | 2018-11-10 13:06:41 -0800 |
---|---|---|
committer | Stephan Hoyer <shoyer@google.com> | 2018-11-10 13:16:43 -0800 |
commit | 3d49d86e0c83da1733386888c65b6490fb53497b (patch) | |
tree | 4156bba0ff9d8568188795ce63111770e6e27008 /numpy/core | |
parent | d0e2f1ac0c0fb4aaf791c9082cff1d6b04545410 (diff) | |
download | numpy-3d49d86e0c83da1733386888c65b6490fb53497b.tar.gz |
MAINT: disable __array_function__ dispatch unless environment variable set
Per discussion on the mailing list, __array_function__ isn't quite ready to
release as part of NumPy 1.16:
https://mail.python.org/pipermail/numpy-discussion/2018-November/078949.html
We'd like to improve performance a bit, and it will be easier to support
introspection on NumPy functions if we support Python 3 only.
So for now, you need to set the environment variable
``NUMPY_EXPERIMENTAL_ARRAY_FUNCTION=1`` to enable dispatching.
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/overrides.py | 27 | ||||
-rw-r--r-- | numpy/core/tests/test_overrides.py | 14 |
2 files changed, 28 insertions, 13 deletions
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py index 088c15e65..0bbb39123 100644 --- a/numpy/core/overrides.py +++ b/numpy/core/overrides.py @@ -4,6 +4,7 @@ TODO: rewrite this in C for performance. """ import collections import functools +import os from numpy.core._multiarray_umath import ndarray from numpy.compat._inspect import getargspec @@ -12,6 +13,9 @@ from numpy.compat._inspect import getargspec _NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__ _NDARRAY_ONLY = [ndarray] +ENABLE_ARRAY_FUNCTION = bool( + int(os.environ.get('NUMPY_EXPERIMENTAL_ARRAY_FUNCTION', 0))) + def get_overloaded_types_and_args(relevant_args): """Returns a list of arguments on which to call __array_function__. @@ -149,17 +153,18 @@ def verify_matching_signatures(implementation, dispatcher): def array_function_dispatch(dispatcher, module=None, verify=True): """Decorator for adding dispatch with the __array_function__ protocol.""" def decorator(implementation): - # TODO: only do this check when the appropriate flag is enabled or for - # a dev install. We want this check for testing but don't want to - # slow down all numpy imports. - if verify: - verify_matching_signatures(implementation, dispatcher) - - @functools.wraps(implementation) - def public_api(*args, **kwargs): - relevant_args = dispatcher(*args, **kwargs) - return array_function_implementation_or_override( - implementation, public_api, relevant_args, args, kwargs) + if ENABLE_ARRAY_FUNCTION: + # __array_function__ requires an explicit opt-in for now + public_api = implementation + else: + if verify: + verify_matching_signatures(implementation, dispatcher) + + @functools.wraps(implementation) + def public_api(*args, **kwargs): + relevant_args = dispatcher(*args, **kwargs) + return array_function_implementation_or_override( + implementation, public_api, relevant_args, args, kwargs) if module is not None: public_api.__module__ = module diff --git a/numpy/core/tests/test_overrides.py b/numpy/core/tests/test_overrides.py index ee6d5da4a..1c8ca760a 100644 --- a/numpy/core/tests/test_overrides.py +++ b/numpy/core/tests/test_overrides.py @@ -7,8 +7,14 @@ from numpy.testing import ( assert_, assert_equal, assert_raises, assert_raises_regex) from numpy.core.overrides import ( get_overloaded_types_and_args, array_function_dispatch, - verify_matching_signatures) + verify_matching_signatures, ENABLE_ARRAY_FUNCTION) from numpy.core.numeric import pickle +import pytest + + +requires_array_function = pytest.mark.skipif( + not ENABLE_ARRAY_FUNCTION, + reason="__array_function__ dispatch not enabled.") def _get_overloaded_args(relevant_args): @@ -165,6 +171,7 @@ def dispatched_one_arg(array): return 'original' +@requires_array_function class TestArrayFunctionDispatch(object): def test_pickle(self): @@ -204,6 +211,7 @@ class TestArrayFunctionDispatch(object): dispatched_one_arg(array) +@requires_array_function class TestVerifyMatchingSignatures(object): def test_verify_matching_signatures(self): @@ -256,6 +264,7 @@ def _new_duck_type_and_implements(): return (MyArray, implements) +@requires_array_function class TestArrayFunctionImplementation(object): def test_one_arg(self): @@ -322,7 +331,7 @@ class TestNDArrayMethods(object): assert_equal(repr(array), 'MyArray(1)') assert_equal(str(array), '1') - + class TestNumPyFunctions(object): def test_module(self): @@ -331,6 +340,7 @@ class TestNumPyFunctions(object): assert_equal(np.fft.fft.__module__, 'numpy.fft') assert_equal(np.linalg.solve.__module__, 'numpy.linalg') + @requires_array_function def test_override_sum(self): MyArray, implements = _new_duck_type_and_implements() |