summaryrefslogtreecommitdiff
path: root/numpy/core/overrides.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core/overrides.py')
-rw-r--r--numpy/core/overrides.py95
1 files changed, 69 insertions, 26 deletions
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py
index c1d5e3864..f2997b600 100644
--- a/numpy/core/overrides.py
+++ b/numpy/core/overrides.py
@@ -12,8 +12,19 @@ _NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__
def get_overloaded_types_and_args(relevant_args):
"""Returns a list of arguments on which to call __array_function__.
- __array_function__ implementations should be called in order on the return
- values from this function.
+ Parameters
+ ----------
+ relevant_args : iterable of array-like
+ Iterable of array-like arguments to check for __array_function__
+ methods.
+
+ Returns
+ -------
+ overloaded_types : collection of types
+ Types of arguments from relevant_args with __array_function__ methods.
+ overloaded_args : list
+ Arguments from relevant_args on which to call __array_function__
+ methods, in the order in which they should be called.
"""
# Runtime is O(num_arguments * num_unique_types)
overloaded_types = []
@@ -35,7 +46,7 @@ def get_overloaded_types_and_args(relevant_args):
break
overloaded_args.insert(index, arg)
- # Special handling for ndarray.
+ # Special handling for ndarray.__array_function__
overloaded_args = [
arg for arg in overloaded_args
if type(arg).__array_function__ is not _NDARRAY_ARRAY_FUNCTION
@@ -44,43 +55,75 @@ def get_overloaded_types_and_args(relevant_args):
return overloaded_types, overloaded_args
-def array_function_override(overloaded_args, func, types, args, kwargs):
- """Call __array_function__ implementations."""
+def array_function_implementation_or_override(
+ implementation_func, api_func, dispatcher, args, kwargs):
+ """Implement a function with checks for __array_function__ overrides.
+
+ Arguments
+ ---------
+ implementation_func : function
+ Function that implements the operation on NumPy array without
+ overrides when called like `implementation_func(*args, **kwargs)`.
+ api_func : function
+ Function exposed by NumPy's public API on which overrides are being
+ checked here.
+ dispatcher : callable
+ Function that when called like `dispatcher(*args, **kwargs)` returns an
+ iterable of relevant argument to check to for __array_function__
+ attributes.
+ args : tuple
+ Arbitrary positional arguments originally passed into api_func.
+ kwargs : tuple
+ Arbitrary keyword arguments originally passed into api_func.
+
+ Returns
+ -------
+ Result from calling `implementation_func()` or an `__array_function__`
+ method, as appropriate.
+
+ Raises
+ ------
+ TypeError : if no implementation is found.
+ """
+
+ # Collect array-like arguments.
+ relevant_arguments = dispatcher(*args, **kwargs)
+
+ # Check for __array_function__ methods.
+ types, overloaded_args = get_overloaded_types_and_args(
+ relevant_arguments)
+
+ # Fast path
+ if not overloaded_args:
+ return implementation_func(*args, **kwargs)
+
+ # Call overrides
for overloaded_arg in overloaded_args:
# Note that we're only calling __array_function__ on the *first*
# occurence of each argument type. This is necessary for reasonable
# performance with a possibly long list of overloaded arguments, for
# which each __array_function__ implementation might reasonably need to
# check all argument types.
- result = overloaded_arg.__array_function__(func, types, args, kwargs)
+ # api_func is the function exposed in NumPy's public API. We
+ # use it instead of func so __array_function__ implementations
+ # can do equality/identity comparisons.
+ result = overloaded_arg.__array_function__(
+ api_func, types, args, kwargs)
if result is not NotImplemented:
return result
raise TypeError('no implementation found for {} on types that implement '
'__array_function__: {}'
- .format(func, list(map(type, overloaded_args))))
+ .format(api_func, list(map(type, overloaded_args))))
def array_function_dispatch(dispatcher):
"""Wrap a function for dispatch with the __array_function__ protocol."""
- def decorator(func):
- @functools.wraps(func)
- def new_func(*args, **kwargs):
- # Collect array-like arguments.
- relevant_arguments = dispatcher(*args, **kwargs)
- # Check for __array_function__ methods.
- types, overloaded_args = get_overloaded_types_and_args(
- relevant_arguments)
- # Call overrides, if necessary.
- if overloaded_args:
- # new_func is the function exposed in NumPy's public API. We
- # use it instead of func so __array_function__ implementations
- # can do equality/identity comparisons.
- return array_function_override(
- overloaded_args, new_func, types, args, kwargs)
- else:
- return func(*args, **kwargs)
-
- return new_func
+ def decorator(implementation_func):
+ @functools.wraps(implementation_func)
+ def api_func(*args, **kwargs):
+ return array_function_implementation_or_override(
+ implementation_func, api_func, dispatcher, args, kwargs)
+ return api_func
return decorator