summaryrefslogtreecommitdiff
path: root/numpy/core/src/multiarray
diff options
context:
space:
mode:
authorSebastian Berg <sebastianb@nvidia.com>2023-03-22 12:20:00 +0100
committerGitHub <noreply@github.com>2023-03-22 12:20:00 +0100
commitb35aac2c35ccfd5efadd7f72a090c9ad99308a60 (patch)
treeb1bd47eb0b5f68488379d8ea631aaa9c360a7552 /numpy/core/src/multiarray
parent294c7f2c893b7e5ef783fc1cb1912d06404b452b (diff)
parentf3f108d313a8b8a4f7a90fb932867f17dc48b1f6 (diff)
downloadnumpy-b35aac2c35ccfd5efadd7f72a090c9ad99308a60.tar.gz
Merge pull request #23240 from byrdie/bugfix/ufunc_where_propagation
ENH: Allow ``where`` argument to override ``__array_ufunc__``
Diffstat (limited to 'numpy/core/src/multiarray')
-rw-r--r--numpy/core/src/multiarray/methods.c14
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.c5
-rw-r--r--numpy/core/src/multiarray/multiarraymodule.h1
3 files changed, 19 insertions, 1 deletions
diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c
index f518f3a02..93b290020 100644
--- a/numpy/core/src/multiarray/methods.c
+++ b/numpy/core/src/multiarray/methods.c
@@ -28,6 +28,7 @@
#include "strfuncs.h"
#include "array_assign.h"
#include "npy_dlpack.h"
+#include "multiarraymodule.h"
#include "methods.h"
#include "alloc.h"
@@ -1102,7 +1103,7 @@ any_array_ufunc_overrides(PyObject *args, PyObject *kwds)
int nin, nout;
PyObject *out_kwd_obj;
PyObject *fast;
- PyObject **in_objs, **out_objs;
+ PyObject **in_objs, **out_objs, *where_obj;
/* check inputs */
nin = PyTuple_Size(args);
@@ -1133,6 +1134,17 @@ any_array_ufunc_overrides(PyObject *args, PyObject *kwds)
}
}
Py_DECREF(out_kwd_obj);
+ /* check where if it exists */
+ where_obj = PyDict_GetItemWithError(kwds, npy_ma_str_where);
+ if (where_obj == NULL) {
+ if (PyErr_Occurred()) {
+ return -1;
+ }
+ } else {
+ if (PyUFunc_HasOverride(where_obj)){
+ return 1;
+ }
+ }
return 0;
}
diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c
index e85f8affa..ac8e641b7 100644
--- a/numpy/core/src/multiarray/multiarraymodule.c
+++ b/numpy/core/src/multiarray/multiarraymodule.c
@@ -4843,6 +4843,7 @@ NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_axis1 = NULL;
NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_axis2 = NULL;
NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_like = NULL;
NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_numpy = NULL;
+NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_where = NULL;
static int
intern_strings(void)
@@ -4899,6 +4900,10 @@ intern_strings(void)
if (npy_ma_str_numpy == NULL) {
return -1;
}
+ npy_ma_str_where = PyUnicode_InternFromString("where");
+ if (npy_ma_str_where == NULL) {
+ return -1;
+ }
return 0;
}
diff --git a/numpy/core/src/multiarray/multiarraymodule.h b/numpy/core/src/multiarray/multiarraymodule.h
index 992acd09f..9ba2a1831 100644
--- a/numpy/core/src/multiarray/multiarraymodule.h
+++ b/numpy/core/src/multiarray/multiarraymodule.h
@@ -16,5 +16,6 @@ NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_axis1;
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_axis2;
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_like;
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_numpy;
+NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_where;
#endif /* NUMPY_CORE_SRC_MULTIARRAY_MULTIARRAYMODULE_H_ */