diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/multiarray/methods.c | 15 | ||||
-rw-r--r-- | numpy/core/src/private/ufunc_override.c | 202 | ||||
-rw-r--r-- | numpy/core/src/private/ufunc_override.h | 3 |
3 files changed, 144 insertions, 76 deletions
diff --git a/numpy/core/src/multiarray/methods.c b/numpy/core/src/multiarray/methods.c index d6f2577a3..2e836d1d0 100644 --- a/numpy/core/src/multiarray/methods.c +++ b/numpy/core/src/multiarray/methods.c @@ -976,9 +976,12 @@ array_ufunc(PyArrayObject *self, PyObject *args, PyObject *kwds) { PyObject *ufunc, *method_name, *normal_args, *ufunc_method; PyObject *result = NULL; - int num_override_args; + int has_override; - if (PyTuple_Size(args) < 2) { + assert(PyTuple_CheckExact(args)); + assert(kwds == NULL || PyDict_CheckExact(kwds)); + + if (PyTuple_GET_SIZE(args) < 2) { PyErr_SetString(PyExc_TypeError, "__array_ufunc__ requires at least 2 arguments"); return NULL; @@ -988,11 +991,11 @@ array_ufunc(PyArrayObject *self, PyObject *args, PyObject *kwds) return NULL; } /* ndarray cannot handle overrides itself */ - num_override_args = PyUFunc_WithOverride(normal_args, kwds, NULL, NULL); - if (num_override_args == -1) { - return NULL; + has_override = PyUFunc_HasOverride(normal_args, kwds); + if (has_override < 0) { + goto cleanup; } - if (num_override_args) { + else if (has_override) { result = Py_NotImplemented; Py_INCREF(Py_NotImplemented); goto cleanup; diff --git a/numpy/core/src/private/ufunc_override.c b/numpy/core/src/private/ufunc_override.c index 522b1744d..69c3cc56c 100644 --- a/numpy/core/src/private/ufunc_override.c +++ b/numpy/core/src/private/ufunc_override.c @@ -54,11 +54,71 @@ get_non_default_array_ufunc(PyObject *obj) } /* - * Check whether a set of input and output args have a non-default - * `__array_ufunc__` method. Return the number of overrides, setting - * corresponding objects in PyObject array with_override and the corresponding - * __array_ufunc__ methods in methods (both only if not NULL, and both using - * new references). + * Check whether an object has __array_ufunc__ defined on its class and it + * is not the default, i.e., the object is not an ndarray, and its + * __array_ufunc__ is not the same as that of ndarray. + * + * Returns 1 if this is the case, 0 if not. + */ + +static int +has_non_default_array_ufunc(PyObject * obj) +{ + PyObject *method = get_non_default_array_ufunc(obj); + if (method) { + Py_DECREF(method); + return 1; + } + else { + return 0; + } +} + +/* + * Get possible out argument from kwds, and returns the number of outputs + * contained within it: if a tuple, the number of elements in it, 1 otherwise. + * The out argument itself is returned in out_kwd_obj, and the outputs + * in the out_obj array (all as borrowed references). + * + * Returns -1 if kwds is not a dict, 0 if no outputs found. + */ +static int +get_out_objects(PyObject *kwds, PyObject **out_kwd_obj, PyObject ***out_objs) +{ + if (kwds == NULL) { + return 0; + } + if (!PyDict_CheckExact(kwds)) { + PyErr_SetString(PyExc_TypeError, + "Internal Numpy error: call to PyUFunc_WithOverride " + "with non-dict kwds"); + return -1; + } + /* borrowed reference */ + *out_kwd_obj = PyDict_GetItemString(kwds, "out"); + if (*out_kwd_obj == NULL) { + return 0; + } + if (PyTuple_CheckExact(*out_kwd_obj)) { + *out_objs = PySequence_Fast_ITEMS(*out_kwd_obj); + return PySequence_Fast_GET_SIZE(*out_kwd_obj); + } + else { + *out_objs = out_kwd_obj; + return 1; + } +} + +/* + * For each positional argument and each argument in a possible "out" + * keyword, look for overrides of the standard ufunc behaviour, i.e., + * non-default __array_ufunc__ methods. + * + * Returns the number of overrides, setting corresponding objects + * in PyObject array ``with_override`` and the corresponding + * __array_ufunc__ methods in ``methods`` (both using new references). + * + * Only the first override for a given class is returned. * * returns -1 on failure. */ @@ -67,68 +127,40 @@ PyUFunc_WithOverride(PyObject *args, PyObject *kwds, PyObject **with_override, PyObject **methods) { int i; - - int nargs; - int nout_kwd = 0; - int out_kwd_is_tuple = 0; int num_override_args = 0; + int narg, nout = 0; + PyObject *out_kwd_obj; + PyObject **arg_objs, **out_objs; - PyObject *out_kwd_obj = NULL; - /* - * Check inputs - */ - if (!PyTuple_Check(args)) { - PyErr_SetString(PyExc_TypeError, - "Internal Numpy error: call to PyUFunc_HasOverride " - "with non-tuple"); - goto fail; + narg = PyTuple_Size(args); + if (narg < 0) { + return -1; } - nargs = PyTuple_GET_SIZE(args); - if (nargs > NPY_MAXARGS) { - PyErr_SetString(PyExc_TypeError, - "Internal Numpy error: too many arguments in call " - "to PyUFunc_HasOverride"); - goto fail; - } - /* be sure to include possible 'out' keyword argument. */ - if (kwds && PyDict_CheckExact(kwds)) { - out_kwd_obj = PyDict_GetItemString(kwds, "out"); - if (out_kwd_obj != NULL) { - out_kwd_is_tuple = PyTuple_CheckExact(out_kwd_obj); - if (out_kwd_is_tuple) { - nout_kwd = PyTuple_GET_SIZE(out_kwd_obj); - } - else { - nout_kwd = 1; - } - } + arg_objs = PySequence_Fast_ITEMS(args); + + nout = get_out_objects(kwds, &out_kwd_obj, &out_objs); + if (nout < 0) { + return -1; } - for (i = 0; i < nargs + nout_kwd; ++i) { + for (i = 0; i < narg + nout; ++i) { PyObject *obj; + int j; int new_class = 1; - if (i < nargs) { - obj = PyTuple_GET_ITEM(args, i); + if (i < narg) { + obj = arg_objs[i]; } else { - if (out_kwd_is_tuple) { - obj = PyTuple_GET_ITEM(out_kwd_obj, i - nargs); - } - else { - obj = out_kwd_obj; - } + obj = out_objs[i - narg]; } /* * Have we seen this class before? If so, ignore. */ - if (with_override != NULL) { - int j; - for (j = 0; j < num_override_args; j++) { - new_class = (Py_TYPE(obj) != Py_TYPE(with_override[j])); - if (!new_class) { - break; - } + for (j = 0; j < num_override_args; j++) { + new_class = (Py_TYPE(obj) != Py_TYPE(with_override[j])); + if (!new_class) { + break; } } if (new_class) { @@ -149,31 +181,61 @@ PyUFunc_WithOverride(PyObject *args, PyObject *kwds, Py_DECREF(method); goto fail; } - if (with_override != NULL) { - Py_INCREF(obj); - with_override[num_override_args] = obj; - } - if (methods != NULL) { - methods[num_override_args] = method; - } - else { - Py_DECREF(method); - } + Py_INCREF(obj); + with_override[num_override_args] = obj; + methods[num_override_args] = method; ++num_override_args; } } return num_override_args; fail: - if (methods != NULL) { - for (i = 0; i < num_override_args; i++) { - Py_DECREF(methods[i]); + for (i = 0; i < num_override_args; i++) { + Py_DECREF(with_override[i]); + Py_DECREF(methods[i]); + } + return -1; +} + +/* + * Check whether any of a set of input and output args have a non-default + * __array_ufunc__ method. Return 1 if so, 0 if not. + * + * This function primarily exists to help ndarray.__array_ufunc__ determine + * whether it can support a ufunc (which is the case only if none of the + * operands have an override). Thus, unlike in PyUFunc_CheckOverride, the + * actual overrides are not needed and one can stop looking once one is found. + * + * TODO: move this function and has_non_default_array_ufunc closer to ndarray. + */ +NPY_NO_EXPORT int +PyUFunc_HasOverride(PyObject *args, PyObject *kwds) +{ + int i; + int nin, nout; + PyObject *out_kwd_obj; + PyObject **in_objs, **out_objs; + + /* check inputs */ + nin = PyTuple_Size(args); + if (nin < 0) { + return -1; + } + in_objs = PySequence_Fast_ITEMS(args); + for (i = 0; i < nin; ++i) { + if (has_non_default_array_ufunc(in_objs[i])) { + return 1; } } - if (with_override != NULL) { - for (i = 0; i < num_override_args; i++) { - Py_DECREF(with_override[i]); + /* check outputs, if any */ + nout = get_out_objects(kwds, &out_kwd_obj, &out_objs); + if (nout < 0) { + return -1; + } + for (i = 0; i < nout; i++) { + if (has_non_default_array_ufunc(out_objs[i])) { + return 1; } } - return -1; + return 0; } diff --git a/numpy/core/src/private/ufunc_override.h b/numpy/core/src/private/ufunc_override.h index 2ed1c626f..fd1ee2135 100644 --- a/numpy/core/src/private/ufunc_override.h +++ b/numpy/core/src/private/ufunc_override.h @@ -12,4 +12,7 @@ NPY_NO_EXPORT int PyUFunc_WithOverride(PyObject *args, PyObject *kwds, PyObject **with_override, PyObject **methods); + +NPY_NO_EXPORT int +PyUFunc_HasOverride(PyObject *args, PyObject *kwds); #endif |