summaryrefslogtreecommitdiff
path: root/numpy/core/overrides.py
diff options
context:
space:
mode:
authorStephan Hoyer <shoyer@gmail.com>2018-12-02 17:54:43 -0800
committerGitHub <noreply@github.com>2018-12-02 17:54:43 -0800
commit04c2f33843782cd27a16f919477b1b27c4b01e5a (patch)
treef59d7ebc6004f3804e30a0bdabd987a1e6f02152 /numpy/core/overrides.py
parentdc608ba2a85a7b89663bfd4d0ad977f40d0f813a (diff)
parentf7a82245fcd61a9b477d250d94951f33dfe2f097 (diff)
downloadnumpy-04c2f33843782cd27a16f919477b1b27c4b01e5a.tar.gz
Merge pull request #12328 from mhvk/ndarray_array_function_allow_subclasses
MAINT: Allow subclasses in `ndarray.__array_function__`.
Diffstat (limited to 'numpy/core/overrides.py')
-rw-r--r--numpy/core/overrides.py16
1 files changed, 5 insertions, 11 deletions
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py
index bddf14310..0979858a1 100644
--- a/numpy/core/overrides.py
+++ b/numpy/core/overrides.py
@@ -62,16 +62,6 @@ def get_overloaded_types_and_args(relevant_args):
overloaded_types = [arg_type]
overloaded_args = [arg]
- # Short-cut for the common case of only ndarray.
- if overloaded_types == _NDARRAY_ONLY:
- return overloaded_types, []
-
- # Special handling for ndarray.__array_function__
- overloaded_args = [
- arg for arg in overloaded_args
- if type(arg).__array_function__ is not _NDARRAY_ARRAY_FUNCTION
- ]
-
return overloaded_types, overloaded_args
@@ -106,7 +96,11 @@ def array_function_implementation_or_override(
"""
# Check for __array_function__ methods.
types, overloaded_args = get_overloaded_types_and_args(relevant_args)
- if not overloaded_args:
+ # Short-cut for common cases: no overload or only ndarray overload
+ # (directly or with subclasses that do not override __array_function__).
+ if (not overloaded_args or types == _NDARRAY_ONLY or
+ all(type(arg).__array_function__ is _NDARRAY_ARRAY_FUNCTION
+ for arg in overloaded_args)):
return implementation(*args, **kwargs)
# Call overrides