diff options
author | Stephan Hoyer <shoyer@google.com> | 2018-09-26 17:53:41 -0700 |
---|---|---|
committer | Stephan Hoyer <shoyer@google.com> | 2018-09-26 17:57:05 -0700 |
commit | 70bc9ee43858c5484372dd6f164ca110bf5368e4 (patch) | |
tree | 4699a271883dff590cf74f54e858ede8cfeea823 /numpy/core/overrides.py | |
parent | fe1c1fbac3dbd0d314b96ad95447e370dceae079 (diff) | |
download | numpy-70bc9ee43858c5484372dd6f164ca110bf5368e4.tar.gz |
CLN: cleanup and better document core/overrides.py
I think this organization makes a little more sense, especially when thinking
about what the functions we want to write in C should look like.
Diffstat (limited to 'numpy/core/overrides.py')
-rw-r--r-- | numpy/core/overrides.py | 95 |
1 files changed, 69 insertions, 26 deletions
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py index c1d5e3864..f2997b600 100644 --- a/numpy/core/overrides.py +++ b/numpy/core/overrides.py @@ -12,8 +12,19 @@ _NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__ def get_overloaded_types_and_args(relevant_args): """Returns a list of arguments on which to call __array_function__. - __array_function__ implementations should be called in order on the return - values from this function. + Parameters + ---------- + relevant_args : iterable of array-like + Iterable of array-like arguments to check for __array_function__ + methods. + + Returns + ------- + overloaded_types : collection of types + Types of arguments from relevant_args with __array_function__ methods. + overloaded_args : list + Arguments from relevant_args on which to call __array_function__ + methods, in the order in which they should be called. """ # Runtime is O(num_arguments * num_unique_types) overloaded_types = [] @@ -35,7 +46,7 @@ def get_overloaded_types_and_args(relevant_args): break overloaded_args.insert(index, arg) - # Special handling for ndarray. + # Special handling for ndarray.__array_function__ overloaded_args = [ arg for arg in overloaded_args if type(arg).__array_function__ is not _NDARRAY_ARRAY_FUNCTION @@ -44,43 +55,75 @@ def get_overloaded_types_and_args(relevant_args): return overloaded_types, overloaded_args -def array_function_override(overloaded_args, func, types, args, kwargs): - """Call __array_function__ implementations.""" +def array_function_implementation_or_override( + implementation_func, api_func, dispatcher, args, kwargs): + """Implement a function with checks for __array_function__ overrides. + + Arguments + --------- + implementation_func : function + Function that implements the operation on NumPy array without + overrides when called like `implementation_func(*args, **kwargs)`. + api_func : function + Function exposed by NumPy's public API on which overrides are being + checked here. + dispatcher : callable + Function that when called like `dispatcher(*args, **kwargs)` returns an + iterable of relevant argument to check to for __array_function__ + attributes. + args : tuple + Arbitrary positional arguments originally passed into api_func. + kwargs : tuple + Arbitrary keyword arguments originally passed into api_func. + + Returns + ------- + Result from calling `implementation_func()` or an `__array_function__` + method, as appropriate. + + Raises + ------ + TypeError : if no implementation is found. + """ + + # Collect array-like arguments. + relevant_arguments = dispatcher(*args, **kwargs) + + # Check for __array_function__ methods. + types, overloaded_args = get_overloaded_types_and_args( + relevant_arguments) + + # Fast path + if not overloaded_args: + return implementation_func(*args, **kwargs) + + # Call overrides for overloaded_arg in overloaded_args: # Note that we're only calling __array_function__ on the *first* # occurence of each argument type. This is necessary for reasonable # performance with a possibly long list of overloaded arguments, for # which each __array_function__ implementation might reasonably need to # check all argument types. - result = overloaded_arg.__array_function__(func, types, args, kwargs) + # api_func is the function exposed in NumPy's public API. We + # use it instead of func so __array_function__ implementations + # can do equality/identity comparisons. + result = overloaded_arg.__array_function__( + api_func, types, args, kwargs) if result is not NotImplemented: return result raise TypeError('no implementation found for {} on types that implement ' '__array_function__: {}' - .format(func, list(map(type, overloaded_args)))) + .format(api_func, list(map(type, overloaded_args)))) def array_function_dispatch(dispatcher): """Wrap a function for dispatch with the __array_function__ protocol.""" - def decorator(func): - @functools.wraps(func) - def new_func(*args, **kwargs): - # Collect array-like arguments. - relevant_arguments = dispatcher(*args, **kwargs) - # Check for __array_function__ methods. - types, overloaded_args = get_overloaded_types_and_args( - relevant_arguments) - # Call overrides, if necessary. - if overloaded_args: - # new_func is the function exposed in NumPy's public API. We - # use it instead of func so __array_function__ implementations - # can do equality/identity comparisons. - return array_function_override( - overloaded_args, new_func, types, args, kwargs) - else: - return func(*args, **kwargs) - - return new_func + def decorator(implementation_func): + @functools.wraps(implementation_func) + def api_func(*args, **kwargs): + return array_function_implementation_or_override( + implementation_func, api_func, dispatcher, args, kwargs) + return api_func return decorator |