diff options
Diffstat (limited to 'numpy/core/overrides.py')
-rw-r--r-- | numpy/core/overrides.py | 31 |
1 files changed, 28 insertions, 3 deletions
diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py index 1cc1ff8d8..bddf14310 100644 --- a/numpy/core/overrides.py +++ b/numpy/core/overrides.py @@ -6,7 +6,7 @@ import collections import functools import os -from numpy.core._multiarray_umath import ndarray +from numpy.core._multiarray_umath import add_docstring, ndarray from numpy.compat._inspect import getargspec @@ -168,7 +168,8 @@ def set_module(module): return decorator -def array_function_dispatch(dispatcher, module=None, verify=True): +def array_function_dispatch(dispatcher, module=None, verify=True, + docs_from_dispatcher=False): """Decorator for adding dispatch with the __array_function__ protocol. See NEP-18 for example usage. @@ -190,6 +191,10 @@ def array_function_dispatch(dispatcher, module=None, verify=True): if the dispatcher's signature needs to deviate for some particular reason, e.g., because the function has a signature like ``func(*args, **kwargs)``. + docs_from_dispatcher : bool, optional + If True, copy docs from the dispatcher function onto the dispatched + function, rather than from the implementation. This is useful for + functions defined in C, which otherwise don't have docstrings. Returns ------- @@ -198,12 +203,21 @@ def array_function_dispatch(dispatcher, module=None, verify=True): if not ENABLE_ARRAY_FUNCTION: # __array_function__ requires an explicit opt-in for now - return set_module(module) + def decorator(implementation): + if module is not None: + implementation.__module__ = module + if docs_from_dispatcher: + add_docstring(implementation, dispatcher.__doc__) + return implementation + return decorator def decorator(implementation): if verify: verify_matching_signatures(implementation, dispatcher) + if docs_from_dispatcher: + add_docstring(implementation, dispatcher.__doc__) + @functools.wraps(implementation) def public_api(*args, **kwargs): relevant_args = dispatcher(*args, **kwargs) @@ -220,3 +234,14 @@ def array_function_dispatch(dispatcher, module=None, verify=True): return public_api return decorator + + +def array_function_from_dispatcher( + implementation, module=None, verify=True, docs_from_dispatcher=True): + """Like array_function_dispatcher, but with function arguments flipped.""" + + def decorator(dispatcher): + return array_function_dispatch( + dispatcher, module, verify=verify, + docs_from_dispatcher=docs_from_dispatcher)(implementation) + return decorator |