diff options
-rw-r--r-- | numpy/core/src/multiarray/methods.c | 35 | ||||
-rw-r--r-- | numpy/core/src/private/ufunc_override.c | 87 | ||||
-rw-r--r-- | numpy/core/src/private/ufunc_override.h | 4 |
3 files changed, 82 insertions, 44 deletions
diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c index 36f48ce8f..6cfd05cd6 100644 --- a/numpy/core/src/multiarray/methods.c +++ b/numpy/core/src/multiarray/methods.c @@ -1010,37 +1010,48 @@ array_getarray(PyArrayObject *self, PyObject *args) static PyObject * array_ufunc(PyArrayObject *self, PyObject *args, PyObject *kwds) { - PyObject *ufunc, *method_name, *normal_args, *ufunc_method, *result; + PyObject *ufunc, *method_name, *normal_args, *ufunc_method; + PyObject *result = NULL; if (PyTuple_Size(args) < 2) { PyErr_SetString(PyExc_TypeError, "__array_ufunc__ requires at least 2 arguments"); return NULL; } + normal_args = PyTuple_GetSlice(args, 2, PyTuple_GET_SIZE(args)); + if (normal_args == NULL) { + return NULL; + } + /* ndarray cannot handle overrides itself */ + if (PyUFunc_HasOverride(normal_args, kwds, NULL)) { + result = Py_NotImplemented; + Py_INCREF(Py_NotImplemented); + goto cleanup; + } + ufunc = PyTuple_GET_ITEM(args, 0); if (ufunc == NULL) { - return NULL; + goto cleanup; } method_name = PyTuple_GET_ITEM(args, 1); if (method_name == NULL) { - return NULL; - } - - normal_args = PyTuple_GetSlice(args, 2, PyTuple_GET_SIZE(args)); - if (normal_args == NULL) { - return NULL; + goto cleanup; } + /* + * TODO(?): call into UFunc code at a later point, since here arguments are + * already normalized and we do not have to look for __array_ufunc__ again. + */ ufunc_method = PyObject_GetAttr(ufunc, method_name); if (ufunc_method == NULL) { - Py_DECREF(normal_args); - return NULL; + goto cleanup; } - result = PyObject_Call(ufunc_method, normal_args, kwds); - Py_DECREF(normal_args); Py_DECREF(ufunc_method); + +cleanup: + Py_DECREF(normal_args); /* no need to DECREF borrowed references ufunc and method_name */ return result; } diff --git a/numpy/core/src/private/ufunc_override.c b/numpy/core/src/private/ufunc_override.c index a7d71ee94..1db4e54b9 100644 --- a/numpy/core/src/private/ufunc_override.c +++ b/numpy/core/src/private/ufunc_override.c @@ -214,24 +214,16 @@ has_non_default_array_ufunc(PyObject *obj) } /* - * Check a set of args for the `__array_ufunc__` method. If more than one of - * the input arguments implements `__array_ufunc__`, they are tried in the - * order: subclasses before superclasses, otherwise left to right. The first - * (non-None) routine returning something other than `NotImplemented` - * determines the result. If all of the `__array_ufunc__` operations return - * `NotImplemented` (or are None), a `TypeError` is raised. - * - * Returns 0 on success and 1 on exception. On success, *result contains the - * result of the operation, if any. If *result is NULL, there is no override. + * Check whether a set of input and output args have a non-default + * `__array_ufunc__` method. Returns the number of overrides, setting + * corresponding objects in PyObject array with_override (if not NULL). + * returns -1 on failure. */ NPY_NO_EXPORT int -PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method, - PyObject *args, PyObject *kwds, - PyObject **result, - int nin) +PyUFunc_HasOverride(PyObject *args, PyObject *kwds, + PyObject **with_override) { int i; - int j; int nargs; int nout_kwd = 0; @@ -240,16 +232,6 @@ PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method, PyObject *obj; PyObject *out_kwd_obj = NULL; - PyObject *other_obj; - - PyObject *method_name = NULL; - PyObject *normal_args = NULL; /* normal_* holds normalized arguments. */ - PyObject *normal_kwds = NULL; - - PyObject *override_args = NULL; - PyObject *with_override[NPY_MAXARGS]; - Py_ssize_t len; - /* * Check inputs */ @@ -266,7 +248,6 @@ PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method, "to PyUFunc_CheckOverride"); goto fail; } - /* be sure to include possible 'out' keyword argument. */ if (kwds && PyDict_CheckExact(kwds)) { out_kwd_obj = PyDict_GetItemString(kwds, "out"); @@ -287,7 +268,7 @@ PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method, } else { if (out_kwd_is_tuple) { - obj = PyTuple_GET_ITEM(out_kwd_obj, i-nargs); + obj = PyTuple_GET_ITEM(out_kwd_obj, i - nargs); } else { obj = out_kwd_obj; @@ -299,22 +280,60 @@ PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method, * any ndarray subclass instances that did not override __array_ufunc__. */ if (has_non_default_array_ufunc(obj)) { - with_override[noa] = obj; + if (with_override != NULL) { + with_override[noa] = obj; + } ++noa; } } + return noa; +fail: + return -1; +} +/* + * Check a set of args for the `__array_ufunc__` method. If more than one of + * the input arguments implements `__array_ufunc__`, they are tried in the + * order: subclasses before superclasses, otherwise left to right. The first + * (non-None) routine returning something other than `NotImplemented` + * determines the result. If all of the `__array_ufunc__` operations return + * `NotImplemented` (or are None), a `TypeError` is raised. + * + * Returns 0 on success and 1 on exception. On success, *result contains the + * result of the operation, if any. If *result is NULL, there is no override. + */ +NPY_NO_EXPORT int +PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method, + PyObject *args, PyObject *kwds, + PyObject **result, + int nin) +{ + int i; + int j; + + int noa; + PyObject *with_override[NPY_MAXARGS]; + + PyObject *obj; + PyObject *other_obj; + + PyObject *method_name = NULL; + PyObject *normal_args = NULL; /* normal_* holds normalized arguments. */ + PyObject *normal_kwds = NULL; + + PyObject *override_args = NULL; + Py_ssize_t len; + + /* + * Check inputs for overrides + */ + noa = PyUFunc_HasOverride(args, kwds, with_override); /* No overrides, bail out.*/ if (noa == 0) { *result = NULL; return 0; } - method_name = PyUString_FromString(method); - if (method_name == NULL) { - goto fail; - } - /* * Normalize ufunc arguments. */ @@ -409,6 +428,10 @@ PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method, Py_INCREF(ufunc); /* PyTuple_SET_ITEM steals reference */ PyTuple_SET_ITEM(override_args, 0, (PyObject *)ufunc); + method_name = PyUString_FromString(method); + if (method_name == NULL) { + goto fail; + } Py_INCREF(method_name); PyTuple_SET_ITEM(override_args, 1, method_name); for (i = 0; i < len; i++) { diff --git a/numpy/core/src/private/ufunc_override.h b/numpy/core/src/private/ufunc_override.h index ee18e569c..db5e84fd5 100644 --- a/numpy/core/src/private/ufunc_override.h +++ b/numpy/core/src/private/ufunc_override.h @@ -5,6 +5,10 @@ #include "numpy/ufuncobject.h" NPY_NO_EXPORT int +PyUFunc_HasOverride(PyObject *args, PyObject *kwds, + PyObject **with_override); + +NPY_NO_EXPORT int PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method, PyObject *args, PyObject *kwds, PyObject **result, |