summaryrefslogtreecommitdiff
path: root/numpy/core/overrides.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core/overrides.py')
-rw-r--r--numpy/core/overrides.py112
1 files changed, 41 insertions, 71 deletions
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py
index 46e1fbe2c..6403e65b0 100644
--- a/numpy/core/overrides.py
+++ b/numpy/core/overrides.py
@@ -6,14 +6,11 @@ import os
from .._utils import set_module
from .._utils._inspect import getargspec
from numpy.core._multiarray_umath import (
- add_docstring, implement_array_function, _get_implementing_args)
+ add_docstring, _get_implementing_args, _ArrayFunctionDispatcher)
ARRAY_FUNCTIONS = set()
-ARRAY_FUNCTION_ENABLED = bool(
- int(os.environ.get('NUMPY_EXPERIMENTAL_ARRAY_FUNCTION', 1)))
-
array_function_like_doc = (
"""like : array_like, optional
Reference object to allow the creation of arrays which are not
@@ -33,40 +30,35 @@ def set_array_function_like_doc(public_api):
add_docstring(
- implement_array_function,
+ _ArrayFunctionDispatcher,
"""
- Implement a function with checks for __array_function__ overrides.
+ Class to wrap functions with checks for __array_function__ overrides.
All arguments are required, and can only be passed by position.
Parameters
----------
+ dispatcher : function or None
+ The dispatcher function that returns a single sequence-like object
+ of all arguments relevant. It must have the same signature (except
+ the default values) as the actual implementation.
+ If ``None``, this is a ``like=`` dispatcher and the
+ ``_ArrayFunctionDispatcher`` must be called with ``like`` as the
+ first (additional and positional) argument.
implementation : function
- Function that implements the operation on NumPy array without
- overrides when called like ``implementation(*args, **kwargs)``.
- public_api : function
- Function exposed by NumPy's public API originally called like
- ``public_api(*args, **kwargs)`` on which arguments are now being
- checked.
- relevant_args : iterable
- Iterable of arguments to check for __array_function__ methods.
- args : tuple
- Arbitrary positional arguments originally passed into ``public_api``.
- kwargs : dict
- Arbitrary keyword arguments originally passed into ``public_api``.
-
- Returns
- -------
- Result from calling ``implementation()`` or an ``__array_function__``
- method, as appropriate.
+ Function that implements the operation on NumPy arrays without
+ overrides. Arguments passed calling the ``_ArrayFunctionDispatcher``
+ will be forwarded to this (and the ``dispatcher``) as if using
+ ``*args, **kwargs``.
- Raises
- ------
- TypeError : if no implementation is found.
+ Attributes
+ ----------
+ _implementation : function
+ The original implementation passed in.
""")
-# exposed for testing purposes; used internally by implement_array_function
+# exposed for testing purposes; used internally by _ArrayFunctionDispatcher
add_docstring(
_get_implementing_args,
"""
@@ -110,7 +102,7 @@ def verify_matching_signatures(implementation, dispatcher):
'default argument values')
-def array_function_dispatch(dispatcher, module=None, verify=True,
+def array_function_dispatch(dispatcher=None, module=None, verify=True,
docs_from_dispatcher=False):
"""Decorator for adding dispatch with the __array_function__ protocol.
@@ -118,10 +110,14 @@ def array_function_dispatch(dispatcher, module=None, verify=True,
Parameters
----------
- dispatcher : callable
+ dispatcher : callable or None
Function that when called like ``dispatcher(*args, **kwargs)`` with
arguments from the NumPy function call returns an iterable of
array-like arguments to check for ``__array_function__``.
+
+ If `None`, the first argument is used as the single `like=` argument
+ and not passed on. A function implementing `like=` must call its
+ dispatcher with `like` as the first non-keyword argument.
module : str, optional
__module__ attribute to set on new function, e.g., ``module='numpy'``.
By default, module is copied from the decorated function.
@@ -141,58 +137,32 @@ def array_function_dispatch(dispatcher, module=None, verify=True,
Returns
-------
Function suitable for decorating the implementation of a NumPy function.
- """
-
- if not ARRAY_FUNCTION_ENABLED:
- def decorator(implementation):
- if docs_from_dispatcher:
- add_docstring(implementation, dispatcher.__doc__)
- if module is not None:
- implementation.__module__ = module
- return implementation
- return decorator
+ """
def decorator(implementation):
if verify:
- verify_matching_signatures(implementation, dispatcher)
+ if dispatcher is not None:
+ verify_matching_signatures(implementation, dispatcher)
+ else:
+ # Using __code__ directly similar to verify_matching_signature
+ co = implementation.__code__
+ last_arg = co.co_argcount + co.co_kwonlyargcount - 1
+ last_arg = co.co_varnames[last_arg]
+ if last_arg != "like" or co.co_kwonlyargcount == 0:
+ raise RuntimeError(
+ "__array_function__ expects `like=` to be the last "
+ "argument and a keyword-only argument. "
+ f"{implementation} does not seem to comply.")
if docs_from_dispatcher:
add_docstring(implementation, dispatcher.__doc__)
- @functools.wraps(implementation)
- def public_api(*args, **kwargs):
- try:
- relevant_args = dispatcher(*args, **kwargs)
- except TypeError as exc:
- # Try to clean up a signature related TypeError. Such an
- # error will be something like:
- # dispatcher.__name__() got an unexpected keyword argument
- #
- # So replace the dispatcher name in this case. In principle
- # TypeErrors may be raised from _within_ the dispatcher, so
- # we check that the traceback contains a string that starts
- # with the name. (In principle we could also check the
- # traceback length, as it would be deeper.)
- msg = exc.args[0]
- disp_name = dispatcher.__name__
- if not isinstance(msg, str) or not msg.startswith(disp_name):
- raise
-
- # Replace with the correct name and re-raise:
- new_msg = msg.replace(disp_name, public_api.__name__)
- raise TypeError(new_msg) from None
-
- return implement_array_function(
- implementation, public_api, relevant_args, args, kwargs)
-
- public_api.__code__ = public_api.__code__.replace(
- co_name=implementation.__name__,
- co_filename='<__array_function__ internals>')
+ public_api = _ArrayFunctionDispatcher(dispatcher, implementation)
+ public_api = functools.wraps(implementation)(public_api)
+
if module is not None:
public_api.__module__ = module
- public_api._implementation = implementation
-
ARRAY_FUNCTIONS.add(public_api)
return public_api