diff options
Diffstat (limited to 'numpy/core/overrides.py')
-rw-r--r-- | numpy/core/overrides.py | 112 |
1 files changed, 41 insertions, 71 deletions
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py index 46e1fbe2c..6403e65b0 100644 --- a/numpy/core/overrides.py +++ b/numpy/core/overrides.py @@ -6,14 +6,11 @@ import os from .._utils import set_module from .._utils._inspect import getargspec from numpy.core._multiarray_umath import ( - add_docstring, implement_array_function, _get_implementing_args) + add_docstring, _get_implementing_args, _ArrayFunctionDispatcher) ARRAY_FUNCTIONS = set() -ARRAY_FUNCTION_ENABLED = bool( - int(os.environ.get('NUMPY_EXPERIMENTAL_ARRAY_FUNCTION', 1))) - array_function_like_doc = ( """like : array_like, optional Reference object to allow the creation of arrays which are not @@ -33,40 +30,35 @@ def set_array_function_like_doc(public_api): add_docstring( - implement_array_function, + _ArrayFunctionDispatcher, """ - Implement a function with checks for __array_function__ overrides. + Class to wrap functions with checks for __array_function__ overrides. All arguments are required, and can only be passed by position. Parameters ---------- + dispatcher : function or None + The dispatcher function that returns a single sequence-like object + of all arguments relevant. It must have the same signature (except + the default values) as the actual implementation. + If ``None``, this is a ``like=`` dispatcher and the + ``_ArrayFunctionDispatcher`` must be called with ``like`` as the + first (additional and positional) argument. implementation : function - Function that implements the operation on NumPy array without - overrides when called like ``implementation(*args, **kwargs)``. - public_api : function - Function exposed by NumPy's public API originally called like - ``public_api(*args, **kwargs)`` on which arguments are now being - checked. - relevant_args : iterable - Iterable of arguments to check for __array_function__ methods. - args : tuple - Arbitrary positional arguments originally passed into ``public_api``. - kwargs : dict - Arbitrary keyword arguments originally passed into ``public_api``. - - Returns - ------- - Result from calling ``implementation()`` or an ``__array_function__`` - method, as appropriate. + Function that implements the operation on NumPy arrays without + overrides. Arguments passed calling the ``_ArrayFunctionDispatcher`` + will be forwarded to this (and the ``dispatcher``) as if using + ``*args, **kwargs``. - Raises - ------ - TypeError : if no implementation is found. + Attributes + ---------- + _implementation : function + The original implementation passed in. """) -# exposed for testing purposes; used internally by implement_array_function +# exposed for testing purposes; used internally by _ArrayFunctionDispatcher add_docstring( _get_implementing_args, """ @@ -110,7 +102,7 @@ def verify_matching_signatures(implementation, dispatcher): 'default argument values') -def array_function_dispatch(dispatcher, module=None, verify=True, +def array_function_dispatch(dispatcher=None, module=None, verify=True, docs_from_dispatcher=False): """Decorator for adding dispatch with the __array_function__ protocol. @@ -118,10 +110,14 @@ def array_function_dispatch(dispatcher, module=None, verify=True, Parameters ---------- - dispatcher : callable + dispatcher : callable or None Function that when called like ``dispatcher(*args, **kwargs)`` with arguments from the NumPy function call returns an iterable of array-like arguments to check for ``__array_function__``. + + If `None`, the first argument is used as the single `like=` argument + and not passed on. A function implementing `like=` must call its + dispatcher with `like` as the first non-keyword argument. module : str, optional __module__ attribute to set on new function, e.g., ``module='numpy'``. By default, module is copied from the decorated function. @@ -141,58 +137,32 @@ def array_function_dispatch(dispatcher, module=None, verify=True, Returns ------- Function suitable for decorating the implementation of a NumPy function. - """ - - if not ARRAY_FUNCTION_ENABLED: - def decorator(implementation): - if docs_from_dispatcher: - add_docstring(implementation, dispatcher.__doc__) - if module is not None: - implementation.__module__ = module - return implementation - return decorator + """ def decorator(implementation): if verify: - verify_matching_signatures(implementation, dispatcher) + if dispatcher is not None: + verify_matching_signatures(implementation, dispatcher) + else: + # Using __code__ directly similar to verify_matching_signature + co = implementation.__code__ + last_arg = co.co_argcount + co.co_kwonlyargcount - 1 + last_arg = co.co_varnames[last_arg] + if last_arg != "like" or co.co_kwonlyargcount == 0: + raise RuntimeError( + "__array_function__ expects `like=` to be the last " + "argument and a keyword-only argument. " + f"{implementation} does not seem to comply.") if docs_from_dispatcher: add_docstring(implementation, dispatcher.__doc__) - @functools.wraps(implementation) - def public_api(*args, **kwargs): - try: - relevant_args = dispatcher(*args, **kwargs) - except TypeError as exc: - # Try to clean up a signature related TypeError. Such an - # error will be something like: - # dispatcher.__name__() got an unexpected keyword argument - # - # So replace the dispatcher name in this case. In principle - # TypeErrors may be raised from _within_ the dispatcher, so - # we check that the traceback contains a string that starts - # with the name. (In principle we could also check the - # traceback length, as it would be deeper.) - msg = exc.args[0] - disp_name = dispatcher.__name__ - if not isinstance(msg, str) or not msg.startswith(disp_name): - raise - - # Replace with the correct name and re-raise: - new_msg = msg.replace(disp_name, public_api.__name__) - raise TypeError(new_msg) from None - - return implement_array_function( - implementation, public_api, relevant_args, args, kwargs) - - public_api.__code__ = public_api.__code__.replace( - co_name=implementation.__name__, - co_filename='<__array_function__ internals>') + public_api = _ArrayFunctionDispatcher(dispatcher, implementation) + public_api = functools.wraps(implementation)(public_api) + if module is not None: public_api.__module__ = module - public_api._implementation = implementation - ARRAY_FUNCTIONS.add(public_api) return public_api |