diff options
author | Stephan Hoyer <shoyer@gmail.com> | 2018-12-02 17:54:43 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-12-02 17:54:43 -0800 |
commit | 04c2f33843782cd27a16f919477b1b27c4b01e5a (patch) | |
tree | f59d7ebc6004f3804e30a0bdabd987a1e6f02152 /numpy | |
parent | dc608ba2a85a7b89663bfd4d0ad977f40d0f813a (diff) | |
parent | f7a82245fcd61a9b477d250d94951f33dfe2f097 (diff) | |
download | numpy-04c2f33843782cd27a16f919477b1b27c4b01e5a.tar.gz |
Merge pull request #12328 from mhvk/ndarray_array_function_allow_subclasses
MAINT: Allow subclasses in `ndarray.__array_function__`.
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/_methods.py | 13 | ||||
-rw-r--r-- | numpy/core/overrides.py | 16 | ||||
-rw-r--r-- | numpy/core/tests/test_overrides.py | 82 |
3 files changed, 66 insertions, 45 deletions
diff --git a/numpy/core/_methods.py b/numpy/core/_methods.py index 8974f0ce1..baeab6383 100644 --- a/numpy/core/_methods.py +++ b/numpy/core/_methods.py @@ -161,11 +161,8 @@ def _array_function(self, func, types, args, kwargs): # TODO: rewrite this in C # Cannot handle items that have __array_function__ other than our own. for t in types: - if t is not mu.ndarray: - method = getattr(t, '__array_function__', _NDARRAY_ARRAY_FUNCTION) - if method is not _NDARRAY_ARRAY_FUNCTION: - return NotImplemented - - # Arguments contain no overrides, so we can safely call the - # overloaded function again. - return func(*args, **kwargs) + if not issubclass(t, mu.ndarray) and hasattr(t, '__array_function__'): + return NotImplemented + + # The regular implementation can handle this, so we call it directly. + return func.__wrapped__(*args, **kwargs) diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py index bddf14310..0979858a1 100644 --- a/numpy/core/overrides.py +++ b/numpy/core/overrides.py @@ -62,16 +62,6 @@ def get_overloaded_types_and_args(relevant_args): overloaded_types = [arg_type] overloaded_args = [arg] - # Short-cut for the common case of only ndarray. - if overloaded_types == _NDARRAY_ONLY: - return overloaded_types, [] - - # Special handling for ndarray.__array_function__ - overloaded_args = [ - arg for arg in overloaded_args - if type(arg).__array_function__ is not _NDARRAY_ARRAY_FUNCTION - ] - return overloaded_types, overloaded_args @@ -106,7 +96,11 @@ def array_function_implementation_or_override( """ # Check for __array_function__ methods. types, overloaded_args = get_overloaded_types_and_args(relevant_args) - if not overloaded_args: + # Short-cut for common cases: no overload or only ndarray overload + # (directly or with subclasses that do not override __array_function__). + if (not overloaded_args or types == _NDARRAY_ONLY or + all(type(arg).__array_function__ is _NDARRAY_ARRAY_FUNCTION + for arg in overloaded_args)): return implementation(*args, **kwargs) # Call overrides diff --git a/numpy/core/tests/test_overrides.py b/numpy/core/tests/test_overrides.py index 7db551801..62b2a3e53 100644 --- a/numpy/core/tests/test_overrides.py +++ b/numpy/core/tests/test_overrides.py @@ -27,6 +27,20 @@ def _return_not_implemented(self, *args, **kwargs): return NotImplemented +# need to define this at the top level to test pickling +@array_function_dispatch(lambda array: (array,)) +def dispatched_one_arg(array): + """Docstring.""" + return 'original' + + +@array_function_dispatch(lambda array1, array2: (array1, array2)) +def dispatched_two_arg(array1, array2): + """Docstring.""" + return 'original' + + +@requires_array_function class TestGetOverloadedTypesAndArgs(object): def test_ndarray(self): @@ -34,20 +48,20 @@ class TestGetOverloadedTypesAndArgs(object): types, args = get_overloaded_types_and_args([array]) assert_equal(set(types), {np.ndarray}) - assert_equal(list(args), []) + assert_equal(list(args), [array]) types, args = get_overloaded_types_and_args([array, array]) assert_equal(len(types), 1) assert_equal(set(types), {np.ndarray}) - assert_equal(list(args), []) + assert_equal(list(args), [array]) types, args = get_overloaded_types_and_args([array, 1]) assert_equal(set(types), {np.ndarray}) - assert_equal(list(args), []) + assert_equal(list(args), [array]) types, args = get_overloaded_types_and_args([1, array]) assert_equal(set(types), {np.ndarray}) - assert_equal(list(args), []) + assert_equal(list(args), [array]) def test_ndarray_subclasses(self): @@ -63,16 +77,16 @@ class TestGetOverloadedTypesAndArgs(object): types, args = get_overloaded_types_and_args([array, override_sub]) assert_equal(set(types), {np.ndarray, OverrideSub}) - assert_equal(list(args), [override_sub]) + assert_equal(list(args), [override_sub, array]) types, args = get_overloaded_types_and_args([array, no_override_sub]) assert_equal(set(types), {np.ndarray, NoOverrideSub}) - assert_equal(list(args), []) + assert_equal(list(args), [no_override_sub, array]) types, args = get_overloaded_types_and_args( [override_sub, no_override_sub]) assert_equal(set(types), {OverrideSub, NoOverrideSub}) - assert_equal(list(args), [override_sub]) + assert_equal(list(args), [override_sub, no_override_sub]) def test_ndarray_and_duck_array(self): @@ -84,11 +98,11 @@ class TestGetOverloadedTypesAndArgs(object): types, args = get_overloaded_types_and_args([other, array]) assert_equal(set(types), {np.ndarray, Other}) - assert_equal(list(args), [other]) + assert_equal(list(args), [other, array]) types, args = get_overloaded_types_and_args([array, other]) assert_equal(set(types), {np.ndarray, Other}) - assert_equal(list(args), [other]) + assert_equal(list(args), [array, other]) def test_ndarray_subclass_and_duck_array(self): @@ -103,9 +117,9 @@ class TestGetOverloadedTypesAndArgs(object): other = Other() assert_equal(_get_overloaded_args([array, subarray, other]), - [subarray, other]) + [subarray, array, other]) assert_equal(_get_overloaded_args([array, other, subarray]), - [subarray, other]) + [subarray, array, other]) def test_many_duck_arrays(self): @@ -137,39 +151,55 @@ class TestGetOverloadedTypesAndArgs(object): assert_equal(_get_overloaded_args([a, c, b]), [c, b, a]) +@requires_array_function class TestNDArrayArrayFunction(object): def test_method(self): - class SubOverride(np.ndarray): + class Other(object): __array_function__ = _return_not_implemented class NoOverrideSub(np.ndarray): pass - array = np.array(1) + class OverrideSub(np.ndarray): + __array_function__ = _return_not_implemented - def func(): - return 'original' + array = np.array([1]) + other = Other() + no_override_sub = array.view(NoOverrideSub) + override_sub = array.view(OverrideSub) - result = array.__array_function__( - func=func, types=(np.ndarray,), args=(), kwargs={}) + result = array.__array_function__(func=dispatched_two_arg, + types=(np.ndarray,), + args=(array, 1.), kwargs={}) assert_equal(result, 'original') - result = array.__array_function__( - func=func, types=(np.ndarray, SubOverride), args=(), kwargs={}) + result = array.__array_function__(func=dispatched_two_arg, + types=(np.ndarray, Other), + args=(array, other), kwargs={}) assert_(result is NotImplemented) - result = array.__array_function__( - func=func, types=(np.ndarray, NoOverrideSub), args=(), kwargs={}) + result = array.__array_function__(func=dispatched_two_arg, + types=(np.ndarray, NoOverrideSub), + args=(array, no_override_sub), + kwargs={}) assert_equal(result, 'original') + result = array.__array_function__(func=dispatched_two_arg, + types=(np.ndarray, OverrideSub), + args=(array, override_sub), + kwargs={}) + assert_equal(result, 'original') -# need to define this at the top level to test pickling -@array_function_dispatch(lambda array: (array,)) -def dispatched_one_arg(array): - """Docstring.""" - return 'original' + with assert_raises_regex(TypeError, 'no implementation found'): + np.concatenate((array, other)) + + expected = np.concatenate((array, array)) + result = np.concatenate((array, no_override_sub)) + assert_equal(result, expected.view(NoOverrideSub)) + result = np.concatenate((array, override_sub)) + assert_equal(result, expected.view(OverrideSub)) @requires_array_function |