diff options
author | Stephan Hoyer <shoyer@gmail.com> | 2018-12-02 17:54:43 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-12-02 17:54:43 -0800 |
commit | 04c2f33843782cd27a16f919477b1b27c4b01e5a (patch) | |
tree | f59d7ebc6004f3804e30a0bdabd987a1e6f02152 /numpy/core/overrides.py | |
parent | dc608ba2a85a7b89663bfd4d0ad977f40d0f813a (diff) | |
parent | f7a82245fcd61a9b477d250d94951f33dfe2f097 (diff) | |
download | numpy-04c2f33843782cd27a16f919477b1b27c4b01e5a.tar.gz |
Merge pull request #12328 from mhvk/ndarray_array_function_allow_subclasses
MAINT: Allow subclasses in `ndarray.__array_function__`.
Diffstat (limited to 'numpy/core/overrides.py')
-rw-r--r-- | numpy/core/overrides.py | 16 |
1 files changed, 5 insertions, 11 deletions
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py index bddf14310..0979858a1 100644 --- a/numpy/core/overrides.py +++ b/numpy/core/overrides.py @@ -62,16 +62,6 @@ def get_overloaded_types_and_args(relevant_args): overloaded_types = [arg_type] overloaded_args = [arg] - # Short-cut for the common case of only ndarray. - if overloaded_types == _NDARRAY_ONLY: - return overloaded_types, [] - - # Special handling for ndarray.__array_function__ - overloaded_args = [ - arg for arg in overloaded_args - if type(arg).__array_function__ is not _NDARRAY_ARRAY_FUNCTION - ] - return overloaded_types, overloaded_args @@ -106,7 +96,11 @@ def array_function_implementation_or_override( """ # Check for __array_function__ methods. types, overloaded_args = get_overloaded_types_and_args(relevant_args) - if not overloaded_args: + # Short-cut for common cases: no overload or only ndarray overload + # (directly or with subclasses that do not override __array_function__). + if (not overloaded_args or types == _NDARRAY_ONLY or + all(type(arg).__array_function__ is _NDARRAY_ARRAY_FUNCTION + for arg in overloaded_args)): return implementation(*args, **kwargs) # Call overrides |