summaryrefslogtreecommitdiff
path: root/numpy/core/overrides.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core/overrides.py')
-rw-r--r--numpy/core/overrides.py104
1 files changed, 10 insertions, 94 deletions
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py
index 8f535b213..c55174ecd 100644
--- a/numpy/core/overrides.py
+++ b/numpy/core/overrides.py
@@ -1,12 +1,10 @@
-"""Preliminary implementation of NEP-18.
-
-TODO: rewrite this in C for performance.
-"""
+"""Implementation of __array_function__ overrides from NEP-18."""
import collections
import functools
import os
-from numpy.core._multiarray_umath import add_docstring, ndarray
+from numpy.core._multiarray_umath import (
+ add_docstring, implement_array_function, _get_implementing_args)
from numpy.compat._inspect import getargspec
@@ -14,72 +12,8 @@ ENABLE_ARRAY_FUNCTION = bool(
int(os.environ.get('NUMPY_EXPERIMENTAL_ARRAY_FUNCTION', 0)))
-def get_implementing_types_and_args(relevant_args):
- """Returns a list of arguments on which to call __array_function__.
- Parameters
- ----------
- relevant_args : iterable of array-like
- Iterable of array-like arguments to check for __array_function__
- methods.
- Returns
- -------
- implementing_types : collection of types
- Types of arguments from relevant_args with __array_function__ methods.
- implementing_args : list
- Arguments from relevant_args on which to call __array_function__
- methods, in the order in which they should be called.
- """
- # Runtime is O(num_arguments * num_unique_types)
- implementing_types = []
- implementing_args = []
- for arg in relevant_args:
- arg_type = type(arg)
- # We only collect arguments if they have a unique type, which ensures
- # reasonable performance even with a long list of possibly overloaded
- # arguments.
- if (arg_type not in implementing_types and
- hasattr(arg_type, '__array_function__')):
-
- # Create lists explicitly for the first type (usually the only one
- # done) to avoid setting up the iterator for implementing_args.
- if implementing_types:
- implementing_types.append(arg_type)
- # By default, insert argument at the end, but if it is
- # subclass of another argument, insert it before that argument.
- # This ensures "subclasses before superclasses".
- index = len(implementing_args)
- for i, old_arg in enumerate(implementing_args):
- if issubclass(arg_type, type(old_arg)):
- index = i
- break
- implementing_args.insert(index, arg)
- else:
- implementing_types = [arg_type]
- implementing_args = [arg]
-
- return implementing_types, implementing_args
-
-
-_NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__
-
-
-def any_overrides(relevant_args):
- """Are there any __array_function__ methods that need to be called?"""
- for arg in relevant_args:
- arg_type = type(arg)
- if (arg_type is not ndarray and
- getattr(arg_type, '__array_function__',
- _NDARRAY_ARRAY_FUNCTION)
- is not _NDARRAY_ARRAY_FUNCTION):
- return True
- return False
-
-
-_TUPLE_OR_LIST = {tuple, list}
-
-
-def implement_array_function(
- implementation, public_api, relevant_args, args, kwargs):
+add_docstring(
+ implement_array_function,
"""
Implement a function with checks for __array_function__ overrides.
@@ -109,28 +43,12 @@ def implement_array_function(
Raises
------
TypeError : if no implementation is found.
- """
- if type(relevant_args) not in _TUPLE_OR_LIST:
- relevant_args = tuple(relevant_args)
-
- if not any_overrides(relevant_args):
- return implementation(*args, **kwargs)
+ """)
- # Call overrides
- types, implementing_args = get_implementing_types_and_args(relevant_args)
- for arg in implementing_args:
- # Use `public_api` instead of `implemenation` so __array_function__
- # implementations can do equality/identity comparisons.
- result = arg.__array_function__(public_api, types, args, kwargs)
- if result is not NotImplemented:
- return result
- func_name = '{}.{}'.format(public_api.__module__, public_api.__name__)
- raise TypeError("no implementation found for '{}' on types that implement "
- '__array_function__: {}'.format(func_name, list(types)))
-
-
-def _get_implementing_args(relevant_args):
+# exposed for testing purposes; used internally by implement_array_function
+add_docstring(
+ _get_implementing_args,
"""
Collect arguments on which to call __array_function__.
@@ -144,9 +62,7 @@ def _get_implementing_args(relevant_args):
-------
Sequence of arguments with __array_function__ methods, in the order in
which they should be called.
- """
- _, args = get_implementing_types_and_args(relevant_args)
- return args
+ """)
ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults')