summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
authorStephan Hoyer <shoyer@google.com>2018-11-10 13:06:41 -0800
committerStephan Hoyer <shoyer@google.com>2018-11-10 13:16:43 -0800
commit3d49d86e0c83da1733386888c65b6490fb53497b (patch)
tree4156bba0ff9d8568188795ce63111770e6e27008 /numpy/core
parentd0e2f1ac0c0fb4aaf791c9082cff1d6b04545410 (diff)
downloadnumpy-3d49d86e0c83da1733386888c65b6490fb53497b.tar.gz
MAINT: disable __array_function__ dispatch unless environment variable set
Per discussion on the mailing list, __array_function__ isn't quite ready to release as part of NumPy 1.16: https://mail.python.org/pipermail/numpy-discussion/2018-November/078949.html We'd like to improve performance a bit, and it will be easier to support introspection on NumPy functions if we support Python 3 only. So for now, you need to set the environment variable ``NUMPY_EXPERIMENTAL_ARRAY_FUNCTION=1`` to enable dispatching.
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/overrides.py27
-rw-r--r--numpy/core/tests/test_overrides.py14
2 files changed, 28 insertions, 13 deletions
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py
index 088c15e65..0bbb39123 100644
--- a/numpy/core/overrides.py
+++ b/numpy/core/overrides.py
@@ -4,6 +4,7 @@ TODO: rewrite this in C for performance.
"""
import collections
import functools
+import os
from numpy.core._multiarray_umath import ndarray
from numpy.compat._inspect import getargspec
@@ -12,6 +13,9 @@ from numpy.compat._inspect import getargspec
_NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__
_NDARRAY_ONLY = [ndarray]
+ENABLE_ARRAY_FUNCTION = bool(
+ int(os.environ.get('NUMPY_EXPERIMENTAL_ARRAY_FUNCTION', 0)))
+
def get_overloaded_types_and_args(relevant_args):
"""Returns a list of arguments on which to call __array_function__.
@@ -149,17 +153,18 @@ def verify_matching_signatures(implementation, dispatcher):
def array_function_dispatch(dispatcher, module=None, verify=True):
"""Decorator for adding dispatch with the __array_function__ protocol."""
def decorator(implementation):
- # TODO: only do this check when the appropriate flag is enabled or for
- # a dev install. We want this check for testing but don't want to
- # slow down all numpy imports.
- if verify:
- verify_matching_signatures(implementation, dispatcher)
-
- @functools.wraps(implementation)
- def public_api(*args, **kwargs):
- relevant_args = dispatcher(*args, **kwargs)
- return array_function_implementation_or_override(
- implementation, public_api, relevant_args, args, kwargs)
+ if ENABLE_ARRAY_FUNCTION:
+ # __array_function__ requires an explicit opt-in for now
+ public_api = implementation
+ else:
+ if verify:
+ verify_matching_signatures(implementation, dispatcher)
+
+ @functools.wraps(implementation)
+ def public_api(*args, **kwargs):
+ relevant_args = dispatcher(*args, **kwargs)
+ return array_function_implementation_or_override(
+ implementation, public_api, relevant_args, args, kwargs)
if module is not None:
public_api.__module__ = module
diff --git a/numpy/core/tests/test_overrides.py b/numpy/core/tests/test_overrides.py
index ee6d5da4a..1c8ca760a 100644
--- a/numpy/core/tests/test_overrides.py
+++ b/numpy/core/tests/test_overrides.py
@@ -7,8 +7,14 @@ from numpy.testing import (
assert_, assert_equal, assert_raises, assert_raises_regex)
from numpy.core.overrides import (
get_overloaded_types_and_args, array_function_dispatch,
- verify_matching_signatures)
+ verify_matching_signatures, ENABLE_ARRAY_FUNCTION)
from numpy.core.numeric import pickle
+import pytest
+
+
+requires_array_function = pytest.mark.skipif(
+ not ENABLE_ARRAY_FUNCTION,
+ reason="__array_function__ dispatch not enabled.")
def _get_overloaded_args(relevant_args):
@@ -165,6 +171,7 @@ def dispatched_one_arg(array):
return 'original'
+@requires_array_function
class TestArrayFunctionDispatch(object):
def test_pickle(self):
@@ -204,6 +211,7 @@ class TestArrayFunctionDispatch(object):
dispatched_one_arg(array)
+@requires_array_function
class TestVerifyMatchingSignatures(object):
def test_verify_matching_signatures(self):
@@ -256,6 +264,7 @@ def _new_duck_type_and_implements():
return (MyArray, implements)
+@requires_array_function
class TestArrayFunctionImplementation(object):
def test_one_arg(self):
@@ -322,7 +331,7 @@ class TestNDArrayMethods(object):
assert_equal(repr(array), 'MyArray(1)')
assert_equal(str(array), '1')
-
+
class TestNumPyFunctions(object):
def test_module(self):
@@ -331,6 +340,7 @@ class TestNumPyFunctions(object):
assert_equal(np.fft.fft.__module__, 'numpy.fft')
assert_equal(np.linalg.solve.__module__, 'numpy.linalg')
+ @requires_array_function
def test_override_sum(self):
MyArray, implements = _new_duck_type_and_implements()