diff options
| author | Sebastian Berg <sebastianb@nvidia.com> | 2023-03-22 12:20:00 +0100 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-03-22 12:20:00 +0100 |
| commit | b35aac2c35ccfd5efadd7f72a090c9ad99308a60 (patch) | |
| tree | b1bd47eb0b5f68488379d8ea631aaa9c360a7552 /numpy/core/src/multiarray | |
| parent | 294c7f2c893b7e5ef783fc1cb1912d06404b452b (diff) | |
| parent | f3f108d313a8b8a4f7a90fb932867f17dc48b1f6 (diff) | |
| download | numpy-b35aac2c35ccfd5efadd7f72a090c9ad99308a60.tar.gz | |
Merge pull request #23240 from byrdie/bugfix/ufunc_where_propagation
ENH: Allow ``where`` argument to override ``__array_ufunc__``
Diffstat (limited to 'numpy/core/src/multiarray')
| -rw-r--r-- | numpy/core/src/multiarray/methods.c | 14 | ||||
| -rw-r--r-- | numpy/core/src/multiarray/multiarraymodule.c | 5 | ||||
| -rw-r--r-- | numpy/core/src/multiarray/multiarraymodule.h | 1 |
3 files changed, 19 insertions, 1 deletions
diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c index f518f3a02..93b290020 100644 --- a/numpy/core/src/multiarray/methods.c +++ b/numpy/core/src/multiarray/methods.c @@ -28,6 +28,7 @@ #include "strfuncs.h" #include "array_assign.h" #include "npy_dlpack.h" +#include "multiarraymodule.h" #include "methods.h" #include "alloc.h" @@ -1102,7 +1103,7 @@ any_array_ufunc_overrides(PyObject *args, PyObject *kwds) int nin, nout; PyObject *out_kwd_obj; PyObject *fast; - PyObject **in_objs, **out_objs; + PyObject **in_objs, **out_objs, *where_obj; /* check inputs */ nin = PyTuple_Size(args); @@ -1133,6 +1134,17 @@ any_array_ufunc_overrides(PyObject *args, PyObject *kwds) } } Py_DECREF(out_kwd_obj); + /* check where if it exists */ + where_obj = PyDict_GetItemWithError(kwds, npy_ma_str_where); + if (where_obj == NULL) { + if (PyErr_Occurred()) { + return -1; + } + } else { + if (PyUFunc_HasOverride(where_obj)){ + return 1; + } + } return 0; } diff --git a/numpy/core/src/multiarray/multiarraymodule.c b/numpy/core/src/multiarray/multiarraymodule.c index e85f8affa..ac8e641b7 100644 --- a/numpy/core/src/multiarray/multiarraymodule.c +++ b/numpy/core/src/multiarray/multiarraymodule.c @@ -4843,6 +4843,7 @@ NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_axis1 = NULL; NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_axis2 = NULL; NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_like = NULL; NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_numpy = NULL; +NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_where = NULL; static int intern_strings(void) @@ -4899,6 +4900,10 @@ intern_strings(void) if (npy_ma_str_numpy == NULL) { return -1; } + npy_ma_str_where = PyUnicode_InternFromString("where"); + if (npy_ma_str_where == NULL) { + return -1; + } return 0; } diff --git a/numpy/core/src/multiarray/multiarraymodule.h b/numpy/core/src/multiarray/multiarraymodule.h index 992acd09f..9ba2a1831 100644 --- a/numpy/core/src/multiarray/multiarraymodule.h +++ b/numpy/core/src/multiarray/multiarraymodule.h @@ -16,5 +16,6 @@ NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_axis1; NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_axis2; NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_like; NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_numpy; +NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_where; #endif /* NUMPY_CORE_SRC_MULTIARRAY_MULTIARRAYMODULE_H_ */ |
