summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/multiarray/methods.c35
-rw-r--r--numpy/core/src/private/ufunc_override.c87
-rw-r--r--numpy/core/src/private/ufunc_override.h4
3 files changed, 82 insertions, 44 deletions
diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c
index 36f48ce8f..6cfd05cd6 100644
--- a/numpy/core/src/multiarray/methods.c
+++ b/numpy/core/src/multiarray/methods.c
@@ -1010,37 +1010,48 @@ array_getarray(PyArrayObject *self, PyObject *args)
static PyObject *
array_ufunc(PyArrayObject *self, PyObject *args, PyObject *kwds)
{
- PyObject *ufunc, *method_name, *normal_args, *ufunc_method, *result;
+ PyObject *ufunc, *method_name, *normal_args, *ufunc_method;
+ PyObject *result = NULL;
if (PyTuple_Size(args) < 2) {
PyErr_SetString(PyExc_TypeError,
"__array_ufunc__ requires at least 2 arguments");
return NULL;
}
+ normal_args = PyTuple_GetSlice(args, 2, PyTuple_GET_SIZE(args));
+ if (normal_args == NULL) {
+ return NULL;
+ }
+ /* ndarray cannot handle overrides itself */
+ if (PyUFunc_HasOverride(normal_args, kwds, NULL)) {
+ result = Py_NotImplemented;
+ Py_INCREF(Py_NotImplemented);
+ goto cleanup;
+ }
+
ufunc = PyTuple_GET_ITEM(args, 0);
if (ufunc == NULL) {
- return NULL;
+ goto cleanup;
}
method_name = PyTuple_GET_ITEM(args, 1);
if (method_name == NULL) {
- return NULL;
- }
-
- normal_args = PyTuple_GetSlice(args, 2, PyTuple_GET_SIZE(args));
- if (normal_args == NULL) {
- return NULL;
+ goto cleanup;
}
+ /*
+ * TODO(?): call into UFunc code at a later point, since here arguments are
+ * already normalized and we do not have to look for __array_ufunc__ again.
+ */
ufunc_method = PyObject_GetAttr(ufunc, method_name);
if (ufunc_method == NULL) {
- Py_DECREF(normal_args);
- return NULL;
+ goto cleanup;
}
-
result = PyObject_Call(ufunc_method, normal_args, kwds);
- Py_DECREF(normal_args);
Py_DECREF(ufunc_method);
+
+cleanup:
+ Py_DECREF(normal_args);
/* no need to DECREF borrowed references ufunc and method_name */
return result;
}
diff --git a/numpy/core/src/private/ufunc_override.c b/numpy/core/src/private/ufunc_override.c
index a7d71ee94..1db4e54b9 100644
--- a/numpy/core/src/private/ufunc_override.c
+++ b/numpy/core/src/private/ufunc_override.c
@@ -214,24 +214,16 @@ has_non_default_array_ufunc(PyObject *obj)
}
/*
- * Check a set of args for the `__array_ufunc__` method. If more than one of
- * the input arguments implements `__array_ufunc__`, they are tried in the
- * order: subclasses before superclasses, otherwise left to right. The first
- * (non-None) routine returning something other than `NotImplemented`
- * determines the result. If all of the `__array_ufunc__` operations return
- * `NotImplemented` (or are None), a `TypeError` is raised.
- *
- * Returns 0 on success and 1 on exception. On success, *result contains the
- * result of the operation, if any. If *result is NULL, there is no override.
+ * Check whether a set of input and output args have a non-default
+ * `__array_ufunc__` method. Returns the number of overrides, setting
+ * corresponding objects in PyObject array with_override (if not NULL).
+ * returns -1 on failure.
*/
NPY_NO_EXPORT int
-PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
- PyObject *args, PyObject *kwds,
- PyObject **result,
- int nin)
+PyUFunc_HasOverride(PyObject *args, PyObject *kwds,
+ PyObject **with_override)
{
int i;
- int j;
int nargs;
int nout_kwd = 0;
@@ -240,16 +232,6 @@ PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
PyObject *obj;
PyObject *out_kwd_obj = NULL;
- PyObject *other_obj;
-
- PyObject *method_name = NULL;
- PyObject *normal_args = NULL; /* normal_* holds normalized arguments. */
- PyObject *normal_kwds = NULL;
-
- PyObject *override_args = NULL;
- PyObject *with_override[NPY_MAXARGS];
- Py_ssize_t len;
-
/*
* Check inputs
*/
@@ -266,7 +248,6 @@ PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
"to PyUFunc_CheckOverride");
goto fail;
}
-
/* be sure to include possible 'out' keyword argument. */
if (kwds && PyDict_CheckExact(kwds)) {
out_kwd_obj = PyDict_GetItemString(kwds, "out");
@@ -287,7 +268,7 @@ PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
}
else {
if (out_kwd_is_tuple) {
- obj = PyTuple_GET_ITEM(out_kwd_obj, i-nargs);
+ obj = PyTuple_GET_ITEM(out_kwd_obj, i - nargs);
}
else {
obj = out_kwd_obj;
@@ -299,22 +280,60 @@ PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
* any ndarray subclass instances that did not override __array_ufunc__.
*/
if (has_non_default_array_ufunc(obj)) {
- with_override[noa] = obj;
+ if (with_override != NULL) {
+ with_override[noa] = obj;
+ }
++noa;
}
}
+ return noa;
+fail:
+ return -1;
+}
+/*
+ * Check a set of args for the `__array_ufunc__` method. If more than one of
+ * the input arguments implements `__array_ufunc__`, they are tried in the
+ * order: subclasses before superclasses, otherwise left to right. The first
+ * (non-None) routine returning something other than `NotImplemented`
+ * determines the result. If all of the `__array_ufunc__` operations return
+ * `NotImplemented` (or are None), a `TypeError` is raised.
+ *
+ * Returns 0 on success and 1 on exception. On success, *result contains the
+ * result of the operation, if any. If *result is NULL, there is no override.
+ */
+NPY_NO_EXPORT int
+PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
+ PyObject *args, PyObject *kwds,
+ PyObject **result,
+ int nin)
+{
+ int i;
+ int j;
+
+ int noa;
+ PyObject *with_override[NPY_MAXARGS];
+
+ PyObject *obj;
+ PyObject *other_obj;
+
+ PyObject *method_name = NULL;
+ PyObject *normal_args = NULL; /* normal_* holds normalized arguments. */
+ PyObject *normal_kwds = NULL;
+
+ PyObject *override_args = NULL;
+ Py_ssize_t len;
+
+ /*
+ * Check inputs for overrides
+ */
+ noa = PyUFunc_HasOverride(args, kwds, with_override);
/* No overrides, bail out.*/
if (noa == 0) {
*result = NULL;
return 0;
}
- method_name = PyUString_FromString(method);
- if (method_name == NULL) {
- goto fail;
- }
-
/*
* Normalize ufunc arguments.
*/
@@ -409,6 +428,10 @@ PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
Py_INCREF(ufunc);
/* PyTuple_SET_ITEM steals reference */
PyTuple_SET_ITEM(override_args, 0, (PyObject *)ufunc);
+ method_name = PyUString_FromString(method);
+ if (method_name == NULL) {
+ goto fail;
+ }
Py_INCREF(method_name);
PyTuple_SET_ITEM(override_args, 1, method_name);
for (i = 0; i < len; i++) {
diff --git a/numpy/core/src/private/ufunc_override.h b/numpy/core/src/private/ufunc_override.h
index ee18e569c..db5e84fd5 100644
--- a/numpy/core/src/private/ufunc_override.h
+++ b/numpy/core/src/private/ufunc_override.h
@@ -5,6 +5,10 @@
#include "numpy/ufuncobject.h"
NPY_NO_EXPORT int
+PyUFunc_HasOverride(PyObject *args, PyObject *kwds,
+ PyObject **with_override);
+
+NPY_NO_EXPORT int
PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
PyObject *args, PyObject *kwds,
PyObject **result,