diff options
-rw-r--r-- | numpy/core/src/private/ufunc_override.h | 180 |
1 files changed, 91 insertions, 89 deletions
diff --git a/numpy/core/src/private/ufunc_override.h b/numpy/core/src/private/ufunc_override.h index d1e8b5688..af7f6e46e 100644 --- a/numpy/core/src/private/ufunc_override.h +++ b/numpy/core/src/private/ufunc_override.h @@ -11,6 +11,9 @@ * routine returning something other than `NotImplemented` determines the * result. If all of the `__numpy_ufunc__` operations returns `NotImplemented`, * a `TypeError` is raised. + * + * Returns 0 on success and 1 on exception. On success, *result contains the + * result of the operation, if any. If *result is NULL, there is no override. */ static int PyUFunc_CheckOverride(PyObject *ufunc, char *method, @@ -21,31 +24,36 @@ PyUFunc_CheckOverride(PyObject *ufunc, char *method, int i; int override_pos; /* Position of override in args.*/ int j; - int pos_in_with_override; /* Position of override in with_override.*/ int nargs = PyTuple_GET_SIZE(args); int noa = 0; /* Number of overriding args.*/ - int normalized = 0; /* Is normalized flag.*/ PyObject *obj; PyObject *other_obj; - PyObject *override_args; - PyObject *method_name = PyUString_FromString(method); + PyObject *method_name = NULL; PyObject *normal_args = NULL; /* normal_* holds normalized arguments. */ PyObject *normal_kwds = NULL; - PyObject *override_obj = NULL; /* overriding object */ - PyObject *numpy_ufunc = NULL; /* the __numpy_ufunc__ method */ PyObject *with_override[NPY_MAXARGS]; + /* Pos of each override in args */ int with_override_pos[NPY_MAXARGS]; - /* Checks */ + /* + * Check inputs + */ if (!PyTuple_Check(args)) { + PyErr_SetString(PyExc_ValueError, + "Internal Numpy error: call to PyUFunc_CheckOverride " + "with non-tuple"); goto fail; } + if (PyTuple_GET_SIZE(args) > NPY_MAXARGS) { + PyErr_SetString(PyExc_ValueError, + "Internal Numpy error: too many arguments in call " + "to PyUFunc_CheckOverride"); goto fail; } @@ -63,12 +71,48 @@ PyUFunc_CheckOverride(PyObject *ufunc, char *method, /* No overrides, bail out.*/ if (noa == 0) { - Py_DECREF(method_name); + *result = NULL; return 0; } + /* + * Normalize ufunc arguments. + */ + normal_args = PyTuple_GetSlice(args, 0, nin); + if (normal_args == NULL) { + goto fail; + } + + /* Build new kwds */ + if (kwds && PyDict_CheckExact(kwds)) { + normal_kwds = PyDict_Copy(kwds); + } + else { + normal_kwds = PyDict_New(); + } + if (normal_kwds == NULL) { + goto fail; + } + + /* If we have more args than nin, the last one must be `out`.*/ + if (nargs > nin) { + obj = PyTuple_GET_ITEM(args, nargs - 1); + PyDict_SetItemString(normal_kwds, "out", obj); + } + + method_name = PyUString_FromString(method); + if (method_name == NULL) { + goto fail; + } + + /* + * Call __numpy_ufunc__ functions in correct order + */ while (1) { - obj = NULL; + PyObject *numpy_ufunc; + PyObject *override_args; + PyObject *override_obj; + override_obj = NULL; *result = NULL; @@ -78,10 +122,10 @@ PyUFunc_CheckOverride(PyObject *ufunc, char *method, if (obj == NULL) { continue; } + /* Get the first instance of an overriding arg.*/ override_pos = with_override_pos[i]; override_obj = obj; - pos_in_with_override = i; /* Check for sub-types to the right of obj. */ for (j = i + 1; j < noa; j++) { @@ -93,109 +137,67 @@ PyUFunc_CheckOverride(PyObject *ufunc, char *method, break; } } + /* override_obj had no subtypes to the right. */ if (override_obj) { + with_override[i] = NULL; /* We won't call this one again */ break; } } - /* No good override_obj */ + + /* Check if there is a method left to call */ if (!override_obj) { - break; + /* No acceptable override found. */ + PyErr_SetString(PyExc_TypeError, + "__numpy_ufunc__ not implemented for this type."); + goto fail; } - /* - * Normalize the ufuncs arguments. Returns a tuple of - * (args, kwds). - * - * Test with and without kwds. - */ - if (!normalized) { - PyObject *out_arg; - - /* If we have more args than nin, the last one must be `out`.*/ - if (nargs > nin) { - out_arg = PyTuple_GET_ITEM(args, nargs - 1); - - /* Build new args.*/ - normal_args = PyTuple_GetSlice(args, 0, nin); - - /* Build new kwds with out arg.*/ - if (kwds && PyDict_CheckExact(kwds)) { - normal_kwds = PyDict_Copy(kwds); - PyDict_SetItemString(normal_kwds, "out", out_arg); - } - else { - normal_kwds = PyDict_New(); - PyDict_SetItemString(normal_kwds, "out", out_arg); - } - - normalized = 1; - } - else { - /* Copy args */ - normal_args = PyTuple_GetSlice(args, 0, nin); - if (kwds && PyDict_CheckExact(kwds)) { - normal_kwds = PyDict_Copy(kwds); - } - else { - normal_kwds = PyDict_New(); - } - normalized = 1; - } + /* Call the override */ + numpy_ufunc = PyObject_GetAttrString(override_obj, + "__numpy_ufunc__"); + if (numpy_ufunc == NULL) { + goto fail; } - /* Calculate a result if we have a override. */ - if (override_obj) { - numpy_ufunc = PyObject_GetAttrString(override_obj, - "__numpy_ufunc__"); - override_args = Py_BuildValue("OOiO", ufunc, method_name, - override_pos, normal_args); - *result = PyObject_Call(numpy_ufunc, override_args, normal_kwds); - + override_args = Py_BuildValue("OOiO", ufunc, method_name, + override_pos, normal_args); + if (override_args == NULL) { Py_DECREF(numpy_ufunc); - Py_DECREF(override_args); - - if (*result == NULL) { - /* Exception occurred */ - Py_XDECREF(normal_args); - Py_XDECREF(normal_kwds); - goto fail; - } - if (*result == Py_NotImplemented) { - /* Remove this arg if it gives not implemented */ - with_override[pos_in_with_override] = NULL; - Py_DECREF(*result); - continue; - } - else { - /* Good result. */ - break; - } + goto fail; } - /* All overrides checked. */ + *result = PyObject_Call(numpy_ufunc, override_args, normal_kwds); + + Py_DECREF(numpy_ufunc); + Py_DECREF(override_args); + + if (*result == NULL) { + /* Exception occurred */ + goto fail; + } + else if (*result == Py_NotImplemented) { + /* Try the next one */ + Py_DECREF(*result); + continue; + } else { + /* Good result. */ break; } } - /* No acceptable override found. */ - if (!*result) { - PyErr_SetString(PyExc_TypeError, - "__numpy_ufunc__ not implemented for this type."); - Py_XDECREF(normal_args); - Py_XDECREF(normal_kwds); - goto fail; - } + /* Override found, return it. */ - Py_DECREF(method_name); + Py_XDECREF(method_name); Py_XDECREF(normal_args); Py_XDECREF(normal_kwds); return 0; fail: - Py_DECREF(method_name); + Py_XDECREF(method_name); + Py_XDECREF(normal_args); + Py_XDECREF(normal_kwds); return 1; - } #endif |