summaryrefslogtreecommitdiff
path: root/numpy/core/overrides.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core/overrides.py')
-rw-r--r--numpy/core/overrides.py16
1 files changed, 5 insertions, 11 deletions
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py
index bddf14310..0979858a1 100644
--- a/numpy/core/overrides.py
+++ b/numpy/core/overrides.py
@@ -62,16 +62,6 @@ def get_overloaded_types_and_args(relevant_args):
overloaded_types = [arg_type]
overloaded_args = [arg]
- # Short-cut for the common case of only ndarray.
- if overloaded_types == _NDARRAY_ONLY:
- return overloaded_types, []
-
- # Special handling for ndarray.__array_function__
- overloaded_args = [
- arg for arg in overloaded_args
- if type(arg).__array_function__ is not _NDARRAY_ARRAY_FUNCTION
- ]
-
return overloaded_types, overloaded_args
@@ -106,7 +96,11 @@ def array_function_implementation_or_override(
"""
# Check for __array_function__ methods.
types, overloaded_args = get_overloaded_types_and_args(relevant_args)
- if not overloaded_args:
+ # Short-cut for common cases: no overload or only ndarray overload
+ # (directly or with subclasses that do not override __array_function__).
+ if (not overloaded_args or types == _NDARRAY_ONLY or
+ all(type(arg).__array_function__ is _NDARRAY_ARRAY_FUNCTION
+ for arg in overloaded_args)):
return implementation(*args, **kwargs)
# Call overrides