diff options
Diffstat (limited to 'numpy/core/overrides.py')
-rw-r--r-- | numpy/core/overrides.py | 66 |
1 files changed, 37 insertions, 29 deletions
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py index 06e005102..c1d5e3864 100644 --- a/numpy/core/overrides.py +++ b/numpy/core/overrides.py @@ -6,6 +6,9 @@ import functools from numpy.core.multiarray import ndarray +_NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__ + + def get_overloaded_types_and_args(relevant_args): """Returns a list of arguments on which to call __array_function__. @@ -17,36 +20,32 @@ def get_overloaded_types_and_args(relevant_args): overloaded_args = [] for arg in relevant_args: arg_type = type(arg) - if arg_type not in overloaded_types: - try: - array_function = arg_type.__array_function__ - except AttributeError: - continue + if (arg_type not in overloaded_types and + hasattr(arg_type, '__array_function__')): overloaded_types.append(arg_type) - if array_function is not ndarray.__array_function__: - # By default, insert this argument at the end, but if it is - # subclass of another argument, insert it before that argument. - # This ensures "subclasses before superclasses". - index = len(overloaded_args) - for i, old_arg in enumerate(overloaded_args): - if issubclass(arg_type, type(old_arg)): - index = i - break - overloaded_args.insert(index, arg) + # By default, insert this argument at the end, but if it is + # subclass of another argument, insert it before that argument. + # This ensures "subclasses before superclasses". + index = len(overloaded_args) + for i, old_arg in enumerate(overloaded_args): + if issubclass(arg_type, type(old_arg)): + index = i + break + overloaded_args.insert(index, arg) - return tuple(overloaded_types), tuple(overloaded_args) + # Special handling for ndarray. + overloaded_args = [ + arg for arg in overloaded_args + if type(arg).__array_function__ is not _NDARRAY_ARRAY_FUNCTION + ] + return overloaded_types, overloaded_args -def try_array_function_override(func, relevant_arguments, args, kwargs): - # TODO: consider simplifying the interface, to only require either `types` - # (by calling __array_function__ a classmethod) or `overloaded_args` (by - # dropping `types` from the signature of __array_function__) - types, overloaded_args = get_overloaded_types_and_args(relevant_arguments) - if not overloaded_args: - return False, None +def array_function_override(overloaded_args, func, types, args, kwargs): + """Call __array_function__ implementations.""" for overloaded_arg in overloaded_args: # Note that we're only calling __array_function__ on the *first* # occurence of each argument type. This is necessary for reasonable @@ -56,7 +55,7 @@ def try_array_function_override(func, relevant_arguments, args, kwargs): result = overloaded_arg.__array_function__(func, types, args, kwargs) if result is not NotImplemented: - return True, result + return result raise TypeError('no implementation found for {} on types that implement ' '__array_function__: {}' @@ -68,11 +67,20 @@ def array_function_dispatch(dispatcher): def decorator(func): @functools.wraps(func) def new_func(*args, **kwargs): + # Collect array-like arguments. relevant_arguments = dispatcher(*args, **kwargs) - success, value = try_array_function_override( - new_func, relevant_arguments, args, kwargs) - if success: - return value - return func(*args, **kwargs) + # Check for __array_function__ methods. + types, overloaded_args = get_overloaded_types_and_args( + relevant_arguments) + # Call overrides, if necessary. + if overloaded_args: + # new_func is the function exposed in NumPy's public API. We + # use it instead of func so __array_function__ implementations + # can do equality/identity comparisons. + return array_function_override( + overloaded_args, new_func, types, args, kwargs) + else: + return func(*args, **kwargs) + return new_func return decorator |