summaryrefslogtreecommitdiff
path: root/numpy/core/overrides.py
diff options
context:
space:
mode:
authorStephan Hoyer <shoyer@google.com>2018-12-19 10:08:01 -0800
committerStephan Hoyer <shoyer@google.com>2018-12-19 10:46:40 -0800
commit45413455791c77465c5f33a5082053274eb18900 (patch)
tree7362d739f4558b4aece09bbde333df5c219d8f1a /numpy/core/overrides.py
parentf4ddc2b3251d4606cc227761933e991131425416 (diff)
downloadnumpy-45413455791c77465c5f33a5082053274eb18900.tar.gz
ENH: refactor __array_function__ pure Python implementation
Diffstat (limited to 'numpy/core/overrides.py')
-rw-r--r--numpy/core/overrides.py108
1 files changed, 69 insertions, 39 deletions
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py
index 0979858a1..8f535b213 100644
--- a/numpy/core/overrides.py
+++ b/numpy/core/overrides.py
@@ -1,4 +1,4 @@
-"""Preliminary implementation of NEP-18
+"""Preliminary implementation of NEP-18.
TODO: rewrite this in C for performance.
"""
@@ -10,64 +10,80 @@ from numpy.core._multiarray_umath import add_docstring, ndarray
from numpy.compat._inspect import getargspec
-_NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__
-_NDARRAY_ONLY = [ndarray]
-
ENABLE_ARRAY_FUNCTION = bool(
int(os.environ.get('NUMPY_EXPERIMENTAL_ARRAY_FUNCTION', 0)))
-def get_overloaded_types_and_args(relevant_args):
+def get_implementing_types_and_args(relevant_args):
"""Returns a list of arguments on which to call __array_function__.
-
Parameters
----------
relevant_args : iterable of array-like
Iterable of array-like arguments to check for __array_function__
methods.
-
Returns
-------
- overloaded_types : collection of types
+ implementing_types : collection of types
Types of arguments from relevant_args with __array_function__ methods.
- overloaded_args : list
+ implementing_args : list
Arguments from relevant_args on which to call __array_function__
methods, in the order in which they should be called.
"""
# Runtime is O(num_arguments * num_unique_types)
- overloaded_types = []
- overloaded_args = []
+ implementing_types = []
+ implementing_args = []
for arg in relevant_args:
arg_type = type(arg)
# We only collect arguments if they have a unique type, which ensures
# reasonable performance even with a long list of possibly overloaded
# arguments.
- if (arg_type not in overloaded_types and
+ if (arg_type not in implementing_types and
hasattr(arg_type, '__array_function__')):
# Create lists explicitly for the first type (usually the only one
- # done) to avoid setting up the iterator for overloaded_args.
- if overloaded_types:
- overloaded_types.append(arg_type)
+ # done) to avoid setting up the iterator for implementing_args.
+ if implementing_types:
+ implementing_types.append(arg_type)
# By default, insert argument at the end, but if it is
# subclass of another argument, insert it before that argument.
# This ensures "subclasses before superclasses".
- index = len(overloaded_args)
- for i, old_arg in enumerate(overloaded_args):
+ index = len(implementing_args)
+ for i, old_arg in enumerate(implementing_args):
if issubclass(arg_type, type(old_arg)):
index = i
break
- overloaded_args.insert(index, arg)
+ implementing_args.insert(index, arg)
else:
- overloaded_types = [arg_type]
- overloaded_args = [arg]
+ implementing_types = [arg_type]
+ implementing_args = [arg]
+
+ return implementing_types, implementing_args
+
+
+_NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__
+
- return overloaded_types, overloaded_args
+def any_overrides(relevant_args):
+ """Are there any __array_function__ methods that need to be called?"""
+ for arg in relevant_args:
+ arg_type = type(arg)
+ if (arg_type is not ndarray and
+ getattr(arg_type, '__array_function__',
+ _NDARRAY_ARRAY_FUNCTION)
+ is not _NDARRAY_ARRAY_FUNCTION):
+ return True
+ return False
-def array_function_implementation_or_override(
+_TUPLE_OR_LIST = {tuple, list}
+
+
+def implement_array_function(
implementation, public_api, relevant_args, args, kwargs):
- """Implement a function with checks for __array_function__ overrides.
+ """
+ Implement a function with checks for __array_function__ overrides.
+
+ All arguments are required, and can only be passed by position.
Arguments
---------
@@ -82,41 +98,55 @@ def array_function_implementation_or_override(
Iterable of arguments to check for __array_function__ methods.
args : tuple
Arbitrary positional arguments originally passed into ``public_api``.
- kwargs : tuple
+ kwargs : dict
Arbitrary keyword arguments originally passed into ``public_api``.
Returns
-------
- Result from calling `implementation()` or an `__array_function__`
+ Result from calling ``implementation()`` or an ``__array_function__``
method, as appropriate.
Raises
------
TypeError : if no implementation is found.
"""
- # Check for __array_function__ methods.
- types, overloaded_args = get_overloaded_types_and_args(relevant_args)
- # Short-cut for common cases: no overload or only ndarray overload
- # (directly or with subclasses that do not override __array_function__).
- if (not overloaded_args or types == _NDARRAY_ONLY or
- all(type(arg).__array_function__ is _NDARRAY_ARRAY_FUNCTION
- for arg in overloaded_args)):
+ if type(relevant_args) not in _TUPLE_OR_LIST:
+ relevant_args = tuple(relevant_args)
+
+ if not any_overrides(relevant_args):
return implementation(*args, **kwargs)
# Call overrides
- for overloaded_arg in overloaded_args:
+ types, implementing_args = get_implementing_types_and_args(relevant_args)
+ for arg in implementing_args:
# Use `public_api` instead of `implemenation` so __array_function__
# implementations can do equality/identity comparisons.
- result = overloaded_arg.__array_function__(
- public_api, types, args, kwargs)
-
+ result = arg.__array_function__(public_api, types, args, kwargs)
if result is not NotImplemented:
return result
func_name = '{}.{}'.format(public_api.__module__, public_api.__name__)
raise TypeError("no implementation found for '{}' on types that implement "
- '__array_function__: {}'
- .format(func_name, list(map(type, overloaded_args))))
+ '__array_function__: {}'.format(func_name, list(types)))
+
+
+def _get_implementing_args(relevant_args):
+ """
+ Collect arguments on which to call __array_function__.
+
+ Parameters
+ ----------
+ relevant_args : iterable of array-like
+ Iterable of possibly array-like arguments to check for
+ __array_function__ methods.
+
+ Returns
+ -------
+ Sequence of arguments with __array_function__ methods, in the order in
+ which they should be called.
+ """
+ _, args = get_implementing_types_and_args(relevant_args)
+ return args
ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults')
@@ -215,7 +245,7 @@ def array_function_dispatch(dispatcher, module=None, verify=True,
@functools.wraps(implementation)
def public_api(*args, **kwargs):
relevant_args = dispatcher(*args, **kwargs)
- return array_function_implementation_or_override(
+ return implement_array_function(
implementation, public_api, relevant_args, args, kwargs)
if module is not None: