summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMarten van Kerkwijk <mhvk@astro.utoronto.ca>2017-04-06 12:27:35 -0400
committerCharles Harris <charlesr.harris@gmail.com>2017-04-27 13:37:50 -0600
commit55500b90c0d868621feb04920782109a57d40c12 (patch)
tree089eb553af32db305a8dd6d2b33c8644bcdd879e
parent36e84948a448c74efda008a9629c68e9fbb0a218 (diff)
downloadnumpy-55500b90c0d868621feb04920782109a57d40c12.tar.gz
MAINT: simplify now that __array_ufunc__ overrides ufuncs only.
In particular, use fact that we're guaranteed to have a PyUFuncObject in PyUFunc_CheckOverride.
-rw-r--r--numpy/core/src/private/ufunc_override.c17
-rw-r--r--numpy/core/src/private/ufunc_override.h3
-rw-r--r--numpy/core/src/umath/ufunc_object.c18
3 files changed, 14 insertions, 24 deletions
diff --git a/numpy/core/src/private/ufunc_override.c b/numpy/core/src/private/ufunc_override.c
index 14cf9b39b..d99a22a5d 100644
--- a/numpy/core/src/private/ufunc_override.c
+++ b/numpy/core/src/private/ufunc_override.c
@@ -11,12 +11,13 @@
static int
normalize___call___args(PyUFuncObject *ufunc, PyObject *args,
- PyObject **normal_args, PyObject **normal_kwds,
- int nin, int nout)
+ PyObject **normal_args, PyObject **normal_kwds)
{
/* ufunc.__call__(*args, **kwds) */
int i;
int not_all_none;
+ int nin = ufunc->nin;
+ int nout = ufunc->nout;
int nargs = PyTuple_GET_SIZE(args);
PyObject *obj;
@@ -325,17 +326,11 @@ fail:
*
* Returns 0 on success and 1 on exception. On success, *result contains the
* result of the operation, if any. If *result is NULL, there is no override.
- *
- * TODO: the ufunc really should always be a ufunc, so that we can rely on
- * using, e.g., ufunc->nin, ufunc->nout, etc. Right now, we cannot, since we
- * also use this function to override np.dot and np.matmul. This should be
- * fixed.
*/
NPY_NO_EXPORT int
PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
PyObject *args, PyObject *kwds,
- PyObject **result,
- int nin, int nout)
+ PyObject **result)
{
int i;
int j;
@@ -376,6 +371,8 @@ PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
normal_kwds = PyDict_Copy(kwds);
out = PyDict_GetItemString(normal_kwds, "out");
if (out != NULL) {
+ int nout = ufunc->nout;
+
if (PyTuple_Check(out)) {
int all_none = 1;
@@ -449,7 +446,7 @@ PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
/* ufunc.__call__ */
if (strcmp(method, "__call__") == 0) {
status = normalize___call___args(ufunc, args, &normal_args,
- &normal_kwds, nin, nout);
+ &normal_kwds);
}
/* ufunc.reduce and ufunc.accumulate */
else if ((strcmp(method, "reduce") == 0) ||
diff --git a/numpy/core/src/private/ufunc_override.h b/numpy/core/src/private/ufunc_override.h
index 68dd0221d..92618453b 100644
--- a/numpy/core/src/private/ufunc_override.h
+++ b/numpy/core/src/private/ufunc_override.h
@@ -11,6 +11,5 @@ PyUFunc_HasOverride(PyObject *args, PyObject *kwds,
NPY_NO_EXPORT int
PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method,
PyObject *args, PyObject *kwds,
- PyObject **result,
- int nin, int nout);
+ PyObject **result);
#endif
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c
index 605d59e61..04aee4aef 100644
--- a/numpy/core/src/umath/ufunc_object.c
+++ b/numpy/core/src/umath/ufunc_object.c
@@ -4370,8 +4370,7 @@ ufunc_generic_call(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds)
mps[i] = NULL;
}
- errval = PyUFunc_CheckOverride(ufunc, "__call__", args, kwds, &override,
- ufunc->nin, ufunc->nout);
+ errval = PyUFunc_CheckOverride(ufunc, "__call__", args, kwds, &override);
if (errval) {
return NULL;
}
@@ -5088,8 +5087,7 @@ ufunc_outer(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds)
}
/* Note: `nin` and `nout` are not used in the normalization */
- errval = PyUFunc_CheckOverride(ufunc, "outer", args, kwds, &override,
- ufunc->nin, ufunc->nout);
+ errval = PyUFunc_CheckOverride(ufunc, "outer", args, kwds, &override);
if (errval) {
return NULL;
}
@@ -5167,8 +5165,7 @@ ufunc_reduce(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds)
PyObject *override = NULL;
/* `nin` and `nout`, the last two arguments, are not actually used */
- errval = PyUFunc_CheckOverride(ufunc, "reduce", args, kwds, &override,
- 1, ufunc->nout);
+ errval = PyUFunc_CheckOverride(ufunc, "reduce", args, kwds, &override);
if (errval) {
return NULL;
}
@@ -5185,8 +5182,7 @@ ufunc_accumulate(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds)
PyObject *override = NULL;
/* `nin` and `nout`, the last two arguments, are not actually used */
- errval = PyUFunc_CheckOverride(ufunc, "accumulate", args, kwds, &override,
- 1, ufunc->nout);
+ errval = PyUFunc_CheckOverride(ufunc, "accumulate", args, kwds, &override);
if (errval) {
return NULL;
}
@@ -5203,8 +5199,7 @@ ufunc_reduceat(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds)
PyObject *override = NULL;
/* `nin` and `nout`, the last two arguments, are not actually used */
- errval = PyUFunc_CheckOverride(ufunc, "reduceat", args, kwds, &override,
- ufunc->nin, ufunc->nout);
+ errval = PyUFunc_CheckOverride(ufunc, "reduceat", args, kwds, &override);
if (errval) {
return NULL;
}
@@ -5269,8 +5264,7 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args)
NPY_BEGIN_THREADS_DEF;
/* `nin` and `nout`, the last two arguments, are not actually used */
- errval = PyUFunc_CheckOverride(ufunc, "at", args, NULL, &override,
- ufunc->nin + 1, 0);
+ errval = PyUFunc_CheckOverride(ufunc, "at", args, NULL, &override);
if (errval) {
return NULL;
}