summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/private/ufunc_override.c17
-rw-r--r--numpy/core/src/private/ufunc_override.h3
-rw-r--r--numpy/core/src/umath/ufunc_object.c18
3 files changed, 14 insertions, 24 deletions
diff --git a/numpy/core/src/private/ufunc_override.c b/numpy/core/src/private/ufunc_override.c
index 14cf9b39b..d99a22a5d 100644
--- a/numpy/core/src/private/ufunc_override.c
+++ b/numpy/core/src/private/ufunc_override.c
@@ -11,12 +11,13 @@
static int
normalize___call___args(PyUFuncObject *ufunc, PyObject *args,
- PyObject **normal_args, PyObject **normal_kwds,
- int nin, int nout)
+ PyObject **normal_args, PyObject **normal_kwds)
{
/* ufunc.__call__(*args, **kwds) */
int i;
int not_all_none;
+ int nin = ufunc->nin;
+ int nout = ufunc->nout;
int nargs = PyTuple_GET_SIZE(args);
PyObject *obj;
@@ -325,17 +326,11 @@ fail:
*
* Returns 0 on success and 1 on exception. On success, *result contains the
* result of the operation, if any. If *result is NULL, there is no override.
- *
- * TODO: the ufunc really should always be a ufunc, so that we can rely on
- * using, e.g., ufunc->nin, ufunc->nout, etc. Right now, we cannot, since we
- * also use this function to override np.dot and np.matmul. This should be
- * fixed.
*/
NPY_NO_EXPORT int
PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
PyObject *args, PyObject *kwds,
- PyObject **result,
- int nin, int nout)
+ PyObject **result)
{
int i;
int j;
@@ -376,6 +371,8 @@ PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
normal_kwds = PyDict_Copy(kwds);
out = PyDict_GetItemString(normal_kwds, "out");
if (out != NULL) {
+ int nout = ufunc->nout;
+
if (PyTuple_Check(out)) {
int all_none = 1;
@@ -449,7 +446,7 @@ PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
/* ufunc.__call__ */
if (strcmp(method, "__call__") == 0) {
status = normalize___call___args(ufunc, args, &normal_args,
- &normal_kwds, nin, nout);
+ &normal_kwds);
}
/* ufunc.reduce and ufunc.accumulate */
else if ((strcmp(method, "reduce") == 0) ||
diff --git a/numpy/core/src/private/ufunc_override.h b/numpy/core/src/private/ufunc_override.h
index 68dd0221d..92618453b 100644
--- a/numpy/core/src/private/ufunc_override.h
+++ b/numpy/core/src/private/ufunc_override.h
@@ -11,6 +11,5 @@ PyUFunc_HasOverride(PyObject *args, PyObject *kwds,
NPY_NO_EXPORT int
PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
PyObject *args, PyObject *kwds,
- PyObject **result,
- int nin, int nout);
+ PyObject **result);
#endif
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c
index 605d59e61..04aee4aef 100644
--- a/numpy/core/src/umath/ufunc_object.c
+++ b/numpy/core/src/umath/ufunc_object.c
@@ -4370,8 +4370,7 @@ ufunc_generic_call(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds)
mps[i] = NULL;
}
- errval = PyUFunc_CheckOverride(ufunc, "__call__", args, kwds, &override,
- ufunc->nin, ufunc->nout);
+ errval = PyUFunc_CheckOverride(ufunc, "__call__", args, kwds, &override);
if (errval) {
return NULL;
}
@@ -5088,8 +5087,7 @@ ufunc_outer(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds)
}
/* Note: `nin` and `nout` are not used in the normalization */
- errval = PyUFunc_CheckOverride(ufunc, "outer", args, kwds, &override,
- ufunc->nin, ufunc->nout);
+ errval = PyUFunc_CheckOverride(ufunc, "outer", args, kwds, &override);
if (errval) {
return NULL;
}
@@ -5167,8 +5165,7 @@ ufunc_reduce(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds)
PyObject *override = NULL;
/* `nin` and `nout`, the last two arguments, are not actually used */
- errval = PyUFunc_CheckOverride(ufunc, "reduce", args, kwds, &override,
- 1, ufunc->nout);
+ errval = PyUFunc_CheckOverride(ufunc, "reduce", args, kwds, &override);
if (errval) {
return NULL;
}
@@ -5185,8 +5182,7 @@ ufunc_accumulate(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds)
PyObject *override = NULL;
/* `nin` and `nout`, the last two arguments, are not actually used */
- errval = PyUFunc_CheckOverride(ufunc, "accumulate", args, kwds, &override,
- 1, ufunc->nout);
+ errval = PyUFunc_CheckOverride(ufunc, "accumulate", args, kwds, &override);
if (errval) {
return NULL;
}
@@ -5203,8 +5199,7 @@ ufunc_reduceat(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds)
PyObject *override = NULL;
/* `nin` and `nout`, the last two arguments, are not actually used */
- errval = PyUFunc_CheckOverride(ufunc, "reduceat", args, kwds, &override,
- ufunc->nin, ufunc->nout);
+ errval = PyUFunc_CheckOverride(ufunc, "reduceat", args, kwds, &override);
if (errval) {
return NULL;
}
@@ -5269,8 +5264,7 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
NPY_BEGIN_THREADS_DEF;
/* `nin` and `nout`, the last two arguments, are not actually used */
- errval = PyUFunc_CheckOverride(ufunc, "at", args, NULL, &override,
- ufunc->nin + 1, 0);
+ errval = PyUFunc_CheckOverride(ufunc, "at", args, NULL, &override);
if (errval) {
return NULL;
}