summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPauli Virtanen <pav@iki.fi>2013-09-22 17:25:35 +0300
committerPauli Virtanen <pav@iki.fi>2013-09-22 17:52:24 +0300
commitf42549fc6c80eb1ac7783393993bd9feec055f16 (patch)
treee790187d8b3c4008106219871243d813bdefc786
parentacf72ea8286172b141624b27f00571f52aa289d0 (diff)
downloadnumpy-f42549fc6c80eb1ac7783393993bd9feec055f16.tar.gz
REF: core: clean up PyUFunc_CheckOverride
Move argument normalization out from the inner loop, reduce complexity. Add error checks.
-rw-r--r--numpy/core/src/private/ufunc_override.h180
1 files changed, 91 insertions, 89 deletions
diff --git a/numpy/core/src/private/ufunc_override.h b/numpy/core/src/private/ufunc_override.h
index d1e8b5688..af7f6e46e 100644
--- a/numpy/core/src/private/ufunc_override.h
+++ b/numpy/core/src/private/ufunc_override.h
@@ -11,6 +11,9 @@
* routine returning something other than `NotImplemented` determines the
* result. If all of the `__numpy_ufunc__` operations returns `NotImplemented`,
* a `TypeError` is raised.
+ *
+ * Returns 0 on success and 1 on exception. On success, *result contains the
+ * result of the operation, if any. If *result is NULL, there is no override.
*/
static int
PyUFunc_CheckOverride(PyObject *ufunc, char *method,
@@ -21,31 +24,36 @@ PyUFunc_CheckOverride(PyObject *ufunc, char *method,
int i;
int override_pos; /* Position of override in args.*/
int j;
- int pos_in_with_override; /* Position of override in with_override.*/
int nargs = PyTuple_GET_SIZE(args);
int noa = 0; /* Number of overriding args.*/
- int normalized = 0; /* Is normalized flag.*/
PyObject *obj;
PyObject *other_obj;
- PyObject *override_args;
- PyObject *method_name = PyUString_FromString(method);
+ PyObject *method_name = NULL;
PyObject *normal_args = NULL; /* normal_* holds normalized arguments. */
PyObject *normal_kwds = NULL;
- PyObject *override_obj = NULL; /* overriding object */
- PyObject *numpy_ufunc = NULL; /* the __numpy_ufunc__ method */
PyObject *with_override[NPY_MAXARGS];
+
/* Pos of each override in args */
int with_override_pos[NPY_MAXARGS];
- /* Checks */
+ /*
+ * Check inputs
+ */
if (!PyTuple_Check(args)) {
+ PyErr_SetString(PyExc_ValueError,
+ "Internal Numpy error: call to PyUFunc_CheckOverride "
+ "with non-tuple");
goto fail;
}
+
if (PyTuple_GET_SIZE(args) > NPY_MAXARGS) {
+ PyErr_SetString(PyExc_ValueError,
+ "Internal Numpy error: too many arguments in call "
+ "to PyUFunc_CheckOverride");
goto fail;
}
@@ -63,12 +71,48 @@ PyUFunc_CheckOverride(PyObject *ufunc, char *method,
/* No overrides, bail out.*/
if (noa == 0) {
- Py_DECREF(method_name);
+ *result = NULL;
return 0;
}
+ /*
+ * Normalize ufunc arguments.
+ */
+ normal_args = PyTuple_GetSlice(args, 0, nin);
+ if (normal_args == NULL) {
+ goto fail;
+ }
+
+ /* Build new kwds */
+ if (kwds && PyDict_CheckExact(kwds)) {
+ normal_kwds = PyDict_Copy(kwds);
+ }
+ else {
+ normal_kwds = PyDict_New();
+ }
+ if (normal_kwds == NULL) {
+ goto fail;
+ }
+
+ /* If we have more args than nin, the last one must be `out`.*/
+ if (nargs > nin) {
+ obj = PyTuple_GET_ITEM(args, nargs - 1);
+ PyDict_SetItemString(normal_kwds, "out", obj);
+ }
+
+ method_name = PyUString_FromString(method);
+ if (method_name == NULL) {
+ goto fail;
+ }
+
+ /*
+ * Call __numpy_ufunc__ functions in correct order
+ */
while (1) {
- obj = NULL;
+ PyObject *numpy_ufunc;
+ PyObject *override_args;
+ PyObject *override_obj;
+
override_obj = NULL;
*result = NULL;
@@ -78,10 +122,10 @@ PyUFunc_CheckOverride(PyObject *ufunc, char *method,
if (obj == NULL) {
continue;
}
+
/* Get the first instance of an overriding arg.*/
override_pos = with_override_pos[i];
override_obj = obj;
- pos_in_with_override = i;
/* Check for sub-types to the right of obj. */
for (j = i + 1; j < noa; j++) {
@@ -93,109 +137,67 @@ PyUFunc_CheckOverride(PyObject *ufunc, char *method,
break;
}
}
+
/* override_obj had no subtypes to the right. */
if (override_obj) {
+ with_override[i] = NULL; /* We won't call this one again */
break;
}
}
- /* No good override_obj */
+
+ /* Check if there is a method left to call */
if (!override_obj) {
- break;
+ /* No acceptable override found. */
+ PyErr_SetString(PyExc_TypeError,
+ "__numpy_ufunc__ not implemented for this type.");
+ goto fail;
}
- /*
- * Normalize the ufuncs arguments. Returns a tuple of
- * (args, kwds).
- *
- * Test with and without kwds.
- */
- if (!normalized) {
- PyObject *out_arg;
-
- /* If we have more args than nin, the last one must be `out`.*/
- if (nargs > nin) {
- out_arg = PyTuple_GET_ITEM(args, nargs - 1);
-
- /* Build new args.*/
- normal_args = PyTuple_GetSlice(args, 0, nin);
-
- /* Build new kwds with out arg.*/
- if (kwds && PyDict_CheckExact(kwds)) {
- normal_kwds = PyDict_Copy(kwds);
- PyDict_SetItemString(normal_kwds, "out", out_arg);
- }
- else {
- normal_kwds = PyDict_New();
- PyDict_SetItemString(normal_kwds, "out", out_arg);
- }
-
- normalized = 1;
- }
- else {
- /* Copy args */
- normal_args = PyTuple_GetSlice(args, 0, nin);
- if (kwds && PyDict_CheckExact(kwds)) {
- normal_kwds = PyDict_Copy(kwds);
- }
- else {
- normal_kwds = PyDict_New();
- }
- normalized = 1;
- }
+ /* Call the override */
+ numpy_ufunc = PyObject_GetAttrString(override_obj,
+ "__numpy_ufunc__");
+ if (numpy_ufunc == NULL) {
+ goto fail;
}
- /* Calculate a result if we have a override. */
- if (override_obj) {
- numpy_ufunc = PyObject_GetAttrString(override_obj,
- "__numpy_ufunc__");
- override_args = Py_BuildValue("OOiO", ufunc, method_name,
- override_pos, normal_args);
- *result = PyObject_Call(numpy_ufunc, override_args, normal_kwds);
-
+ override_args = Py_BuildValue("OOiO", ufunc, method_name,
+ override_pos, normal_args);
+ if (override_args == NULL) {
Py_DECREF(numpy_ufunc);
- Py_DECREF(override_args);
-
- if (*result == NULL) {
- /* Exception occurred */
- Py_XDECREF(normal_args);
- Py_XDECREF(normal_kwds);
- goto fail;
- }
- if (*result == Py_NotImplemented) {
- /* Remove this arg if it gives not implemented */
- with_override[pos_in_with_override] = NULL;
- Py_DECREF(*result);
- continue;
- }
- else {
- /* Good result. */
- break;
- }
+ goto fail;
}
- /* All overrides checked. */
+ *result = PyObject_Call(numpy_ufunc, override_args, normal_kwds);
+
+ Py_DECREF(numpy_ufunc);
+ Py_DECREF(override_args);
+
+ if (*result == NULL) {
+ /* Exception occurred */
+ goto fail;
+ }
+ else if (*result == Py_NotImplemented) {
+ /* Try the next one */
+ Py_DECREF(*result);
+ continue;
+ }
else {
+ /* Good result. */
break;
}
}
- /* No acceptable override found. */
- if (!*result) {
- PyErr_SetString(PyExc_TypeError,
- "__numpy_ufunc__ not implemented for this type.");
- Py_XDECREF(normal_args);
- Py_XDECREF(normal_kwds);
- goto fail;
- }
+
/* Override found, return it. */
- Py_DECREF(method_name);
+ Py_XDECREF(method_name);
Py_XDECREF(normal_args);
Py_XDECREF(normal_kwds);
return 0;
fail:
- Py_DECREF(method_name);
+ Py_XDECREF(method_name);
+ Py_XDECREF(normal_args);
+ Py_XDECREF(normal_kwds);
return 1;
-
}
#endif