diff options
-rw-r--r-- | doc/neps/nep-0013-ufunc-overrides.rst | 11 | ||||
-rw-r--r-- | doc/release/upcoming_changes/23240.compatibility.rst | 10 | ||||
-rw-r--r-- | doc/source/reference/arrays.classes.rst | 5 | ||||
-rw-r--r-- | numpy/core/src/multiarray/methods.c | 14 | ||||
-rw-r--r-- | numpy/core/src/multiarray/multiarraymodule.c | 5 | ||||
-rw-r--r-- | numpy/core/src/multiarray/multiarraymodule.h | 1 | ||||
-rw-r--r-- | numpy/core/src/umath/override.c | 16 | ||||
-rw-r--r-- | numpy/core/src/umath/override.h | 2 | ||||
-rw-r--r-- | numpy/core/src/umath/ufunc_object.c | 6 | ||||
-rw-r--r-- | numpy/core/tests/test_overrides.py | 13 | ||||
-rw-r--r-- | numpy/core/tests/test_umath.py | 73 |
11 files changed, 140 insertions, 16 deletions
diff --git a/doc/neps/nep-0013-ufunc-overrides.rst b/doc/neps/nep-0013-ufunc-overrides.rst index c132113db..d69af6924 100644 --- a/doc/neps/nep-0013-ufunc-overrides.rst +++ b/doc/neps/nep-0013-ufunc-overrides.rst @@ -20,6 +20,8 @@ NEP 13 — A mechanism for overriding Ufuncs :Date: 2017-03-31 :Status: Final +:Updated: 2023-02-19 +:Author: Roy Smart Executive summary ================= @@ -173,12 +175,12 @@ where in all current cases only a single output makes sense). The function dispatch proceeds as follows: -- If one of the input or output arguments implements +- If one of the input, output, or ``where`` arguments implements ``__array_ufunc__``, it is executed instead of the ufunc. - If more than one of the arguments implements ``__array_ufunc__``, they are tried in the following order: subclasses before superclasses, - inputs before outputs, otherwise left to right. + inputs before outputs, outputs before ``where``, otherwise left to right. - The first ``__array_ufunc__`` method returning something else than :obj:`NotImplemented` determines the return value of the Ufunc. @@ -326,7 +328,10 @@ equivalent to:: def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): # Cannot handle items that have __array_ufunc__ (other than our own). outputs = kwargs.get('out', ()) - for item in inputs + outputs: + objs = inputs + outputs + if "where" in kwargs: + objs = objs + (kwargs["where"], ) + for item in objs: if (hasattr(item, '__array_ufunc__') and type(item).__array_ufunc__ is not ndarray.__array_ufunc__): return NotImplemented diff --git a/doc/release/upcoming_changes/23240.compatibility.rst b/doc/release/upcoming_changes/23240.compatibility.rst new file mode 100644 index 000000000..28536a020 --- /dev/null +++ b/doc/release/upcoming_changes/23240.compatibility.rst @@ -0,0 +1,10 @@ +Array-likes that define ``__array_ufunc__`` can now override ufuncs if used as ``where`` +---------------------------------------------------------------------------------------- +If the ``where`` keyword argument of a :class:`numpy.ufunc` is a subclass of +:class:`numpy.ndarray` or is a duck type that defines +:func:`numpy.class.__array_ufunc__` it can override the behavior of the ufunc +using the same mechanism as the input and output arguments. +Note that for this to work properly, the ``where.__array_ufunc__`` +implementation will have to unwrap the ``where`` argument to pass it into the +default implementation of the ``ufunc`` or, for :class:`numpy.ndarray` +subclasses before using ``super().__array_ufunc__``.
\ No newline at end of file diff --git a/doc/source/reference/arrays.classes.rst b/doc/source/reference/arrays.classes.rst index 2cce595e0..34da83670 100644 --- a/doc/source/reference/arrays.classes.rst +++ b/doc/source/reference/arrays.classes.rst @@ -71,10 +71,11 @@ NumPy provides several hooks that classes can customize: The method should return either the result of the operation, or :obj:`NotImplemented` if the operation requested is not implemented. - If one of the input or output arguments has a :func:`__array_ufunc__` + If one of the input, output, or ``where`` arguments has a :func:`__array_ufunc__` method, it is executed *instead* of the ufunc. If more than one of the arguments implements :func:`__array_ufunc__`, they are tried in the - order: subclasses before superclasses, inputs before outputs, otherwise + order: subclasses before superclasses, inputs before outputs, + outputs before ``where``, otherwise left to right. The first routine returning something other than :obj:`NotImplemented` determines the result. If all of the :func:`__array_ufunc__` operations return :obj:`NotImplemented`, a diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c index f518f3a02..93b290020 100644 --- a/numpy/core/src/multiarray/methods.c +++ b/numpy/core/src/multiarray/methods.c @@ -28,6 +28,7 @@ #include "strfuncs.h" #include "array_assign.h" #include "npy_dlpack.h" +#include "multiarraymodule.h" #include "methods.h" #include "alloc.h" @@ -1102,7 +1103,7 @@ any_array_ufunc_overrides(PyObject *args, PyObject *kwds) int nin, nout; PyObject *out_kwd_obj; PyObject *fast; - PyObject **in_objs, **out_objs; + PyObject **in_objs, **out_objs, *where_obj; /* check inputs */ nin = PyTuple_Size(args); @@ -1133,6 +1134,17 @@ any_array_ufunc_overrides(PyObject *args, PyObject *kwds) } } Py_DECREF(out_kwd_obj); + /* check where if it exists */ + where_obj = PyDict_GetItemWithError(kwds, npy_ma_str_where); + if (where_obj == NULL) { + if (PyErr_Occurred()) { + return -1; + } + } else { + if (PyUFunc_HasOverride(where_obj)){ + return 1; + } + } return 0; } diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index e85f8affa..ac8e641b7 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -4843,6 +4843,7 @@ NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_axis1 = NULL; NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_axis2 = NULL; NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_like = NULL; NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_numpy = NULL; +NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_where = NULL; static int intern_strings(void) @@ -4899,6 +4900,10 @@ intern_strings(void) if (npy_ma_str_numpy == NULL) { return -1; } + npy_ma_str_where = PyUnicode_InternFromString("where"); + if (npy_ma_str_where == NULL) { + return -1; + } return 0; } diff --git a/numpy/core/src/multiarray/multiarraymodule.h b/numpy/core/src/multiarray/multiarraymodule.h index 992acd09f..9ba2a1831 100644 --- a/numpy/core/src/multiarray/multiarraymodule.h +++ b/numpy/core/src/multiarray/multiarraymodule.h @@ -16,5 +16,6 @@ NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_axis1; NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_axis2; NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_like; NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_numpy; +NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_where; #endif /* NUMPY_CORE_SRC_MULTIARRAY_MULTIARRAYMODULE_H_ */ diff --git a/numpy/core/src/umath/override.c b/numpy/core/src/umath/override.c index d247c2639..167164163 100644 --- a/numpy/core/src/umath/override.c +++ b/numpy/core/src/umath/override.c @@ -23,18 +23,19 @@ * Returns -1 on failure. */ static int -get_array_ufunc_overrides(PyObject *in_args, PyObject *out_args, +get_array_ufunc_overrides(PyObject *in_args, PyObject *out_args, PyObject *wheremask_obj, PyObject **with_override, PyObject **methods) { int i; int num_override_args = 0; - int narg, nout; + int narg, nout, nwhere; narg = (int)PyTuple_GET_SIZE(in_args); /* It is valid for out_args to be NULL: */ nout = (out_args != NULL) ? (int)PyTuple_GET_SIZE(out_args) : 0; + nwhere = (wheremask_obj != NULL) ? 1: 0; - for (i = 0; i < narg + nout; ++i) { + for (i = 0; i < narg + nout + nwhere; ++i) { PyObject *obj; int j; int new_class = 1; @@ -42,9 +43,12 @@ get_array_ufunc_overrides(PyObject *in_args, PyObject *out_args, if (i < narg) { obj = PyTuple_GET_ITEM(in_args, i); } - else { + else if (i < narg + nout){ obj = PyTuple_GET_ITEM(out_args, i - narg); } + else { + obj = wheremask_obj; + } /* * Have we seen this class before? If so, ignore. */ @@ -208,7 +212,7 @@ copy_positional_args_to_kwargs(const char **keywords, */ NPY_NO_EXPORT int PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method, - PyObject *in_args, PyObject *out_args, + PyObject *in_args, PyObject *out_args, PyObject *wheremask_obj, PyObject *const *args, Py_ssize_t len_args, PyObject *kwnames, PyObject **result) { @@ -227,7 +231,7 @@ PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method, * Check inputs for overrides */ num_override_args = get_array_ufunc_overrides( - in_args, out_args, with_override, array_ufunc_methods); + in_args, out_args, wheremask_obj, with_override, array_ufunc_methods); if (num_override_args == -1) { goto fail; } diff --git a/numpy/core/src/umath/override.h b/numpy/core/src/umath/override.h index 4e9a323ca..20621bb19 100644 --- a/numpy/core/src/umath/override.h +++ b/numpy/core/src/umath/override.h @@ -6,7 +6,7 @@ NPY_NO_EXPORT int PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method, - PyObject *in_args, PyObject *out_args, + PyObject *in_args, PyObject *out_args, PyObject *wheremask_obj, PyObject *const *args, Py_ssize_t len_args, PyObject *kwnames, PyObject **result); diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c index a159003de..a5e8f4cbe 100644 --- a/numpy/core/src/umath/ufunc_object.c +++ b/numpy/core/src/umath/ufunc_object.c @@ -4071,7 +4071,7 @@ PyUFunc_GenericReduction(PyUFuncObject *ufunc, /* We now have all the information required to check for Overrides */ PyObject *override = NULL; int errval = PyUFunc_CheckOverride(ufunc, _reduce_type[operation], - full_args.in, full_args.out, args, len_args, kwnames, &override); + full_args.in, full_args.out, wheremask_obj, args, len_args, kwnames, &override); if (errval) { return NULL; } @@ -4843,7 +4843,7 @@ ufunc_generic_fastcall(PyUFuncObject *ufunc, /* We now have all the information required to check for Overrides */ PyObject *override = NULL; errval = PyUFunc_CheckOverride(ufunc, method, - full_args.in, full_args.out, + full_args.in, full_args.out, where_obj, args, len_args, kwnames, &override); if (errval) { goto fail; @@ -6261,7 +6261,7 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args) return NULL; } errval = PyUFunc_CheckOverride(ufunc, "at", - args, NULL, NULL, 0, NULL, &override); + args, NULL, NULL, NULL, 0, NULL, &override); if (errval) { return NULL; diff --git a/numpy/core/tests/test_overrides.py b/numpy/core/tests/test_overrides.py index ae4cddb0e..25f551f6f 100644 --- a/numpy/core/tests/test_overrides.py +++ b/numpy/core/tests/test_overrides.py @@ -241,6 +241,19 @@ class TestArrayFunctionDispatch: with assert_raises_regex(TypeError, 'no implementation found'): dispatched_one_arg(array) + def test_where_dispatch(self): + + class DuckArray: + def __array_function__(self, ufunc, method, *inputs, **kwargs): + return "overridden" + + array = np.array(1) + duck_array = DuckArray() + + result = np.std(array, where=duck_array) + + assert_equal(result, "overridden") + class TestVerifyMatchingSignatures: diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py index 13f7375c2..0ed64d72a 100644 --- a/numpy/core/tests/test_umath.py +++ b/numpy/core/tests/test_umath.py @@ -3497,6 +3497,79 @@ class TestSpecialMethods: assert_raises(ValueError, np.modf, a, out=('one', 'two', 'three')) assert_raises(ValueError, np.modf, a, out=('one',)) + def test_ufunc_override_where(self): + + class OverriddenArrayOld(np.ndarray): + + def _unwrap(self, objs): + cls = type(self) + result = [] + for obj in objs: + if isinstance(obj, cls): + obj = np.array(obj) + elif type(obj) != np.ndarray: + return NotImplemented + result.append(obj) + return result + + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): + + inputs = self._unwrap(inputs) + if inputs is NotImplemented: + return NotImplemented + + kwargs = kwargs.copy() + if "out" in kwargs: + kwargs["out"] = self._unwrap(kwargs["out"]) + if kwargs["out"] is NotImplemented: + return NotImplemented + + r = super().__array_ufunc__(ufunc, method, *inputs, **kwargs) + if r is not NotImplemented: + r = r.view(type(self)) + + return r + + class OverriddenArrayNew(OverriddenArrayOld): + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): + + kwargs = kwargs.copy() + if "where" in kwargs: + kwargs["where"] = self._unwrap((kwargs["where"], )) + if kwargs["where"] is NotImplemented: + return NotImplemented + else: + kwargs["where"] = kwargs["where"][0] + + r = super().__array_ufunc__(ufunc, method, *inputs, **kwargs) + if r is not NotImplemented: + r = r.view(type(self)) + + return r + + ufunc = np.negative + + array = np.array([1, 2, 3]) + where = np.array([True, False, True]) + expected = ufunc(array, where=where) + + with pytest.raises(TypeError): + ufunc(array, where=where.view(OverriddenArrayOld)) + + result_1 = ufunc( + array, + where=where.view(OverriddenArrayNew) + ) + assert isinstance(result_1, OverriddenArrayNew) + assert np.all(np.array(result_1) == expected, where=where) + + result_2 = ufunc( + array.view(OverriddenArrayNew), + where=where.view(OverriddenArrayNew) + ) + assert isinstance(result_2, OverriddenArrayNew) + assert np.all(np.array(result_2) == expected, where=where) + def test_ufunc_override_exception(self): class A: |