diff options
author | Stephan Hoyer <shoyer@google.com> | 2018-12-19 10:08:01 -0800 |
---|---|---|
committer | Stephan Hoyer <shoyer@google.com> | 2018-12-19 10:46:40 -0800 |
commit | 45413455791c77465c5f33a5082053274eb18900 (patch) | |
tree | 7362d739f4558b4aece09bbde333df5c219d8f1a /numpy/core/overrides.py | |
parent | f4ddc2b3251d4606cc227761933e991131425416 (diff) | |
download | numpy-45413455791c77465c5f33a5082053274eb18900.tar.gz |
ENH: refactor __array_function__ pure Python implementation
Diffstat (limited to 'numpy/core/overrides.py')
-rw-r--r-- | numpy/core/overrides.py | 108 |
1 files changed, 69 insertions, 39 deletions
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py index 0979858a1..8f535b213 100644 --- a/numpy/core/overrides.py +++ b/numpy/core/overrides.py @@ -1,4 +1,4 @@ -"""Preliminary implementation of NEP-18 +"""Preliminary implementation of NEP-18. TODO: rewrite this in C for performance. """ @@ -10,64 +10,80 @@ from numpy.core._multiarray_umath import add_docstring, ndarray from numpy.compat._inspect import getargspec -_NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__ -_NDARRAY_ONLY = [ndarray] - ENABLE_ARRAY_FUNCTION = bool( int(os.environ.get('NUMPY_EXPERIMENTAL_ARRAY_FUNCTION', 0))) -def get_overloaded_types_and_args(relevant_args): +def get_implementing_types_and_args(relevant_args): """Returns a list of arguments on which to call __array_function__. - Parameters ---------- relevant_args : iterable of array-like Iterable of array-like arguments to check for __array_function__ methods. - Returns ------- - overloaded_types : collection of types + implementing_types : collection of types Types of arguments from relevant_args with __array_function__ methods. - overloaded_args : list + implementing_args : list Arguments from relevant_args on which to call __array_function__ methods, in the order in which they should be called. """ # Runtime is O(num_arguments * num_unique_types) - overloaded_types = [] - overloaded_args = [] + implementing_types = [] + implementing_args = [] for arg in relevant_args: arg_type = type(arg) # We only collect arguments if they have a unique type, which ensures # reasonable performance even with a long list of possibly overloaded # arguments. - if (arg_type not in overloaded_types and + if (arg_type not in implementing_types and hasattr(arg_type, '__array_function__')): # Create lists explicitly for the first type (usually the only one - # done) to avoid setting up the iterator for overloaded_args. - if overloaded_types: - overloaded_types.append(arg_type) + # done) to avoid setting up the iterator for implementing_args. + if implementing_types: + implementing_types.append(arg_type) # By default, insert argument at the end, but if it is # subclass of another argument, insert it before that argument. # This ensures "subclasses before superclasses". - index = len(overloaded_args) - for i, old_arg in enumerate(overloaded_args): + index = len(implementing_args) + for i, old_arg in enumerate(implementing_args): if issubclass(arg_type, type(old_arg)): index = i break - overloaded_args.insert(index, arg) + implementing_args.insert(index, arg) else: - overloaded_types = [arg_type] - overloaded_args = [arg] + implementing_types = [arg_type] + implementing_args = [arg] + + return implementing_types, implementing_args + + +_NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__ + - return overloaded_types, overloaded_args +def any_overrides(relevant_args): + """Are there any __array_function__ methods that need to be called?""" + for arg in relevant_args: + arg_type = type(arg) + if (arg_type is not ndarray and + getattr(arg_type, '__array_function__', + _NDARRAY_ARRAY_FUNCTION) + is not _NDARRAY_ARRAY_FUNCTION): + return True + return False -def array_function_implementation_or_override( +_TUPLE_OR_LIST = {tuple, list} + + +def implement_array_function( implementation, public_api, relevant_args, args, kwargs): - """Implement a function with checks for __array_function__ overrides. + """ + Implement a function with checks for __array_function__ overrides. + + All arguments are required, and can only be passed by position. Arguments --------- @@ -82,41 +98,55 @@ def array_function_implementation_or_override( Iterable of arguments to check for __array_function__ methods. args : tuple Arbitrary positional arguments originally passed into ``public_api``. - kwargs : tuple + kwargs : dict Arbitrary keyword arguments originally passed into ``public_api``. Returns ------- - Result from calling `implementation()` or an `__array_function__` + Result from calling ``implementation()`` or an ``__array_function__`` method, as appropriate. Raises ------ TypeError : if no implementation is found. """ - # Check for __array_function__ methods. - types, overloaded_args = get_overloaded_types_and_args(relevant_args) - # Short-cut for common cases: no overload or only ndarray overload - # (directly or with subclasses that do not override __array_function__). - if (not overloaded_args or types == _NDARRAY_ONLY or - all(type(arg).__array_function__ is _NDARRAY_ARRAY_FUNCTION - for arg in overloaded_args)): + if type(relevant_args) not in _TUPLE_OR_LIST: + relevant_args = tuple(relevant_args) + + if not any_overrides(relevant_args): return implementation(*args, **kwargs) # Call overrides - for overloaded_arg in overloaded_args: + types, implementing_args = get_implementing_types_and_args(relevant_args) + for arg in implementing_args: # Use `public_api` instead of `implemenation` so __array_function__ # implementations can do equality/identity comparisons. - result = overloaded_arg.__array_function__( - public_api, types, args, kwargs) - + result = arg.__array_function__(public_api, types, args, kwargs) if result is not NotImplemented: return result func_name = '{}.{}'.format(public_api.__module__, public_api.__name__) raise TypeError("no implementation found for '{}' on types that implement " - '__array_function__: {}' - .format(func_name, list(map(type, overloaded_args)))) + '__array_function__: {}'.format(func_name, list(types))) + + +def _get_implementing_args(relevant_args): + """ + Collect arguments on which to call __array_function__. + + Parameters + ---------- + relevant_args : iterable of array-like + Iterable of possibly array-like arguments to check for + __array_function__ methods. + + Returns + ------- + Sequence of arguments with __array_function__ methods, in the order in + which they should be called. + """ + _, args = get_implementing_types_and_args(relevant_args) + return args ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults') @@ -215,7 +245,7 @@ def array_function_dispatch(dispatcher, module=None, verify=True, @functools.wraps(implementation) def public_api(*args, **kwargs): relevant_args = dispatcher(*args, **kwargs) - return array_function_implementation_or_override( + return implement_array_function( implementation, public_api, relevant_args, args, kwargs) if module is not None: |