summaryrefslogtreecommitdiff
path: root/numpy/core/overrides.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core/overrides.py')
-rw-r--r--numpy/core/overrides.py66
1 files changed, 37 insertions, 29 deletions
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py
index 06e005102..c1d5e3864 100644
--- a/numpy/core/overrides.py
+++ b/numpy/core/overrides.py
@@ -6,6 +6,9 @@ import functools
from numpy.core.multiarray import ndarray
+_NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__
+
+
def get_overloaded_types_and_args(relevant_args):
"""Returns a list of arguments on which to call __array_function__.
@@ -17,36 +20,32 @@ def get_overloaded_types_and_args(relevant_args):
overloaded_args = []
for arg in relevant_args:
arg_type = type(arg)
- if arg_type not in overloaded_types:
- try:
- array_function = arg_type.__array_function__
- except AttributeError:
- continue
+ if (arg_type not in overloaded_types and
+ hasattr(arg_type, '__array_function__')):
overloaded_types.append(arg_type)
- if array_function is not ndarray.__array_function__:
- # By default, insert this argument at the end, but if it is
- # subclass of another argument, insert it before that argument.
- # This ensures "subclasses before superclasses".
- index = len(overloaded_args)
- for i, old_arg in enumerate(overloaded_args):
- if issubclass(arg_type, type(old_arg)):
- index = i
- break
- overloaded_args.insert(index, arg)
+ # By default, insert this argument at the end, but if it is
+ # subclass of another argument, insert it before that argument.
+ # This ensures "subclasses before superclasses".
+ index = len(overloaded_args)
+ for i, old_arg in enumerate(overloaded_args):
+ if issubclass(arg_type, type(old_arg)):
+ index = i
+ break
+ overloaded_args.insert(index, arg)
- return tuple(overloaded_types), tuple(overloaded_args)
+ # Special handling for ndarray.
+ overloaded_args = [
+ arg for arg in overloaded_args
+ if type(arg).__array_function__ is not _NDARRAY_ARRAY_FUNCTION
+ ]
+ return overloaded_types, overloaded_args
-def try_array_function_override(func, relevant_arguments, args, kwargs):
- # TODO: consider simplifying the interface, to only require either `types`
- # (by calling __array_function__ a classmethod) or `overloaded_args` (by
- # dropping `types` from the signature of __array_function__)
- types, overloaded_args = get_overloaded_types_and_args(relevant_arguments)
- if not overloaded_args:
- return False, None
+def array_function_override(overloaded_args, func, types, args, kwargs):
+ """Call __array_function__ implementations."""
for overloaded_arg in overloaded_args:
# Note that we're only calling __array_function__ on the *first*
# occurence of each argument type. This is necessary for reasonable
@@ -56,7 +55,7 @@ def try_array_function_override(func, relevant_arguments, args, kwargs):
result = overloaded_arg.__array_function__(func, types, args, kwargs)
if result is not NotImplemented:
- return True, result
+ return result
raise TypeError('no implementation found for {} on types that implement '
'__array_function__: {}'
@@ -68,11 +67,20 @@ def array_function_dispatch(dispatcher):
def decorator(func):
@functools.wraps(func)
def new_func(*args, **kwargs):
+ # Collect array-like arguments.
relevant_arguments = dispatcher(*args, **kwargs)
- success, value = try_array_function_override(
- new_func, relevant_arguments, args, kwargs)
- if success:
- return value
- return func(*args, **kwargs)
+ # Check for __array_function__ methods.
+ types, overloaded_args = get_overloaded_types_and_args(
+ relevant_arguments)
+ # Call overrides, if necessary.
+ if overloaded_args:
+ # new_func is the function exposed in NumPy's public API. We
+ # use it instead of func so __array_function__ implementations
+ # can do equality/identity comparisons.
+ return array_function_override(
+ overloaded_args, new_func, types, args, kwargs)
+ else:
+ return func(*args, **kwargs)
+
return new_func
return decorator