diff options
author | Stephan Hoyer <shoyer@google.com> | 2018-12-19 10:08:01 -0800 |
---|---|---|
committer | Stephan Hoyer <shoyer@google.com> | 2018-12-19 10:46:40 -0800 |
commit | 45413455791c77465c5f33a5082053274eb18900 (patch) | |
tree | 7362d739f4558b4aece09bbde333df5c219d8f1a /numpy | |
parent | f4ddc2b3251d4606cc227761933e991131425416 (diff) | |
download | numpy-45413455791c77465c5f33a5082053274eb18900.tar.gz |
ENH: refactor __array_function__ pure Python implementation
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/_methods.py | 2 | ||||
-rw-r--r-- | numpy/core/multiarray.py | 8 | ||||
-rw-r--r-- | numpy/core/overrides.py | 108 | ||||
-rw-r--r-- | numpy/core/shape_base.py | 7 | ||||
-rw-r--r-- | numpy/core/tests/test_overrides.py | 67 | ||||
-rw-r--r-- | numpy/core/tests/test_shape_base.py | 4 |
6 files changed, 113 insertions, 83 deletions
diff --git a/numpy/core/_methods.py b/numpy/core/_methods.py index baeab6383..c6c163e16 100644 --- a/numpy/core/_methods.py +++ b/numpy/core/_methods.py @@ -161,7 +161,7 @@ def _array_function(self, func, types, args, kwargs): # TODO: rewrite this in C # Cannot handle items that have __array_function__ other than our own. for t in types: - if not issubclass(t, mu.ndarray) and hasattr(t, '__array_function__'): + if not issubclass(t, mu.ndarray): return NotImplemented # The regular implementation can handle this, so we call it directly. diff --git a/numpy/core/multiarray.py b/numpy/core/multiarray.py index 1b9c65782..4c2715892 100644 --- a/numpy/core/multiarray.py +++ b/numpy/core/multiarray.py @@ -211,9 +211,11 @@ def concatenate(arrays, axis=None, out=None): fill_value=999999) """ - for array in arrays: - yield array - yield out + if out is not None: + # optimize for the typical case where only arrays is provided + arrays = list(arrays) + arrays.append(out) + return arrays @array_function_from_c_func_and_dispatcher(_multiarray_umath.inner) diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py index 0979858a1..8f535b213 100644 --- a/numpy/core/overrides.py +++ b/numpy/core/overrides.py @@ -1,4 +1,4 @@ -"""Preliminary implementation of NEP-18 +"""Preliminary implementation of NEP-18. TODO: rewrite this in C for performance. """ @@ -10,64 +10,80 @@ from numpy.core._multiarray_umath import add_docstring, ndarray from numpy.compat._inspect import getargspec -_NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__ -_NDARRAY_ONLY = [ndarray] - ENABLE_ARRAY_FUNCTION = bool( int(os.environ.get('NUMPY_EXPERIMENTAL_ARRAY_FUNCTION', 0))) -def get_overloaded_types_and_args(relevant_args): +def get_implementing_types_and_args(relevant_args): """Returns a list of arguments on which to call __array_function__. - Parameters ---------- relevant_args : iterable of array-like Iterable of array-like arguments to check for __array_function__ methods. - Returns ------- - overloaded_types : collection of types + implementing_types : collection of types Types of arguments from relevant_args with __array_function__ methods. - overloaded_args : list + implementing_args : list Arguments from relevant_args on which to call __array_function__ methods, in the order in which they should be called. """ # Runtime is O(num_arguments * num_unique_types) - overloaded_types = [] - overloaded_args = [] + implementing_types = [] + implementing_args = [] for arg in relevant_args: arg_type = type(arg) # We only collect arguments if they have a unique type, which ensures # reasonable performance even with a long list of possibly overloaded # arguments. - if (arg_type not in overloaded_types and + if (arg_type not in implementing_types and hasattr(arg_type, '__array_function__')): # Create lists explicitly for the first type (usually the only one - # done) to avoid setting up the iterator for overloaded_args. - if overloaded_types: - overloaded_types.append(arg_type) + # done) to avoid setting up the iterator for implementing_args. + if implementing_types: + implementing_types.append(arg_type) # By default, insert argument at the end, but if it is # subclass of another argument, insert it before that argument. # This ensures "subclasses before superclasses". - index = len(overloaded_args) - for i, old_arg in enumerate(overloaded_args): + index = len(implementing_args) + for i, old_arg in enumerate(implementing_args): if issubclass(arg_type, type(old_arg)): index = i break - overloaded_args.insert(index, arg) + implementing_args.insert(index, arg) else: - overloaded_types = [arg_type] - overloaded_args = [arg] + implementing_types = [arg_type] + implementing_args = [arg] + + return implementing_types, implementing_args + + +_NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__ + - return overloaded_types, overloaded_args +def any_overrides(relevant_args): + """Are there any __array_function__ methods that need to be called?""" + for arg in relevant_args: + arg_type = type(arg) + if (arg_type is not ndarray and + getattr(arg_type, '__array_function__', + _NDARRAY_ARRAY_FUNCTION) + is not _NDARRAY_ARRAY_FUNCTION): + return True + return False -def array_function_implementation_or_override( +_TUPLE_OR_LIST = {tuple, list} + + +def implement_array_function( implementation, public_api, relevant_args, args, kwargs): - """Implement a function with checks for __array_function__ overrides. + """ + Implement a function with checks for __array_function__ overrides. + + All arguments are required, and can only be passed by position. Arguments --------- @@ -82,41 +98,55 @@ def array_function_implementation_or_override( Iterable of arguments to check for __array_function__ methods. args : tuple Arbitrary positional arguments originally passed into ``public_api``. - kwargs : tuple + kwargs : dict Arbitrary keyword arguments originally passed into ``public_api``. Returns ------- - Result from calling `implementation()` or an `__array_function__` + Result from calling ``implementation()`` or an ``__array_function__`` method, as appropriate. Raises ------ TypeError : if no implementation is found. """ - # Check for __array_function__ methods. - types, overloaded_args = get_overloaded_types_and_args(relevant_args) - # Short-cut for common cases: no overload or only ndarray overload - # (directly or with subclasses that do not override __array_function__). - if (not overloaded_args or types == _NDARRAY_ONLY or - all(type(arg).__array_function__ is _NDARRAY_ARRAY_FUNCTION - for arg in overloaded_args)): + if type(relevant_args) not in _TUPLE_OR_LIST: + relevant_args = tuple(relevant_args) + + if not any_overrides(relevant_args): return implementation(*args, **kwargs) # Call overrides - for overloaded_arg in overloaded_args: + types, implementing_args = get_implementing_types_and_args(relevant_args) + for arg in implementing_args: # Use `public_api` instead of `implemenation` so __array_function__ # implementations can do equality/identity comparisons. - result = overloaded_arg.__array_function__( - public_api, types, args, kwargs) - + result = arg.__array_function__(public_api, types, args, kwargs) if result is not NotImplemented: return result func_name = '{}.{}'.format(public_api.__module__, public_api.__name__) raise TypeError("no implementation found for '{}' on types that implement " - '__array_function__: {}' - .format(func_name, list(map(type, overloaded_args)))) + '__array_function__: {}'.format(func_name, list(types))) + + +def _get_implementing_args(relevant_args): + """ + Collect arguments on which to call __array_function__. + + Parameters + ---------- + relevant_args : iterable of array-like + Iterable of possibly array-like arguments to check for + __array_function__ methods. + + Returns + ------- + Sequence of arguments with __array_function__ methods, in the order in + which they should be called. + """ + _, args = get_implementing_types_and_args(relevant_args) + return args ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults') @@ -215,7 +245,7 @@ def array_function_dispatch(dispatcher, module=None, verify=True, @functools.wraps(implementation) def public_api(*args, **kwargs): relevant_args = dispatcher(*args, **kwargs) - return array_function_implementation_or_override( + return implement_array_function( implementation, public_api, relevant_args, args, kwargs) if module is not None: diff --git a/numpy/core/shape_base.py b/numpy/core/shape_base.py index 0378d3c1f..f8332c362 100644 --- a/numpy/core/shape_base.py +++ b/numpy/core/shape_base.py @@ -342,10 +342,11 @@ def hstack(tup): def _stack_dispatcher(arrays, axis=None, out=None): arrays = _arrays_for_stack_dispatcher(arrays, stacklevel=6) - for a in arrays: - yield a if out is not None: - yield out + # optimize for the typical case where only arrays is provided + arrays = list(arrays) + arrays.append(out) + return arrays @array_function_dispatch(_stack_dispatcher) diff --git a/numpy/core/tests/test_overrides.py b/numpy/core/tests/test_overrides.py index 62b2a3e53..8c7ef576e 100644 --- a/numpy/core/tests/test_overrides.py +++ b/numpy/core/tests/test_overrides.py @@ -7,7 +7,7 @@ import numpy as np from numpy.testing import ( assert_, assert_equal, assert_raises, assert_raises_regex) from numpy.core.overrides import ( - get_overloaded_types_and_args, array_function_dispatch, + _get_implementing_args, array_function_dispatch, verify_matching_signatures, ENABLE_ARRAY_FUNCTION) from numpy.core.numeric import pickle import pytest @@ -18,11 +18,6 @@ requires_array_function = pytest.mark.skipif( reason="__array_function__ dispatch not enabled.") -def _get_overloaded_args(relevant_args): - types, args = get_overloaded_types_and_args(relevant_args) - return args - - def _return_not_implemented(self, *args, **kwargs): return NotImplemented @@ -41,26 +36,21 @@ def dispatched_two_arg(array1, array2): @requires_array_function -class TestGetOverloadedTypesAndArgs(object): +class TestGetImplementingArgs(object): def test_ndarray(self): array = np.array(1) - types, args = get_overloaded_types_and_args([array]) - assert_equal(set(types), {np.ndarray}) + args = _get_implementing_args([array]) assert_equal(list(args), [array]) - types, args = get_overloaded_types_and_args([array, array]) - assert_equal(len(types), 1) - assert_equal(set(types), {np.ndarray}) + args = _get_implementing_args([array, array]) assert_equal(list(args), [array]) - types, args = get_overloaded_types_and_args([array, 1]) - assert_equal(set(types), {np.ndarray}) + args = _get_implementing_args([array, 1]) assert_equal(list(args), [array]) - types, args = get_overloaded_types_and_args([1, array]) - assert_equal(set(types), {np.ndarray}) + args = _get_implementing_args([1, array]) assert_equal(list(args), [array]) def test_ndarray_subclasses(self): @@ -75,17 +65,14 @@ class TestGetOverloadedTypesAndArgs(object): override_sub = np.array(1).view(OverrideSub) no_override_sub = np.array(1).view(NoOverrideSub) - types, args = get_overloaded_types_and_args([array, override_sub]) - assert_equal(set(types), {np.ndarray, OverrideSub}) + args = _get_implementing_args([array, override_sub]) assert_equal(list(args), [override_sub, array]) - types, args = get_overloaded_types_and_args([array, no_override_sub]) - assert_equal(set(types), {np.ndarray, NoOverrideSub}) + args = _get_implementing_args([array, no_override_sub]) assert_equal(list(args), [no_override_sub, array]) - types, args = get_overloaded_types_and_args( + args = _get_implementing_args( [override_sub, no_override_sub]) - assert_equal(set(types), {OverrideSub, NoOverrideSub}) assert_equal(list(args), [override_sub, no_override_sub]) def test_ndarray_and_duck_array(self): @@ -96,12 +83,10 @@ class TestGetOverloadedTypesAndArgs(object): array = np.array(1) other = Other() - types, args = get_overloaded_types_and_args([other, array]) - assert_equal(set(types), {np.ndarray, Other}) + args = _get_implementing_args([other, array]) assert_equal(list(args), [other, array]) - types, args = get_overloaded_types_and_args([array, other]) - assert_equal(set(types), {np.ndarray, Other}) + args = _get_implementing_args([array, other]) assert_equal(list(args), [array, other]) def test_ndarray_subclass_and_duck_array(self): @@ -116,9 +101,9 @@ class TestGetOverloadedTypesAndArgs(object): subarray = np.array(1).view(OverrideSub) other = Other() - assert_equal(_get_overloaded_args([array, subarray, other]), + assert_equal(_get_implementing_args([array, subarray, other]), [subarray, array, other]) - assert_equal(_get_overloaded_args([array, other, subarray]), + assert_equal(_get_implementing_args([array, other, subarray]), [subarray, array, other]) def test_many_duck_arrays(self): @@ -140,15 +125,15 @@ class TestGetOverloadedTypesAndArgs(object): c = C() d = D() - assert_equal(_get_overloaded_args([1]), []) - assert_equal(_get_overloaded_args([a]), [a]) - assert_equal(_get_overloaded_args([a, 1]), [a]) - assert_equal(_get_overloaded_args([a, a, a]), [a]) - assert_equal(_get_overloaded_args([a, d, a]), [a, d]) - assert_equal(_get_overloaded_args([a, b]), [b, a]) - assert_equal(_get_overloaded_args([b, a]), [b, a]) - assert_equal(_get_overloaded_args([a, b, c]), [b, c, a]) - assert_equal(_get_overloaded_args([a, c, b]), [c, b, a]) + assert_equal(_get_implementing_args([1]), []) + assert_equal(_get_implementing_args([a]), [a]) + assert_equal(_get_implementing_args([a, 1]), [a]) + assert_equal(_get_implementing_args([a, a, a]), [a]) + assert_equal(_get_implementing_args([a, d, a]), [a, d]) + assert_equal(_get_implementing_args([a, b]), [b, a]) + assert_equal(_get_implementing_args([b, a]), [b, a]) + assert_equal(_get_implementing_args([a, b, c]), [b, c, a]) + assert_equal(_get_implementing_args([a, c, b]), [c, b, a]) @requires_array_function @@ -201,6 +186,14 @@ class TestNDArrayArrayFunction(object): result = np.concatenate((array, override_sub)) assert_equal(result, expected.view(OverrideSub)) + def test_no_wrapper(self): + array = np.array(1) + func = dispatched_one_arg.__wrapped__ + with assert_raises_regex(AttributeError, '__wrapped__'): + array.__array_function__(func=func, + types=(np.ndarray,), + args=(array,), kwargs={}) + @requires_array_function class TestArrayFunctionDispatch(object): diff --git a/numpy/core/tests/test_shape_base.py b/numpy/core/tests/test_shape_base.py index ef5c118ec..b996321c2 100644 --- a/numpy/core/tests/test_shape_base.py +++ b/numpy/core/tests/test_shape_base.py @@ -373,6 +373,10 @@ def test_stack(): # empty arrays assert_(stack([[], [], []]).shape == (3, 0)) assert_(stack([[], [], []], axis=1).shape == (0, 3)) + # out + out = np.zeros_like(r1) + np.stack((a, b), out=out) + assert_array_equal(out, r1) # edge cases assert_raises_regex(ValueError, 'need at least one array', stack, []) assert_raises_regex(ValueError, 'must have the same shape', |