summaryrefslogtreecommitdiff
path: root/numpy/core/overrides.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core/overrides.py')
-rw-r--r--numpy/core/overrides.py114
1 files changed, 30 insertions, 84 deletions
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py
index 0979858a1..c55174ecd 100644
--- a/numpy/core/overrides.py
+++ b/numpy/core/overrides.py
@@ -1,73 +1,23 @@
-"""Preliminary implementation of NEP-18
-
-TODO: rewrite this in C for performance.
-"""
+"""Implementation of __array_function__ overrides from NEP-18."""
import collections
import functools
import os
-from numpy.core._multiarray_umath import add_docstring, ndarray
+from numpy.core._multiarray_umath import (
+ add_docstring, implement_array_function, _get_implementing_args)
from numpy.compat._inspect import getargspec
-_NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__
-_NDARRAY_ONLY = [ndarray]
-
ENABLE_ARRAY_FUNCTION = bool(
int(os.environ.get('NUMPY_EXPERIMENTAL_ARRAY_FUNCTION', 0)))
-def get_overloaded_types_and_args(relevant_args):
- """Returns a list of arguments on which to call __array_function__.
-
- Parameters
- ----------
- relevant_args : iterable of array-like
- Iterable of array-like arguments to check for __array_function__
- methods.
-
- Returns
- -------
- overloaded_types : collection of types
- Types of arguments from relevant_args with __array_function__ methods.
- overloaded_args : list
- Arguments from relevant_args on which to call __array_function__
- methods, in the order in which they should be called.
+add_docstring(
+ implement_array_function,
"""
- # Runtime is O(num_arguments * num_unique_types)
- overloaded_types = []
- overloaded_args = []
- for arg in relevant_args:
- arg_type = type(arg)
- # We only collect arguments if they have a unique type, which ensures
- # reasonable performance even with a long list of possibly overloaded
- # arguments.
- if (arg_type not in overloaded_types and
- hasattr(arg_type, '__array_function__')):
-
- # Create lists explicitly for the first type (usually the only one
- # done) to avoid setting up the iterator for overloaded_args.
- if overloaded_types:
- overloaded_types.append(arg_type)
- # By default, insert argument at the end, but if it is
- # subclass of another argument, insert it before that argument.
- # This ensures "subclasses before superclasses".
- index = len(overloaded_args)
- for i, old_arg in enumerate(overloaded_args):
- if issubclass(arg_type, type(old_arg)):
- index = i
- break
- overloaded_args.insert(index, arg)
- else:
- overloaded_types = [arg_type]
- overloaded_args = [arg]
-
- return overloaded_types, overloaded_args
-
-
-def array_function_implementation_or_override(
- implementation, public_api, relevant_args, args, kwargs):
- """Implement a function with checks for __array_function__ overrides.
+ Implement a function with checks for __array_function__ overrides.
+
+ All arguments are required, and can only be passed by position.
Arguments
---------
@@ -82,41 +32,37 @@ def array_function_implementation_or_override(
Iterable of arguments to check for __array_function__ methods.
args : tuple
Arbitrary positional arguments originally passed into ``public_api``.
- kwargs : tuple
+ kwargs : dict
Arbitrary keyword arguments originally passed into ``public_api``.
Returns
-------
- Result from calling `implementation()` or an `__array_function__`
+ Result from calling ``implementation()`` or an ``__array_function__``
method, as appropriate.
Raises
------
TypeError : if no implementation is found.
+ """)
+
+
+# exposed for testing purposes; used internally by implement_array_function
+add_docstring(
+ _get_implementing_args,
"""
- # Check for __array_function__ methods.
- types, overloaded_args = get_overloaded_types_and_args(relevant_args)
- # Short-cut for common cases: no overload or only ndarray overload
- # (directly or with subclasses that do not override __array_function__).
- if (not overloaded_args or types == _NDARRAY_ONLY or
- all(type(arg).__array_function__ is _NDARRAY_ARRAY_FUNCTION
- for arg in overloaded_args)):
- return implementation(*args, **kwargs)
-
- # Call overrides
- for overloaded_arg in overloaded_args:
- # Use `public_api` instead of `implemenation` so __array_function__
- # implementations can do equality/identity comparisons.
- result = overloaded_arg.__array_function__(
- public_api, types, args, kwargs)
-
- if result is not NotImplemented:
- return result
-
- func_name = '{}.{}'.format(public_api.__module__, public_api.__name__)
- raise TypeError("no implementation found for '{}' on types that implement "
- '__array_function__: {}'
- .format(func_name, list(map(type, overloaded_args))))
+ Collect arguments on which to call __array_function__.
+
+ Parameters
+ ----------
+ relevant_args : iterable of array-like
+ Iterable of possibly array-like arguments to check for
+ __array_function__ methods.
+
+ Returns
+ -------
+ Sequence of arguments with __array_function__ methods, in the order in
+ which they should be called.
+ """)
ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults')
@@ -215,7 +161,7 @@ def array_function_dispatch(dispatcher, module=None, verify=True,
@functools.wraps(implementation)
def public_api(*args, **kwargs):
relevant_args = dispatcher(*args, **kwargs)
- return array_function_implementation_or_override(
+ return implement_array_function(
implementation, public_api, relevant_args, args, kwargs)
if module is not None: