diff options
author | Stephan Hoyer <shoyer@google.com> | 2018-12-19 10:48:10 -0800 |
---|---|---|
committer | Stephan Hoyer <shoyer@google.com> | 2018-12-19 10:48:10 -0800 |
commit | 7104a3e9fe989f43f20765d740cf22dabae563d4 (patch) | |
tree | cde09169afa607600abd09d68989fe62a4f37ccf /numpy/core/overrides.py | |
parent | 45413455791c77465c5f33a5082053274eb18900 (diff) | |
download | numpy-7104a3e9fe989f43f20765d740cf22dabae563d4.tar.gz |
ENH: port __array_function__ overrides to C
Diffstat (limited to 'numpy/core/overrides.py')
-rw-r--r-- | numpy/core/overrides.py | 104 |
1 files changed, 10 insertions, 94 deletions
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py index 8f535b213..c55174ecd 100644 --- a/numpy/core/overrides.py +++ b/numpy/core/overrides.py @@ -1,12 +1,10 @@ -"""Preliminary implementation of NEP-18. - -TODO: rewrite this in C for performance. -""" +"""Implementation of __array_function__ overrides from NEP-18.""" import collections import functools import os -from numpy.core._multiarray_umath import add_docstring, ndarray +from numpy.core._multiarray_umath import ( + add_docstring, implement_array_function, _get_implementing_args) from numpy.compat._inspect import getargspec @@ -14,72 +12,8 @@ ENABLE_ARRAY_FUNCTION = bool( int(os.environ.get('NUMPY_EXPERIMENTAL_ARRAY_FUNCTION', 0))) -def get_implementing_types_and_args(relevant_args): - """Returns a list of arguments on which to call __array_function__. - Parameters - ---------- - relevant_args : iterable of array-like - Iterable of array-like arguments to check for __array_function__ - methods. - Returns - ------- - implementing_types : collection of types - Types of arguments from relevant_args with __array_function__ methods. - implementing_args : list - Arguments from relevant_args on which to call __array_function__ - methods, in the order in which they should be called. - """ - # Runtime is O(num_arguments * num_unique_types) - implementing_types = [] - implementing_args = [] - for arg in relevant_args: - arg_type = type(arg) - # We only collect arguments if they have a unique type, which ensures - # reasonable performance even with a long list of possibly overloaded - # arguments. - if (arg_type not in implementing_types and - hasattr(arg_type, '__array_function__')): - - # Create lists explicitly for the first type (usually the only one - # done) to avoid setting up the iterator for implementing_args. - if implementing_types: - implementing_types.append(arg_type) - # By default, insert argument at the end, but if it is - # subclass of another argument, insert it before that argument. - # This ensures "subclasses before superclasses". - index = len(implementing_args) - for i, old_arg in enumerate(implementing_args): - if issubclass(arg_type, type(old_arg)): - index = i - break - implementing_args.insert(index, arg) - else: - implementing_types = [arg_type] - implementing_args = [arg] - - return implementing_types, implementing_args - - -_NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__ - - -def any_overrides(relevant_args): - """Are there any __array_function__ methods that need to be called?""" - for arg in relevant_args: - arg_type = type(arg) - if (arg_type is not ndarray and - getattr(arg_type, '__array_function__', - _NDARRAY_ARRAY_FUNCTION) - is not _NDARRAY_ARRAY_FUNCTION): - return True - return False - - -_TUPLE_OR_LIST = {tuple, list} - - -def implement_array_function( - implementation, public_api, relevant_args, args, kwargs): +add_docstring( + implement_array_function, """ Implement a function with checks for __array_function__ overrides. @@ -109,28 +43,12 @@ def implement_array_function( Raises ------ TypeError : if no implementation is found. - """ - if type(relevant_args) not in _TUPLE_OR_LIST: - relevant_args = tuple(relevant_args) - - if not any_overrides(relevant_args): - return implementation(*args, **kwargs) + """) - # Call overrides - types, implementing_args = get_implementing_types_and_args(relevant_args) - for arg in implementing_args: - # Use `public_api` instead of `implemenation` so __array_function__ - # implementations can do equality/identity comparisons. - result = arg.__array_function__(public_api, types, args, kwargs) - if result is not NotImplemented: - return result - func_name = '{}.{}'.format(public_api.__module__, public_api.__name__) - raise TypeError("no implementation found for '{}' on types that implement " - '__array_function__: {}'.format(func_name, list(types))) - - -def _get_implementing_args(relevant_args): +# exposed for testing purposes; used internally by implement_array_function +add_docstring( + _get_implementing_args, """ Collect arguments on which to call __array_function__. @@ -144,9 +62,7 @@ def _get_implementing_args(relevant_args): ------- Sequence of arguments with __array_function__ methods, in the order in which they should be called. - """ - _, args = get_implementing_types_and_args(relevant_args) - return args + """) ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults') |