diff options
author | Charles Harris <charlesr.harris@gmail.com> | 2018-09-30 08:09:26 -0500 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-09-30 08:09:26 -0500 |
commit | 4a85dc3e57aa1c9066b107780dbbf504f951eb99 (patch) | |
tree | 8eeb80c5f0ef5f1d57b8d48d023b19b29a728a2c | |
parent | 90ea6b37a4dee698fd7a1659c5bd601eae6fef64 (diff) | |
parent | 587717b89596004b7a62d96458fad929480678a0 (diff) | |
download | numpy-4a85dc3e57aa1c9066b107780dbbf504f951eb99.tar.gz |
Merge pull request #12042 from shoyer/py-overrides-cleanup
MAINT: cleanup and better document core/overrides.py
-rw-r--r-- | numpy/core/overrides.py | 97 |
1 files changed, 65 insertions, 32 deletions
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py index c1d5e3864..17e3d475f 100644 --- a/numpy/core/overrides.py +++ b/numpy/core/overrides.py @@ -12,14 +12,28 @@ _NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__ def get_overloaded_types_and_args(relevant_args): """Returns a list of arguments on which to call __array_function__. - __array_function__ implementations should be called in order on the return - values from this function. + Parameters + ---------- + relevant_args : iterable of array-like + Iterable of array-like arguments to check for __array_function__ + methods. + + Returns + ------- + overloaded_types : collection of types + Types of arguments from relevant_args with __array_function__ methods. + overloaded_args : list + Arguments from relevant_args on which to call __array_function__ + methods, in the order in which they should be called. """ # Runtime is O(num_arguments * num_unique_types) overloaded_types = [] overloaded_args = [] for arg in relevant_args: arg_type = type(arg) + # We only collect arguments if they have a unique type, which ensures + # reasonable performance even with a long list of possibly overloaded + # arguments. if (arg_type not in overloaded_types and hasattr(arg_type, '__array_function__')): @@ -35,7 +49,7 @@ def get_overloaded_types_and_args(relevant_args): break overloaded_args.insert(index, arg) - # Special handling for ndarray. + # Special handling for ndarray.__array_function__ overloaded_args = [ arg for arg in overloaded_args if type(arg).__array_function__ is not _NDARRAY_ARRAY_FUNCTION @@ -44,43 +58,62 @@ def get_overloaded_types_and_args(relevant_args): return overloaded_types, overloaded_args -def array_function_override(overloaded_args, func, types, args, kwargs): - """Call __array_function__ implementations.""" +def array_function_implementation_or_override( + implementation, public_api, relevant_args, args, kwargs): + """Implement a function with checks for __array_function__ overrides. + + Arguments + --------- + implementation : function + Function that implements the operation on NumPy array without + overrides when called like ``implementation(*args, **kwargs)``. + public_api : function + Function exposed by NumPy's public API riginally called like + ``public_api(*args, **kwargs`` on which arguments are now being + checked. + relevant_args : iterable + Iterable of arguments to check for __array_function__ methods. + args : tuple + Arbitrary positional arguments originally passed into ``public_api``. + kwargs : tuple + Arbitrary keyword arguments originally passed into ``public_api``. + + Returns + ------- + Result from calling `implementation()` or an `__array_function__` + method, as appropriate. + + Raises + ------ + TypeError : if no implementation is found. + """ + # Check for __array_function__ methods. + types, overloaded_args = get_overloaded_types_and_args(relevant_args) + if not overloaded_args: + return implementation(*args, **kwargs) + + # Call overrides for overloaded_arg in overloaded_args: - # Note that we're only calling __array_function__ on the *first* - # occurence of each argument type. This is necessary for reasonable - # performance with a possibly long list of overloaded arguments, for - # which each __array_function__ implementation might reasonably need to - # check all argument types. - result = overloaded_arg.__array_function__(func, types, args, kwargs) + # Use `public_api` instead of `implemenation` so __array_function__ + # implementations can do equality/identity comparisons. + result = overloaded_arg.__array_function__( + public_api, types, args, kwargs) if result is not NotImplemented: return result raise TypeError('no implementation found for {} on types that implement ' '__array_function__: {}' - .format(func, list(map(type, overloaded_args)))) + .format(public_api, list(map(type, overloaded_args)))) def array_function_dispatch(dispatcher): - """Wrap a function for dispatch with the __array_function__ protocol.""" - def decorator(func): - @functools.wraps(func) - def new_func(*args, **kwargs): - # Collect array-like arguments. - relevant_arguments = dispatcher(*args, **kwargs) - # Check for __array_function__ methods. - types, overloaded_args = get_overloaded_types_and_args( - relevant_arguments) - # Call overrides, if necessary. - if overloaded_args: - # new_func is the function exposed in NumPy's public API. We - # use it instead of func so __array_function__ implementations - # can do equality/identity comparisons. - return array_function_override( - overloaded_args, new_func, types, args, kwargs) - else: - return func(*args, **kwargs) - - return new_func + """Decorator for adding dispatch with the __array_function__ protocol.""" + def decorator(implementation): + @functools.wraps(implementation) + def public_api(*args, **kwargs): + relevant_args = dispatcher(*args, **kwargs) + return array_function_implementation_or_override( + implementation, public_api, relevant_args, args, kwargs) + return public_api return decorator |