summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/neps/nep-0013-ufunc-overrides.rst11
-rw-r--r--doc/release/upcoming_changes/23240.compatibility.rst10
-rw-r--r--doc/source/reference/arrays.classes.rst5
-rw-r--r--numpy/core/src/multiarray/methods.c14
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.c5
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.h1
-rw-r--r--numpy/core/src/umath/override.c16
-rw-r--r--numpy/core/src/umath/override.h2
-rw-r--r--numpy/core/src/umath/ufunc_object.c6
-rw-r--r--numpy/core/tests/test_overrides.py13
-rw-r--r--numpy/core/tests/test_umath.py73
11 files changed, 140 insertions, 16 deletions
diff --git a/doc/neps/nep-0013-ufunc-overrides.rst b/doc/neps/nep-0013-ufunc-overrides.rst
index c132113db..d69af6924 100644
--- a/doc/neps/nep-0013-ufunc-overrides.rst
+++ b/doc/neps/nep-0013-ufunc-overrides.rst
@@ -20,6 +20,8 @@ NEP 13 — A mechanism for overriding Ufuncs
:Date: 2017-03-31
:Status: Final
+:Updated: 2023-02-19
+:Author: Roy Smart
Executive summary
=================
@@ -173,12 +175,12 @@ where in all current cases only a single output makes sense).
The function dispatch proceeds as follows:
-- If one of the input or output arguments implements
+- If one of the input, output, or ``where`` arguments implements
``__array_ufunc__``, it is executed instead of the ufunc.
- If more than one of the arguments implements ``__array_ufunc__``,
they are tried in the following order: subclasses before superclasses,
- inputs before outputs, otherwise left to right.
+ inputs before outputs, outputs before ``where``, otherwise left to right.
- The first ``__array_ufunc__`` method returning something else than
:obj:`NotImplemented` determines the return value of the Ufunc.
@@ -326,7 +328,10 @@ equivalent to::
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
# Cannot handle items that have __array_ufunc__ (other than our own).
outputs = kwargs.get('out', ())
- for item in inputs + outputs:
+ objs = inputs + outputs
+ if "where" in kwargs:
+ objs = objs + (kwargs["where"], )
+ for item in objs:
if (hasattr(item, '__array_ufunc__') and
type(item).__array_ufunc__ is not ndarray.__array_ufunc__):
return NotImplemented
diff --git a/doc/release/upcoming_changes/23240.compatibility.rst b/doc/release/upcoming_changes/23240.compatibility.rst
new file mode 100644
index 000000000..28536a020
--- /dev/null
+++ b/doc/release/upcoming_changes/23240.compatibility.rst
@@ -0,0 +1,10 @@
+Array-likes that define ``__array_ufunc__`` can now override ufuncs if used as ``where``
+----------------------------------------------------------------------------------------
+If the ``where`` keyword argument of a :class:`numpy.ufunc` is a subclass of
+:class:`numpy.ndarray` or is a duck type that defines
+:func:`numpy.class.__array_ufunc__` it can override the behavior of the ufunc
+using the same mechanism as the input and output arguments.
+Note that for this to work properly, the ``where.__array_ufunc__``
+implementation will have to unwrap the ``where`` argument to pass it into the
+default implementation of the ``ufunc`` or, for :class:`numpy.ndarray`
+subclasses before using ``super().__array_ufunc__``. \ No newline at end of file
diff --git a/doc/source/reference/arrays.classes.rst b/doc/source/reference/arrays.classes.rst
index 2cce595e0..34da83670 100644
--- a/doc/source/reference/arrays.classes.rst
+++ b/doc/source/reference/arrays.classes.rst
@@ -71,10 +71,11 @@ NumPy provides several hooks that classes can customize:
The method should return either the result of the operation, or
:obj:`NotImplemented` if the operation requested is not implemented.
- If one of the input or output arguments has a :func:`__array_ufunc__`
+ If one of the input, output, or ``where`` arguments has a :func:`__array_ufunc__`
method, it is executed *instead* of the ufunc. If more than one of the
arguments implements :func:`__array_ufunc__`, they are tried in the
- order: subclasses before superclasses, inputs before outputs, otherwise
+ order: subclasses before superclasses, inputs before outputs,
+ outputs before ``where``, otherwise
left to right. The first routine returning something other than
:obj:`NotImplemented` determines the result. If all of the
:func:`__array_ufunc__` operations return :obj:`NotImplemented`, a
diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c
index f518f3a02..93b290020 100644
--- a/numpy/core/src/multiarray/methods.c
+++ b/numpy/core/src/multiarray/methods.c
@@ -28,6 +28,7 @@
#include "strfuncs.h"
#include "array_assign.h"
#include "npy_dlpack.h"
+#include "multiarraymodule.h"
#include "methods.h"
#include "alloc.h"
@@ -1102,7 +1103,7 @@ any_array_ufunc_overrides(PyObject *args, PyObject *kwds)
int nin, nout;
PyObject *out_kwd_obj;
PyObject *fast;
- PyObject **in_objs, **out_objs;
+ PyObject **in_objs, **out_objs, *where_obj;
/* check inputs */
nin = PyTuple_Size(args);
@@ -1133,6 +1134,17 @@ any_array_ufunc_overrides(PyObject *args, PyObject *kwds)
}
}
Py_DECREF(out_kwd_obj);
+ /* check where if it exists */
+ where_obj = PyDict_GetItemWithError(kwds, npy_ma_str_where);
+ if (where_obj == NULL) {
+ if (PyErr_Occurred()) {
+ return -1;
+ }
+ } else {
+ if (PyUFunc_HasOverride(where_obj)){
+ return 1;
+ }
+ }
return 0;
}
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c
index e85f8affa..ac8e641b7 100644
--- a/numpy/core/src/multiarray/multiarraymodule.c
+++ b/numpy/core/src/multiarray/multiarraymodule.c
@@ -4843,6 +4843,7 @@ NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_axis1 = NULL;
NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_axis2 = NULL;
NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_like = NULL;
NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_numpy = NULL;
+NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_where = NULL;
static int
intern_strings(void)
@@ -4899,6 +4900,10 @@ intern_strings(void)
if (npy_ma_str_numpy == NULL) {
return -1;
}
+ npy_ma_str_where = PyUnicode_InternFromString("where");
+ if (npy_ma_str_where == NULL) {
+ return -1;
+ }
return 0;
}
diff --git a/numpy/core/src/multiarray/multiarraymodule.h b/numpy/core/src/multiarray/multiarraymodule.h
index 992acd09f..9ba2a1831 100644
--- a/numpy/core/src/multiarray/multiarraymodule.h
+++ b/numpy/core/src/multiarray/multiarraymodule.h
@@ -16,5 +16,6 @@ NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_axis1;
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_axis2;
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_like;
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_numpy;
+NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_where;
#endif /* NUMPY_CORE_SRC_MULTIARRAY_MULTIARRAYMODULE_H_ */
diff --git a/numpy/core/src/umath/override.c b/numpy/core/src/umath/override.c
index d247c2639..167164163 100644
--- a/numpy/core/src/umath/override.c
+++ b/numpy/core/src/umath/override.c
@@ -23,18 +23,19 @@
* Returns -1 on failure.
*/
static int
-get_array_ufunc_overrides(PyObject *in_args, PyObject *out_args,
+get_array_ufunc_overrides(PyObject *in_args, PyObject *out_args, PyObject *wheremask_obj,
PyObject **with_override, PyObject **methods)
{
int i;
int num_override_args = 0;
- int narg, nout;
+ int narg, nout, nwhere;
narg = (int)PyTuple_GET_SIZE(in_args);
/* It is valid for out_args to be NULL: */
nout = (out_args != NULL) ? (int)PyTuple_GET_SIZE(out_args) : 0;
+ nwhere = (wheremask_obj != NULL) ? 1: 0;
- for (i = 0; i < narg + nout; ++i) {
+ for (i = 0; i < narg + nout + nwhere; ++i) {
PyObject *obj;
int j;
int new_class = 1;
@@ -42,9 +43,12 @@ get_array_ufunc_overrides(PyObject *in_args, PyObject *out_args,
if (i < narg) {
obj = PyTuple_GET_ITEM(in_args, i);
}
- else {
+ else if (i < narg + nout){
obj = PyTuple_GET_ITEM(out_args, i - narg);
}
+ else {
+ obj = wheremask_obj;
+ }
/*
* Have we seen this class before? If so, ignore.
*/
@@ -208,7 +212,7 @@ copy_positional_args_to_kwargs(const char **keywords,
*/
NPY_NO_EXPORT int
PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
- PyObject *in_args, PyObject *out_args,
+ PyObject *in_args, PyObject *out_args, PyObject *wheremask_obj,
PyObject *const *args, Py_ssize_t len_args, PyObject *kwnames,
PyObject **result)
{
@@ -227,7 +231,7 @@ PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
* Check inputs for overrides
*/
num_override_args = get_array_ufunc_overrides(
- in_args, out_args, with_override, array_ufunc_methods);
+ in_args, out_args, wheremask_obj, with_override, array_ufunc_methods);
if (num_override_args == -1) {
goto fail;
}
diff --git a/numpy/core/src/umath/override.h b/numpy/core/src/umath/override.h
index 4e9a323ca..20621bb19 100644
--- a/numpy/core/src/umath/override.h
+++ b/numpy/core/src/umath/override.h
@@ -6,7 +6,7 @@
NPY_NO_EXPORT int
PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
- PyObject *in_args, PyObject *out_args,
+ PyObject *in_args, PyObject *out_args, PyObject *wheremask_obj,
PyObject *const *args, Py_ssize_t len_args, PyObject *kwnames,
PyObject **result);
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c
index a159003de..a5e8f4cbe 100644
--- a/numpy/core/src/umath/ufunc_object.c
+++ b/numpy/core/src/umath/ufunc_object.c
@@ -4071,7 +4071,7 @@ PyUFunc_GenericReduction(PyUFuncObject *ufunc,
/* We now have all the information required to check for Overrides */
PyObject *override = NULL;
int errval = PyUFunc_CheckOverride(ufunc, _reduce_type[operation],
- full_args.in, full_args.out, args, len_args, kwnames, &override);
+ full_args.in, full_args.out, wheremask_obj, args, len_args, kwnames, &override);
if (errval) {
return NULL;
}
@@ -4843,7 +4843,7 @@ ufunc_generic_fastcall(PyUFuncObject *ufunc,
/* We now have all the information required to check for Overrides */
PyObject *override = NULL;
errval = PyUFunc_CheckOverride(ufunc, method,
- full_args.in, full_args.out,
+ full_args.in, full_args.out, where_obj,
args, len_args, kwnames, &override);
if (errval) {
goto fail;
@@ -6261,7 +6261,7 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
return NULL;
}
errval = PyUFunc_CheckOverride(ufunc, "at",
- args, NULL, NULL, 0, NULL, &override);
+ args, NULL, NULL, NULL, 0, NULL, &override);
if (errval) {
return NULL;
diff --git a/numpy/core/tests/test_overrides.py b/numpy/core/tests/test_overrides.py
index ae4cddb0e..25f551f6f 100644
--- a/numpy/core/tests/test_overrides.py
+++ b/numpy/core/tests/test_overrides.py
@@ -241,6 +241,19 @@ class TestArrayFunctionDispatch:
with assert_raises_regex(TypeError, 'no implementation found'):
dispatched_one_arg(array)
+ def test_where_dispatch(self):
+
+ class DuckArray:
+ def __array_function__(self, ufunc, method, *inputs, **kwargs):
+ return "overridden"
+
+ array = np.array(1)
+ duck_array = DuckArray()
+
+ result = np.std(array, where=duck_array)
+
+ assert_equal(result, "overridden")
+
class TestVerifyMatchingSignatures:
diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py
index 13f7375c2..0ed64d72a 100644
--- a/numpy/core/tests/test_umath.py
+++ b/numpy/core/tests/test_umath.py
@@ -3497,6 +3497,79 @@ class TestSpecialMethods:
assert_raises(ValueError, np.modf, a, out=('one', 'two', 'three'))
assert_raises(ValueError, np.modf, a, out=('one',))
+ def test_ufunc_override_where(self):
+
+ class OverriddenArrayOld(np.ndarray):
+
+ def _unwrap(self, objs):
+ cls = type(self)
+ result = []
+ for obj in objs:
+ if isinstance(obj, cls):
+ obj = np.array(obj)
+ elif type(obj) != np.ndarray:
+ return NotImplemented
+ result.append(obj)
+ return result
+
+ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
+
+ inputs = self._unwrap(inputs)
+ if inputs is NotImplemented:
+ return NotImplemented
+
+ kwargs = kwargs.copy()
+ if "out" in kwargs:
+ kwargs["out"] = self._unwrap(kwargs["out"])
+ if kwargs["out"] is NotImplemented:
+ return NotImplemented
+
+ r = super().__array_ufunc__(ufunc, method, *inputs, **kwargs)
+ if r is not NotImplemented:
+ r = r.view(type(self))
+
+ return r
+
+ class OverriddenArrayNew(OverriddenArrayOld):
+ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
+
+ kwargs = kwargs.copy()
+ if "where" in kwargs:
+ kwargs["where"] = self._unwrap((kwargs["where"], ))
+ if kwargs["where"] is NotImplemented:
+ return NotImplemented
+ else:
+ kwargs["where"] = kwargs["where"][0]
+
+ r = super().__array_ufunc__(ufunc, method, *inputs, **kwargs)
+ if r is not NotImplemented:
+ r = r.view(type(self))
+
+ return r
+
+ ufunc = np.negative
+
+ array = np.array([1, 2, 3])
+ where = np.array([True, False, True])
+ expected = ufunc(array, where=where)
+
+ with pytest.raises(TypeError):
+ ufunc(array, where=where.view(OverriddenArrayOld))
+
+ result_1 = ufunc(
+ array,
+ where=where.view(OverriddenArrayNew)
+ )
+ assert isinstance(result_1, OverriddenArrayNew)
+ assert np.all(np.array(result_1) == expected, where=where)
+
+ result_2 = ufunc(
+ array.view(OverriddenArrayNew),
+ where=where.view(OverriddenArrayNew)
+ )
+ assert isinstance(result_2, OverriddenArrayNew)
+ assert np.all(np.array(result_2) == expected, where=where)
+
def test_ufunc_override_exception(self):
class A: