diff options
author | Sebastian Berg <sebastian@sipsolutions.net> | 2022-11-29 17:56:10 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-11-29 17:56:10 +0100 |
commit | 2cf8c40727b08c029cdaf7fce803d031a60499a2 (patch) | |
tree | a59d6662d17e7bfad1e94b36ab4a77930574c3a0 /numpy/testing | |
parent | 7f0f045625022c3f816911cd80f8635ac2a36f21 (diff) | |
parent | 1a8d3ca45f0a7294784bc200ec436dc8563f654a (diff) | |
download | numpy-2cf8c40727b08c029cdaf7fce803d031a60499a2.tar.gz |
Merge pull request #22533 from ngoldbaum/ufunc-and-function-listing
API: Add numpy.testing.overrides to aid testing of custom array containers
Diffstat (limited to 'numpy/testing')
-rw-r--r-- | numpy/testing/overrides.py | 82 |
1 files changed, 82 insertions, 0 deletions
diff --git a/numpy/testing/overrides.py b/numpy/testing/overrides.py new file mode 100644 index 000000000..d20ed60e5 --- /dev/null +++ b/numpy/testing/overrides.py @@ -0,0 +1,82 @@ +"""Tools for testing implementations of __array_function__ and ufunc overrides + + +""" + +from numpy.core.overrides import ARRAY_FUNCTIONS as _array_functions +from numpy import ufunc as _ufunc +import numpy.core.umath as _umath + +def get_overridable_numpy_ufuncs(): + """List all numpy ufuncs overridable via `__array_ufunc__` + + Parameters + ---------- + None + + Returns + ------- + set + A set containing all overridable ufuncs in the public numpy API. + """ + ufuncs = {obj for obj in _umath.__dict__.values() + if isinstance(obj, _ufunc)} + + +def allows_array_ufunc_override(func): + """Determine if a function can be overriden via `__array_ufunc__` + + Parameters + ---------- + func : callable + Function that may be overridable via `__array_ufunc__` + + Returns + ------- + bool + `True` if `func` is overridable via `__array_ufunc__` and + `False` otherwise. + + Note + ---- + This function is equivalent to `isinstance(func, np.ufunc)` and + will work correctly for ufuncs defined outside of Numpy. + + """ + return isinstance(func, np.ufunc) + + +def get_overridable_numpy_array_functions(): + """List all numpy functions overridable via `__array_function__` + + Parameters + ---------- + None + + Returns + ------- + set + A set containing all functions in the public numpy API that are + overridable via `__array_function__`. + + """ + # 'import numpy' doesn't import recfunctions, so make sure it's imported + # so ufuncs defined there show up in the ufunc listing + from numpy.lib import recfunctions + return _array_functions.copy() + +def allows_array_function_override(func): + """Determine if a Numpy function can be overriden via `__array_function__` + + Parameters + ---------- + func : callable + Function that may be overridable via `__array_function__` + + Returns + ------- + bool + `True` if `func` is a function in the Numpy API that is + overridable via `__array_function__` and `False` otherwise. + """ + return func in _array_functions |