summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorStephan Hoyer <shoyer@gmail.com>2018-12-02 17:54:43 -0800
committerGitHub <noreply@github.com>2018-12-02 17:54:43 -0800
commit04c2f33843782cd27a16f919477b1b27c4b01e5a (patch)
treef59d7ebc6004f3804e30a0bdabd987a1e6f02152 /numpy
parentdc608ba2a85a7b89663bfd4d0ad977f40d0f813a (diff)
parentf7a82245fcd61a9b477d250d94951f33dfe2f097 (diff)
downloadnumpy-04c2f33843782cd27a16f919477b1b27c4b01e5a.tar.gz
Merge pull request #12328 from mhvk/ndarray_array_function_allow_subclasses
MAINT: Allow subclasses in `ndarray.__array_function__`.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/_methods.py13
-rw-r--r--numpy/core/overrides.py16
-rw-r--r--numpy/core/tests/test_overrides.py82
3 files changed, 66 insertions, 45 deletions
diff --git a/numpy/core/_methods.py b/numpy/core/_methods.py
index 8974f0ce1..baeab6383 100644
--- a/numpy/core/_methods.py
+++ b/numpy/core/_methods.py
@@ -161,11 +161,8 @@ def _array_function(self, func, types, args, kwargs):
# TODO: rewrite this in C
# Cannot handle items that have __array_function__ other than our own.
for t in types:
- if t is not mu.ndarray:
- method = getattr(t, '__array_function__', _NDARRAY_ARRAY_FUNCTION)
- if method is not _NDARRAY_ARRAY_FUNCTION:
- return NotImplemented
-
- # Arguments contain no overrides, so we can safely call the
- # overloaded function again.
- return func(*args, **kwargs)
+ if not issubclass(t, mu.ndarray) and hasattr(t, '__array_function__'):
+ return NotImplemented
+
+ # The regular implementation can handle this, so we call it directly.
+ return func.__wrapped__(*args, **kwargs)
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py
index bddf14310..0979858a1 100644
--- a/numpy/core/overrides.py
+++ b/numpy/core/overrides.py
@@ -62,16 +62,6 @@ def get_overloaded_types_and_args(relevant_args):
overloaded_types = [arg_type]
overloaded_args = [arg]
- # Short-cut for the common case of only ndarray.
- if overloaded_types == _NDARRAY_ONLY:
- return overloaded_types, []
-
- # Special handling for ndarray.__array_function__
- overloaded_args = [
- arg for arg in overloaded_args
- if type(arg).__array_function__ is not _NDARRAY_ARRAY_FUNCTION
- ]
-
return overloaded_types, overloaded_args
@@ -106,7 +96,11 @@ def array_function_implementation_or_override(
"""
# Check for __array_function__ methods.
types, overloaded_args = get_overloaded_types_and_args(relevant_args)
- if not overloaded_args:
+ # Short-cut for common cases: no overload or only ndarray overload
+ # (directly or with subclasses that do not override __array_function__).
+ if (not overloaded_args or types == _NDARRAY_ONLY or
+ all(type(arg).__array_function__ is _NDARRAY_ARRAY_FUNCTION
+ for arg in overloaded_args)):
return implementation(*args, **kwargs)
# Call overrides
diff --git a/numpy/core/tests/test_overrides.py b/numpy/core/tests/test_overrides.py
index 7db551801..62b2a3e53 100644
--- a/numpy/core/tests/test_overrides.py
+++ b/numpy/core/tests/test_overrides.py
@@ -27,6 +27,20 @@ def _return_not_implemented(self, *args, **kwargs):
return NotImplemented
+# need to define this at the top level to test pickling
+@array_function_dispatch(lambda array: (array,))
+def dispatched_one_arg(array):
+ """Docstring."""
+ return 'original'
+
+
+@array_function_dispatch(lambda array1, array2: (array1, array2))
+def dispatched_two_arg(array1, array2):
+ """Docstring."""
+ return 'original'
+
+
+@requires_array_function
class TestGetOverloadedTypesAndArgs(object):
def test_ndarray(self):
@@ -34,20 +48,20 @@ class TestGetOverloadedTypesAndArgs(object):
types, args = get_overloaded_types_and_args([array])
assert_equal(set(types), {np.ndarray})
- assert_equal(list(args), [])
+ assert_equal(list(args), [array])
types, args = get_overloaded_types_and_args([array, array])
assert_equal(len(types), 1)
assert_equal(set(types), {np.ndarray})
- assert_equal(list(args), [])
+ assert_equal(list(args), [array])
types, args = get_overloaded_types_and_args([array, 1])
assert_equal(set(types), {np.ndarray})
- assert_equal(list(args), [])
+ assert_equal(list(args), [array])
types, args = get_overloaded_types_and_args([1, array])
assert_equal(set(types), {np.ndarray})
- assert_equal(list(args), [])
+ assert_equal(list(args), [array])
def test_ndarray_subclasses(self):
@@ -63,16 +77,16 @@ class TestGetOverloadedTypesAndArgs(object):
types, args = get_overloaded_types_and_args([array, override_sub])
assert_equal(set(types), {np.ndarray, OverrideSub})
- assert_equal(list(args), [override_sub])
+ assert_equal(list(args), [override_sub, array])
types, args = get_overloaded_types_and_args([array, no_override_sub])
assert_equal(set(types), {np.ndarray, NoOverrideSub})
- assert_equal(list(args), [])
+ assert_equal(list(args), [no_override_sub, array])
types, args = get_overloaded_types_and_args(
[override_sub, no_override_sub])
assert_equal(set(types), {OverrideSub, NoOverrideSub})
- assert_equal(list(args), [override_sub])
+ assert_equal(list(args), [override_sub, no_override_sub])
def test_ndarray_and_duck_array(self):
@@ -84,11 +98,11 @@ class TestGetOverloadedTypesAndArgs(object):
types, args = get_overloaded_types_and_args([other, array])
assert_equal(set(types), {np.ndarray, Other})
- assert_equal(list(args), [other])
+ assert_equal(list(args), [other, array])
types, args = get_overloaded_types_and_args([array, other])
assert_equal(set(types), {np.ndarray, Other})
- assert_equal(list(args), [other])
+ assert_equal(list(args), [array, other])
def test_ndarray_subclass_and_duck_array(self):
@@ -103,9 +117,9 @@ class TestGetOverloadedTypesAndArgs(object):
other = Other()
assert_equal(_get_overloaded_args([array, subarray, other]),
- [subarray, other])
+ [subarray, array, other])
assert_equal(_get_overloaded_args([array, other, subarray]),
- [subarray, other])
+ [subarray, array, other])
def test_many_duck_arrays(self):
@@ -137,39 +151,55 @@ class TestGetOverloadedTypesAndArgs(object):
assert_equal(_get_overloaded_args([a, c, b]), [c, b, a])
+@requires_array_function
class TestNDArrayArrayFunction(object):
def test_method(self):
- class SubOverride(np.ndarray):
+ class Other(object):
__array_function__ = _return_not_implemented
class NoOverrideSub(np.ndarray):
pass
- array = np.array(1)
+ class OverrideSub(np.ndarray):
+ __array_function__ = _return_not_implemented
- def func():
- return 'original'
+ array = np.array([1])
+ other = Other()
+ no_override_sub = array.view(NoOverrideSub)
+ override_sub = array.view(OverrideSub)
- result = array.__array_function__(
- func=func, types=(np.ndarray,), args=(), kwargs={})
+ result = array.__array_function__(func=dispatched_two_arg,
+ types=(np.ndarray,),
+ args=(array, 1.), kwargs={})
assert_equal(result, 'original')
- result = array.__array_function__(
- func=func, types=(np.ndarray, SubOverride), args=(), kwargs={})
+ result = array.__array_function__(func=dispatched_two_arg,
+ types=(np.ndarray, Other),
+ args=(array, other), kwargs={})
assert_(result is NotImplemented)
- result = array.__array_function__(
- func=func, types=(np.ndarray, NoOverrideSub), args=(), kwargs={})
+ result = array.__array_function__(func=dispatched_two_arg,
+ types=(np.ndarray, NoOverrideSub),
+ args=(array, no_override_sub),
+ kwargs={})
assert_equal(result, 'original')
+ result = array.__array_function__(func=dispatched_two_arg,
+ types=(np.ndarray, OverrideSub),
+ args=(array, override_sub),
+ kwargs={})
+ assert_equal(result, 'original')
-# need to define this at the top level to test pickling
-@array_function_dispatch(lambda array: (array,))
-def dispatched_one_arg(array):
- """Docstring."""
- return 'original'
+ with assert_raises_regex(TypeError, 'no implementation found'):
+ np.concatenate((array, other))
+
+ expected = np.concatenate((array, array))
+ result = np.concatenate((array, no_override_sub))
+ assert_equal(result, expected.view(NoOverrideSub))
+ result = np.concatenate((array, override_sub))
+ assert_equal(result, expected.view(OverrideSub))
@requires_array_function