summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/methods.c15
-rw-r--r--numpy/core/src/private/ufunc_override.c202
-rw-r--r--numpy/core/src/private/ufunc_override.h3
3 files changed, 144 insertions, 76 deletions
diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c
index d6f2577a3..2e836d1d0 100644
--- a/numpy/core/src/multiarray/methods.c
+++ b/numpy/core/src/multiarray/methods.c
@@ -976,9 +976,12 @@ array_ufunc(PyArrayObject *self, PyObject *args, PyObject *kwds)
{
PyObject *ufunc, *method_name, *normal_args, *ufunc_method;
PyObject *result = NULL;
- int num_override_args;
+ int has_override;
- if (PyTuple_Size(args) < 2) {
+ assert(PyTuple_CheckExact(args));
+ assert(kwds == NULL || PyDict_CheckExact(kwds));
+
+ if (PyTuple_GET_SIZE(args) < 2) {
PyErr_SetString(PyExc_TypeError,
"__array_ufunc__ requires at least 2 arguments");
return NULL;
@@ -988,11 +991,11 @@ array_ufunc(PyArrayObject *self, PyObject *args, PyObject *kwds)
return NULL;
}
/* ndarray cannot handle overrides itself */
- num_override_args = PyUFunc_WithOverride(normal_args, kwds, NULL, NULL);
- if (num_override_args == -1) {
- return NULL;
+ has_override = PyUFunc_HasOverride(normal_args, kwds);
+ if (has_override < 0) {
+ goto cleanup;
}
- if (num_override_args) {
+ else if (has_override) {
result = Py_NotImplemented;
Py_INCREF(Py_NotImplemented);
goto cleanup;
diff --git a/numpy/core/src/private/ufunc_override.c b/numpy/core/src/private/ufunc_override.c
index 522b1744d..69c3cc56c 100644
--- a/numpy/core/src/private/ufunc_override.c
+++ b/numpy/core/src/private/ufunc_override.c
@@ -54,11 +54,71 @@ get_non_default_array_ufunc(PyObject *obj)
}
/*
- * Check whether a set of input and output args have a non-default
- * `__array_ufunc__` method. Return the number of overrides, setting
- * corresponding objects in PyObject array with_override and the corresponding
- * __array_ufunc__ methods in methods (both only if not NULL, and both using
- * new references).
+ * Check whether an object has __array_ufunc__ defined on its class and it
+ * is not the default, i.e., the object is not an ndarray, and its
+ * __array_ufunc__ is not the same as that of ndarray.
+ *
+ * Returns 1 if this is the case, 0 if not.
+ */
+
+static int
+has_non_default_array_ufunc(PyObject * obj)
+{
+ PyObject *method = get_non_default_array_ufunc(obj);
+ if (method) {
+ Py_DECREF(method);
+ return 1;
+ }
+ else {
+ return 0;
+ }
+}
+
+/*
+ * Get possible out argument from kwds, and returns the number of outputs
+ * contained within it: if a tuple, the number of elements in it, 1 otherwise.
+ * The out argument itself is returned in out_kwd_obj, and the outputs
+ * in the out_obj array (all as borrowed references).
+ *
+ * Returns -1 if kwds is not a dict, 0 if no outputs found.
+ */
+static int
+get_out_objects(PyObject *kwds, PyObject **out_kwd_obj, PyObject ***out_objs)
+{
+ if (kwds == NULL) {
+ return 0;
+ }
+ if (!PyDict_CheckExact(kwds)) {
+ PyErr_SetString(PyExc_TypeError,
+ "Internal Numpy error: call to PyUFunc_WithOverride "
+ "with non-dict kwds");
+ return -1;
+ }
+ /* borrowed reference */
+ *out_kwd_obj = PyDict_GetItemString(kwds, "out");
+ if (*out_kwd_obj == NULL) {
+ return 0;
+ }
+ if (PyTuple_CheckExact(*out_kwd_obj)) {
+ *out_objs = PySequence_Fast_ITEMS(*out_kwd_obj);
+ return PySequence_Fast_GET_SIZE(*out_kwd_obj);
+ }
+ else {
+ *out_objs = out_kwd_obj;
+ return 1;
+ }
+}
+
+/*
+ * For each positional argument and each argument in a possible "out"
+ * keyword, look for overrides of the standard ufunc behaviour, i.e.,
+ * non-default __array_ufunc__ methods.
+ *
+ * Returns the number of overrides, setting corresponding objects
+ * in PyObject array ``with_override`` and the corresponding
+ * __array_ufunc__ methods in ``methods`` (both using new references).
+ *
+ * Only the first override for a given class is returned.
*
* returns -1 on failure.
*/
@@ -67,68 +127,40 @@ PyUFunc_WithOverride(PyObject *args, PyObject *kwds,
PyObject **with_override, PyObject **methods)
{
int i;
-
- int nargs;
- int nout_kwd = 0;
- int out_kwd_is_tuple = 0;
int num_override_args = 0;
+ int narg, nout = 0;
+ PyObject *out_kwd_obj;
+ PyObject **arg_objs, **out_objs;
- PyObject *out_kwd_obj = NULL;
- /*
- * Check inputs
- */
- if (!PyTuple_Check(args)) {
- PyErr_SetString(PyExc_TypeError,
- "Internal Numpy error: call to PyUFunc_HasOverride "
- "with non-tuple");
- goto fail;
+ narg = PyTuple_Size(args);
+ if (narg < 0) {
+ return -1;
}
- nargs = PyTuple_GET_SIZE(args);
- if (nargs > NPY_MAXARGS) {
- PyErr_SetString(PyExc_TypeError,
- "Internal Numpy error: too many arguments in call "
- "to PyUFunc_HasOverride");
- goto fail;
- }
- /* be sure to include possible 'out' keyword argument. */
- if (kwds && PyDict_CheckExact(kwds)) {
- out_kwd_obj = PyDict_GetItemString(kwds, "out");
- if (out_kwd_obj != NULL) {
- out_kwd_is_tuple = PyTuple_CheckExact(out_kwd_obj);
- if (out_kwd_is_tuple) {
- nout_kwd = PyTuple_GET_SIZE(out_kwd_obj);
- }
- else {
- nout_kwd = 1;
- }
- }
+ arg_objs = PySequence_Fast_ITEMS(args);
+
+ nout = get_out_objects(kwds, &out_kwd_obj, &out_objs);
+ if (nout < 0) {
+ return -1;
}
- for (i = 0; i < nargs + nout_kwd; ++i) {
+ for (i = 0; i < narg + nout; ++i) {
PyObject *obj;
+ int j;
int new_class = 1;
- if (i < nargs) {
- obj = PyTuple_GET_ITEM(args, i);
+ if (i < narg) {
+ obj = arg_objs[i];
}
else {
- if (out_kwd_is_tuple) {
- obj = PyTuple_GET_ITEM(out_kwd_obj, i - nargs);
- }
- else {
- obj = out_kwd_obj;
- }
+ obj = out_objs[i - narg];
}
/*
* Have we seen this class before? If so, ignore.
*/
- if (with_override != NULL) {
- int j;
- for (j = 0; j < num_override_args; j++) {
- new_class = (Py_TYPE(obj) != Py_TYPE(with_override[j]));
- if (!new_class) {
- break;
- }
+ for (j = 0; j < num_override_args; j++) {
+ new_class = (Py_TYPE(obj) != Py_TYPE(with_override[j]));
+ if (!new_class) {
+ break;
}
}
if (new_class) {
@@ -149,31 +181,61 @@ PyUFunc_WithOverride(PyObject *args, PyObject *kwds,
Py_DECREF(method);
goto fail;
}
- if (with_override != NULL) {
- Py_INCREF(obj);
- with_override[num_override_args] = obj;
- }
- if (methods != NULL) {
- methods[num_override_args] = method;
- }
- else {
- Py_DECREF(method);
- }
+ Py_INCREF(obj);
+ with_override[num_override_args] = obj;
+ methods[num_override_args] = method;
++num_override_args;
}
}
return num_override_args;
fail:
- if (methods != NULL) {
- for (i = 0; i < num_override_args; i++) {
- Py_DECREF(methods[i]);
+ for (i = 0; i < num_override_args; i++) {
+ Py_DECREF(with_override[i]);
+ Py_DECREF(methods[i]);
+ }
+ return -1;
+}
+
+/*
+ * Check whether any of a set of input and output args have a non-default
+ * __array_ufunc__ method. Return 1 if so, 0 if not.
+ *
+ * This function primarily exists to help ndarray.__array_ufunc__ determine
+ * whether it can support a ufunc (which is the case only if none of the
+ * operands have an override). Thus, unlike in PyUFunc_CheckOverride, the
+ * actual overrides are not needed and one can stop looking once one is found.
+ *
+ * TODO: move this function and has_non_default_array_ufunc closer to ndarray.
+ */
+NPY_NO_EXPORT int
+PyUFunc_HasOverride(PyObject *args, PyObject *kwds)
+{
+ int i;
+ int nin, nout;
+ PyObject *out_kwd_obj;
+ PyObject **in_objs, **out_objs;
+
+ /* check inputs */
+ nin = PyTuple_Size(args);
+ if (nin < 0) {
+ return -1;
+ }
+ in_objs = PySequence_Fast_ITEMS(args);
+ for (i = 0; i < nin; ++i) {
+ if (has_non_default_array_ufunc(in_objs[i])) {
+ return 1;
}
}
- if (with_override != NULL) {
- for (i = 0; i < num_override_args; i++) {
- Py_DECREF(with_override[i]);
+ /* check outputs, if any */
+ nout = get_out_objects(kwds, &out_kwd_obj, &out_objs);
+ if (nout < 0) {
+ return -1;
+ }
+ for (i = 0; i < nout; i++) {
+ if (has_non_default_array_ufunc(out_objs[i])) {
+ return 1;
}
}
- return -1;
+ return 0;
}
diff --git a/numpy/core/src/private/ufunc_override.h b/numpy/core/src/private/ufunc_override.h
index 2ed1c626f..fd1ee2135 100644
--- a/numpy/core/src/private/ufunc_override.h
+++ b/numpy/core/src/private/ufunc_override.h
@@ -12,4 +12,7 @@
NPY_NO_EXPORT int
PyUFunc_WithOverride(PyObject *args, PyObject *kwds,
PyObject **with_override, PyObject **methods);
+
+NPY_NO_EXPORT int
+PyUFunc_HasOverride(PyObject *args, PyObject *kwds);
#endif