summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorRoy Smart <roytsmart@gmail.com>2023-02-17 15:29:41 -0700
committerRoy Smart <roytsmart@gmail.com>2023-02-24 01:21:55 -0700
commitf3f108d313a8b8a4f7a90fb932867f17dc48b1f6 (patch)
tree95738be8cf2ae971bc40d6820328972bbf53a7f9 /numpy
parentd92cc2d1c7c7153525e03c4d10377714d85cfde6 (diff)
downloadnumpy-f3f108d313a8b8a4f7a90fb932867f17dc48b1f6.tar.gz
ENH: Modified `PyUFunc_CheckOverride` to allow the `where` argument to override `__array_ufunc__`.
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/methods.c14
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.c5
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.h1
-rw-r--r--numpy/core/src/umath/override.c16
-rw-r--r--numpy/core/src/umath/override.h2
-rw-r--r--numpy/core/src/umath/ufunc_object.c6
-rw-r--r--numpy/core/tests/test_overrides.py13
-rw-r--r--numpy/core/tests/test_umath.py73
8 files changed, 119 insertions, 11 deletions
diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c
index f518f3a02..93b290020 100644
--- a/numpy/core/src/multiarray/methods.c
+++ b/numpy/core/src/multiarray/methods.c
@@ -28,6 +28,7 @@
#include "strfuncs.h"
#include "array_assign.h"
#include "npy_dlpack.h"
+#include "multiarraymodule.h"
#include "methods.h"
#include "alloc.h"
@@ -1102,7 +1103,7 @@ any_array_ufunc_overrides(PyObject *args, PyObject *kwds)
int nin, nout;
PyObject *out_kwd_obj;
PyObject *fast;
- PyObject **in_objs, **out_objs;
+ PyObject **in_objs, **out_objs, *where_obj;
/* check inputs */
nin = PyTuple_Size(args);
@@ -1133,6 +1134,17 @@ any_array_ufunc_overrides(PyObject *args, PyObject *kwds)
}
}
Py_DECREF(out_kwd_obj);
+ /* check where if it exists */
+ where_obj = PyDict_GetItemWithError(kwds, npy_ma_str_where);
+ if (where_obj == NULL) {
+ if (PyErr_Occurred()) {
+ return -1;
+ }
+ } else {
+ if (PyUFunc_HasOverride(where_obj)){
+ return 1;
+ }
+ }
return 0;
}
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c
index e85f8affa..ac8e641b7 100644
--- a/numpy/core/src/multiarray/multiarraymodule.c
+++ b/numpy/core/src/multiarray/multiarraymodule.c
@@ -4843,6 +4843,7 @@ NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_axis1 = NULL;
NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_axis2 = NULL;
NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_like = NULL;
NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_numpy = NULL;
+NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_where = NULL;
static int
intern_strings(void)
@@ -4899,6 +4900,10 @@ intern_strings(void)
if (npy_ma_str_numpy == NULL) {
return -1;
}
+ npy_ma_str_where = PyUnicode_InternFromString("where");
+ if (npy_ma_str_where == NULL) {
+ return -1;
+ }
return 0;
}
diff --git a/numpy/core/src/multiarray/multiarraymodule.h b/numpy/core/src/multiarray/multiarraymodule.h
index 992acd09f..9ba2a1831 100644
--- a/numpy/core/src/multiarray/multiarraymodule.h
+++ b/numpy/core/src/multiarray/multiarraymodule.h
@@ -16,5 +16,6 @@ NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_axis1;
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_axis2;
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_like;
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_numpy;
+NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_where;
#endif /* NUMPY_CORE_SRC_MULTIARRAY_MULTIARRAYMODULE_H_ */
diff --git a/numpy/core/src/umath/override.c b/numpy/core/src/umath/override.c
index d247c2639..167164163 100644
--- a/numpy/core/src/umath/override.c
+++ b/numpy/core/src/umath/override.c
@@ -23,18 +23,19 @@
* Returns -1 on failure.
*/
static int
-get_array_ufunc_overrides(PyObject *in_args, PyObject *out_args,
+get_array_ufunc_overrides(PyObject *in_args, PyObject *out_args, PyObject *wheremask_obj,
PyObject **with_override, PyObject **methods)
{
int i;
int num_override_args = 0;
- int narg, nout;
+ int narg, nout, nwhere;
narg = (int)PyTuple_GET_SIZE(in_args);
/* It is valid for out_args to be NULL: */
nout = (out_args != NULL) ? (int)PyTuple_GET_SIZE(out_args) : 0;
+ nwhere = (wheremask_obj != NULL) ? 1: 0;
- for (i = 0; i < narg + nout; ++i) {
+ for (i = 0; i < narg + nout + nwhere; ++i) {
PyObject *obj;
int j;
int new_class = 1;
@@ -42,9 +43,12 @@ get_array_ufunc_overrides(PyObject *in_args, PyObject *out_args,
if (i < narg) {
obj = PyTuple_GET_ITEM(in_args, i);
}
- else {
+ else if (i < narg + nout){
obj = PyTuple_GET_ITEM(out_args, i - narg);
}
+ else {
+ obj = wheremask_obj;
+ }
/*
* Have we seen this class before? If so, ignore.
*/
@@ -208,7 +212,7 @@ copy_positional_args_to_kwargs(const char **keywords,
*/
NPY_NO_EXPORT int
PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
- PyObject *in_args, PyObject *out_args,
+ PyObject *in_args, PyObject *out_args, PyObject *wheremask_obj,
PyObject *const *args, Py_ssize_t len_args, PyObject *kwnames,
PyObject **result)
{
@@ -227,7 +231,7 @@ PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
* Check inputs for overrides
*/
num_override_args = get_array_ufunc_overrides(
- in_args, out_args, with_override, array_ufunc_methods);
+ in_args, out_args, wheremask_obj, with_override, array_ufunc_methods);
if (num_override_args == -1) {
goto fail;
}
diff --git a/numpy/core/src/umath/override.h b/numpy/core/src/umath/override.h
index 4e9a323ca..20621bb19 100644
--- a/numpy/core/src/umath/override.h
+++ b/numpy/core/src/umath/override.h
@@ -6,7 +6,7 @@
NPY_NO_EXPORT int
PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
- PyObject *in_args, PyObject *out_args,
+ PyObject *in_args, PyObject *out_args, PyObject *wheremask_obj,
PyObject *const *args, Py_ssize_t len_args, PyObject *kwnames,
PyObject **result);
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c
index a159003de..a5e8f4cbe 100644
--- a/numpy/core/src/umath/ufunc_object.c
+++ b/numpy/core/src/umath/ufunc_object.c
@@ -4071,7 +4071,7 @@ PyUFunc_GenericReduction(PyUFuncObject *ufunc,
/* We now have all the information required to check for Overrides */
PyObject *override = NULL;
int errval = PyUFunc_CheckOverride(ufunc, _reduce_type[operation],
- full_args.in, full_args.out, args, len_args, kwnames, &override);
+ full_args.in, full_args.out, wheremask_obj, args, len_args, kwnames, &override);
if (errval) {
return NULL;
}
@@ -4843,7 +4843,7 @@ ufunc_generic_fastcall(PyUFuncObject *ufunc,
/* We now have all the information required to check for Overrides */
PyObject *override = NULL;
errval = PyUFunc_CheckOverride(ufunc, method,
- full_args.in, full_args.out,
+ full_args.in, full_args.out, where_obj,
args, len_args, kwnames, &override);
if (errval) {
goto fail;
@@ -6261,7 +6261,7 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
return NULL;
}
errval = PyUFunc_CheckOverride(ufunc, "at",
- args, NULL, NULL, 0, NULL, &override);
+ args, NULL, NULL, NULL, 0, NULL, &override);
if (errval) {
return NULL;
diff --git a/numpy/core/tests/test_overrides.py b/numpy/core/tests/test_overrides.py
index 90d917701..42e3d70bd 100644
--- a/numpy/core/tests/test_overrides.py
+++ b/numpy/core/tests/test_overrides.py
@@ -248,6 +248,19 @@ class TestArrayFunctionDispatch:
with assert_raises_regex(TypeError, 'no implementation found'):
dispatched_one_arg(array)
+ def test_where_dispatch(self):
+
+ class DuckArray:
+ def __array_function__(self, ufunc, method, *inputs, **kwargs):
+ return "overridden"
+
+ array = np.array(1)
+ duck_array = DuckArray()
+
+ result = np.std(array, where=duck_array)
+
+ assert_equal(result, "overridden")
+
@requires_array_function
class TestVerifyMatchingSignatures:
diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py
index e504ddd6e..2d5604a4f 100644
--- a/numpy/core/tests/test_umath.py
+++ b/numpy/core/tests/test_umath.py
@@ -3493,6 +3493,79 @@ class TestSpecialMethods:
assert_raises(ValueError, np.modf, a, out=('one', 'two', 'three'))
assert_raises(ValueError, np.modf, a, out=('one',))
+ def test_ufunc_override_where(self):
+
+ class OverriddenArrayOld(np.ndarray):
+
+ def _unwrap(self, objs):
+ cls = type(self)
+ result = []
+ for obj in objs:
+ if isinstance(obj, cls):
+ obj = np.array(obj)
+ elif type(obj) != np.ndarray:
+ return NotImplemented
+ result.append(obj)
+ return result
+
+ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
+
+ inputs = self._unwrap(inputs)
+ if inputs is NotImplemented:
+ return NotImplemented
+
+ kwargs = kwargs.copy()
+ if "out" in kwargs:
+ kwargs["out"] = self._unwrap(kwargs["out"])
+ if kwargs["out"] is NotImplemented:
+ return NotImplemented
+
+ r = super().__array_ufunc__(ufunc, method, *inputs, **kwargs)
+ if r is not NotImplemented:
+ r = r.view(type(self))
+
+ return r
+
+ class OverriddenArrayNew(OverriddenArrayOld):
+ def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
+
+ kwargs = kwargs.copy()
+ if "where" in kwargs:
+ kwargs["where"] = self._unwrap((kwargs["where"], ))
+ if kwargs["where"] is NotImplemented:
+ return NotImplemented
+ else:
+ kwargs["where"] = kwargs["where"][0]
+
+ r = super().__array_ufunc__(ufunc, method, *inputs, **kwargs)
+ if r is not NotImplemented:
+ r = r.view(type(self))
+
+ return r
+
+ ufunc = np.negative
+
+ array = np.array([1, 2, 3])
+ where = np.array([True, False, True])
+ expected = ufunc(array, where=where)
+
+ with pytest.raises(TypeError):
+ ufunc(array, where=where.view(OverriddenArrayOld))
+
+ result_1 = ufunc(
+ array,
+ where=where.view(OverriddenArrayNew)
+ )
+ assert isinstance(result_1, OverriddenArrayNew)
+ assert np.all(np.array(result_1) == expected, where=where)
+
+ result_2 = ufunc(
+ array.view(OverriddenArrayNew),
+ where=where.view(OverriddenArrayNew)
+ )
+ assert isinstance(result_2, OverriddenArrayNew)
+ assert np.all(np.array(result_2) == expected, where=where)
+
def test_ufunc_override_exception(self):
class A: