diff options
Diffstat (limited to 'numpy/core/overrides.py')
-rw-r--r-- | numpy/core/overrides.py | 16 |
1 files changed, 5 insertions, 11 deletions
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py index bddf14310..0979858a1 100644 --- a/numpy/core/overrides.py +++ b/numpy/core/overrides.py @@ -62,16 +62,6 @@ def get_overloaded_types_and_args(relevant_args): overloaded_types = [arg_type] overloaded_args = [arg] - # Short-cut for the common case of only ndarray. - if overloaded_types == _NDARRAY_ONLY: - return overloaded_types, [] - - # Special handling for ndarray.__array_function__ - overloaded_args = [ - arg for arg in overloaded_args - if type(arg).__array_function__ is not _NDARRAY_ARRAY_FUNCTION - ] - return overloaded_types, overloaded_args @@ -106,7 +96,11 @@ def array_function_implementation_or_override( """ # Check for __array_function__ methods. types, overloaded_args = get_overloaded_types_and_args(relevant_args) - if not overloaded_args: + # Short-cut for common cases: no overload or only ndarray overload + # (directly or with subclasses that do not override __array_function__). + if (not overloaded_args or types == _NDARRAY_ONLY or + all(type(arg).__array_function__ is _NDARRAY_ARRAY_FUNCTION + for arg in overloaded_args)): return implementation(*args, **kwargs) # Call overrides |