summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2018-09-30 08:09:26 -0500
committerGitHub <noreply@github.com>2018-09-30 08:09:26 -0500
commit4a85dc3e57aa1c9066b107780dbbf504f951eb99 (patch)
tree8eeb80c5f0ef5f1d57b8d48d023b19b29a728a2c
parent90ea6b37a4dee698fd7a1659c5bd601eae6fef64 (diff)
parent587717b89596004b7a62d96458fad929480678a0 (diff)
downloadnumpy-4a85dc3e57aa1c9066b107780dbbf504f951eb99.tar.gz
Merge pull request #12042 from shoyer/py-overrides-cleanup
MAINT: cleanup and better document core/overrides.py
-rw-r--r--numpy/core/overrides.py97
1 files changed, 65 insertions, 32 deletions
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py
index c1d5e3864..17e3d475f 100644
--- a/numpy/core/overrides.py
+++ b/numpy/core/overrides.py
@@ -12,14 +12,28 @@ _NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__
def get_overloaded_types_and_args(relevant_args):
"""Returns a list of arguments on which to call __array_function__.
- __array_function__ implementations should be called in order on the return
- values from this function.
+ Parameters
+ ----------
+ relevant_args : iterable of array-like
+ Iterable of array-like arguments to check for __array_function__
+ methods.
+
+ Returns
+ -------
+ overloaded_types : collection of types
+ Types of arguments from relevant_args with __array_function__ methods.
+ overloaded_args : list
+ Arguments from relevant_args on which to call __array_function__
+ methods, in the order in which they should be called.
"""
# Runtime is O(num_arguments * num_unique_types)
overloaded_types = []
overloaded_args = []
for arg in relevant_args:
arg_type = type(arg)
+ # We only collect arguments if they have a unique type, which ensures
+ # reasonable performance even with a long list of possibly overloaded
+ # arguments.
if (arg_type not in overloaded_types and
hasattr(arg_type, '__array_function__')):
@@ -35,7 +49,7 @@ def get_overloaded_types_and_args(relevant_args):
break
overloaded_args.insert(index, arg)
- # Special handling for ndarray.
+ # Special handling for ndarray.__array_function__
overloaded_args = [
arg for arg in overloaded_args
if type(arg).__array_function__ is not _NDARRAY_ARRAY_FUNCTION
@@ -44,43 +58,62 @@ def get_overloaded_types_and_args(relevant_args):
return overloaded_types, overloaded_args
-def array_function_override(overloaded_args, func, types, args, kwargs):
- """Call __array_function__ implementations."""
+def array_function_implementation_or_override(
+ implementation, public_api, relevant_args, args, kwargs):
+ """Implement a function with checks for __array_function__ overrides.
+
+ Arguments
+ ---------
+ implementation : function
+ Function that implements the operation on NumPy array without
+ overrides when called like ``implementation(*args, **kwargs)``.
+ public_api : function
+ Function exposed by NumPy's public API riginally called like
+ ``public_api(*args, **kwargs`` on which arguments are now being
+ checked.
+ relevant_args : iterable
+ Iterable of arguments to check for __array_function__ methods.
+ args : tuple
+ Arbitrary positional arguments originally passed into ``public_api``.
+ kwargs : tuple
+ Arbitrary keyword arguments originally passed into ``public_api``.
+
+ Returns
+ -------
+ Result from calling `implementation()` or an `__array_function__`
+ method, as appropriate.
+
+ Raises
+ ------
+ TypeError : if no implementation is found.
+ """
+ # Check for __array_function__ methods.
+ types, overloaded_args = get_overloaded_types_and_args(relevant_args)
+ if not overloaded_args:
+ return implementation(*args, **kwargs)
+
+ # Call overrides
for overloaded_arg in overloaded_args:
- # Note that we're only calling __array_function__ on the *first*
- # occurence of each argument type. This is necessary for reasonable
- # performance with a possibly long list of overloaded arguments, for
- # which each __array_function__ implementation might reasonably need to
- # check all argument types.
- result = overloaded_arg.__array_function__(func, types, args, kwargs)
+ # Use `public_api` instead of `implemenation` so __array_function__
+ # implementations can do equality/identity comparisons.
+ result = overloaded_arg.__array_function__(
+ public_api, types, args, kwargs)
if result is not NotImplemented:
return result
raise TypeError('no implementation found for {} on types that implement '
'__array_function__: {}'
- .format(func, list(map(type, overloaded_args))))
+ .format(public_api, list(map(type, overloaded_args))))
def array_function_dispatch(dispatcher):
- """Wrap a function for dispatch with the __array_function__ protocol."""
- def decorator(func):
- @functools.wraps(func)
- def new_func(*args, **kwargs):
- # Collect array-like arguments.
- relevant_arguments = dispatcher(*args, **kwargs)
- # Check for __array_function__ methods.
- types, overloaded_args = get_overloaded_types_and_args(
- relevant_arguments)
- # Call overrides, if necessary.
- if overloaded_args:
- # new_func is the function exposed in NumPy's public API. We
- # use it instead of func so __array_function__ implementations
- # can do equality/identity comparisons.
- return array_function_override(
- overloaded_args, new_func, types, args, kwargs)
- else:
- return func(*args, **kwargs)
-
- return new_func
+ """Decorator for adding dispatch with the __array_function__ protocol."""
+ def decorator(implementation):
+ @functools.wraps(implementation)
+ def public_api(*args, **kwargs):
+ relevant_args = dispatcher(*args, **kwargs)
+ return array_function_implementation_or_override(
+ implementation, public_api, relevant_args, args, kwargs)
+ return public_api
return decorator