diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2018-11-12 17:17:24 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-11-12 17:17:24 -0600 |
commit | 067264b20f8d2c043cf5b9cd3d1826384e558e79 (patch) | |
tree | 4714149704808160daa1ee9b2bdaa9cfab2672d8 /numpy/core | |
parent | e34f5bb4e3aad2bb35fb6d9e3a5c1bfc85eb97eb (diff) | |
parent | 9fa3a4e9802b32c985f6a1fc14dd315a2656ac38 (diff) | |
download | numpy-067264b20f8d2c043cf5b9cd3d1826384e558e79.tar.gz |
Merge pull request #12362 from shoyer/disable-array-function-by-default
MAINT: disable `__array_function__` dispatch unless environment variable set
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/arrayprint.py | 6 | ||||
-rw-r--r-- | numpy/core/overrides.py | 57 | ||||
-rw-r--r-- | numpy/core/shape_base.py | 8 | ||||
-rw-r--r-- | numpy/core/tests/test_overrides.py | 20 |
4 files changed, 83 insertions, 8 deletions
diff --git a/numpy/core/arrayprint.py b/numpy/core/arrayprint.py index ccc1468c4..b578fab54 100644 --- a/numpy/core/arrayprint.py +++ b/numpy/core/arrayprint.py @@ -1547,10 +1547,12 @@ def array_str(a, max_line_width=None, precision=None, suppress_small=None): a, max_line_width, precision, suppress_small) +# needed if __array_function__ is disabled +_array2string_impl = getattr(array2string, '__wrapped__', array2string) _default_array_str = functools.partial(_array_str_implementation, - array2string=array2string.__wrapped__) + array2string=_array2string_impl) _default_array_repr = functools.partial(_array_repr_implementation, - array2string=array2string.__wrapped__) + array2string=_array2string_impl) def set_string_function(f, repr=True): diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py index 088c15e65..e4d505f06 100644 --- a/numpy/core/overrides.py +++ b/numpy/core/overrides.py @@ -4,6 +4,7 @@ TODO: rewrite this in C for performance. """ import collections import functools +import os from numpy.core._multiarray_umath import ndarray from numpy.compat._inspect import getargspec @@ -12,6 +13,9 @@ from numpy.compat._inspect import getargspec _NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__ _NDARRAY_ONLY = [ndarray] +ENABLE_ARRAY_FUNCTION = bool( + int(os.environ.get('NUMPY_EXPERIMENTAL_ARRAY_FUNCTION', 0))) + def get_overloaded_types_and_args(relevant_args): """Returns a list of arguments on which to call __array_function__. @@ -146,12 +150,57 @@ def verify_matching_signatures(implementation, dispatcher): 'default argument values') +def override_module(module): + """Decorator for overriding __module__ on a function or class. + + Example usage:: + + @override_module('numpy') + def example(): + pass + + assert example.__module__ == 'numpy' + """ + def decorator(func): + if module is not None: + func.__module__ = module + return func + return decorator + + def array_function_dispatch(dispatcher, module=None, verify=True): - """Decorator for adding dispatch with the __array_function__ protocol.""" + """Decorator for adding dispatch with the __array_function__ protocol. + + See NEP-18 for example usage. + + Parameters + ---------- + dispatcher : callable + Function that when called like ``dispatcher(*args, **kwargs)`` with + arguments from the NumPy function call returns an iterable of + array-like arguments to check for ``__array_function__``. + module : str, optional + __module__ attribute to set on new function, e.g., ``module='numpy'``. + By default, module is copied from the decorated function. + verify : bool, optional + If True, verify the that the signature of the dispatcher and decorated + function signatures match exactly: all required and optional arguments + should appear in order with the same names, but the default values for + all optional arguments should be ``None``. Only disable verification + if the dispatcher's signature needs to deviate for some particular + reason, e.g., because the function has a signature like + ``func(*args, **kwargs)``. + + Returns + ------- + Function suitable for decorating the implementation of a NumPy function. + """ + + if not ENABLE_ARRAY_FUNCTION: + # __array_function__ requires an explicit opt-in for now + return override_module(module) + def decorator(implementation): - # TODO: only do this check when the appropriate flag is enabled or for - # a dev install. We want this check for testing but don't want to - # slow down all numpy imports. if verify: verify_matching_signatures(implementation, dispatcher) diff --git a/numpy/core/shape_base.py b/numpy/core/shape_base.py index 3edf0824e..6d234e527 100644 --- a/numpy/core/shape_base.py +++ b/numpy/core/shape_base.py @@ -217,6 +217,11 @@ def _arrays_for_stack_dispatcher(arrays, stacklevel=4): return arrays +def _warn_for_nonsequence(arrays): + if not overrides.ENABLE_ARRAY_FUNCTION: + _arrays_for_stack_dispatcher(arrays, stacklevel=4) + + def _vhstack_dispatcher(tup): return _arrays_for_stack_dispatcher(tup) @@ -274,6 +279,7 @@ def vstack(tup): [4]]) """ + _warn_for_nonsequence(tup) return _nx.concatenate([atleast_2d(_m) for _m in tup], 0) @@ -325,6 +331,7 @@ def hstack(tup): [3, 4]]) """ + _warn_for_nonsequence(tup) arrs = [atleast_1d(_m) for _m in tup] # As a special case, dimension 0 of 1-dimensional arrays is "horizontal" if arrs and arrs[0].ndim == 1: @@ -398,6 +405,7 @@ def stack(arrays, axis=0, out=None): [3, 4]]) """ + _warn_for_nonsequence(arrays) arrays = [asanyarray(arr) for arr in arrays] if not arrays: raise ValueError('need at least one array to stack') diff --git a/numpy/core/tests/test_overrides.py b/numpy/core/tests/test_overrides.py index ee6d5da4a..a32049ea1 100644 --- a/numpy/core/tests/test_overrides.py +++ b/numpy/core/tests/test_overrides.py @@ -1,5 +1,6 @@ from __future__ import division, absolute_import, print_function +import inspect import sys import numpy as np @@ -7,8 +8,14 @@ from numpy.testing import ( assert_, assert_equal, assert_raises, assert_raises_regex) from numpy.core.overrides import ( get_overloaded_types_and_args, array_function_dispatch, - verify_matching_signatures) + verify_matching_signatures, ENABLE_ARRAY_FUNCTION) from numpy.core.numeric import pickle +import pytest + + +requires_array_function = pytest.mark.skipif( + not ENABLE_ARRAY_FUNCTION, + reason="__array_function__ dispatch not enabled.") def _get_overloaded_args(relevant_args): @@ -165,6 +172,7 @@ def dispatched_one_arg(array): return 'original' +@requires_array_function class TestArrayFunctionDispatch(object): def test_pickle(self): @@ -204,6 +212,7 @@ class TestArrayFunctionDispatch(object): dispatched_one_arg(array) +@requires_array_function class TestVerifyMatchingSignatures(object): def test_verify_matching_signatures(self): @@ -256,6 +265,7 @@ def _new_duck_type_and_implements(): return (MyArray, implements) +@requires_array_function class TestArrayFunctionImplementation(object): def test_one_arg(self): @@ -322,7 +332,7 @@ class TestNDArrayMethods(object): assert_equal(repr(array), 'MyArray(1)') assert_equal(str(array), '1') - + class TestNumPyFunctions(object): def test_module(self): @@ -331,6 +341,12 @@ class TestNumPyFunctions(object): assert_equal(np.fft.fft.__module__, 'numpy.fft') assert_equal(np.linalg.solve.__module__, 'numpy.linalg') + @pytest.mark.skipif(sys.version_info[0] < 3, reason="Python 3 only") + def test_inspect_sum(self): + signature = inspect.signature(np.sum) + assert_('axis' in signature.parameters) + + @requires_array_function def test_override_sum(self): MyArray, implements = _new_duck_type_and_implements() |