diff options
author | Marten van Kerkwijk <mhvk@astro.utoronto.ca> | 2017-04-06 12:27:35 -0400 |
---|---|---|
committer | Charles Harris <charlesr.harris@gmail.com> | 2017-04-27 13:37:50 -0600 |
commit | 55500b90c0d868621feb04920782109a57d40c12 (patch) | |
tree | 089eb553af32db305a8dd6d2b33c8644bcdd879e | |
parent | 36e84948a448c74efda008a9629c68e9fbb0a218 (diff) | |
download | numpy-55500b90c0d868621feb04920782109a57d40c12.tar.gz |
MAINT: simplify now that __array_ufunc__ overrides ufuncs only.
In particular, use fact that we're guaranteed to have a PyUFuncObject
in PyUFunc_CheckOverride.
-rw-r--r-- | numpy/core/src/private/ufunc_override.c | 17 | ||||
-rw-r--r-- | numpy/core/src/private/ufunc_override.h | 3 | ||||
-rw-r--r-- | numpy/core/src/umath/ufunc_object.c | 18 |
3 files changed, 14 insertions, 24 deletions
diff --git a/numpy/core/src/private/ufunc_override.c b/numpy/core/src/private/ufunc_override.c index 14cf9b39b..d99a22a5d 100644 --- a/numpy/core/src/private/ufunc_override.c +++ b/numpy/core/src/private/ufunc_override.c @@ -11,12 +11,13 @@ static int normalize___call___args(PyUFuncObject *ufunc, PyObject *args, - PyObject **normal_args, PyObject **normal_kwds, - int nin, int nout) + PyObject **normal_args, PyObject **normal_kwds) { /* ufunc.__call__(*args, **kwds) */ int i; int not_all_none; + int nin = ufunc->nin; + int nout = ufunc->nout; int nargs = PyTuple_GET_SIZE(args); PyObject *obj; @@ -325,17 +326,11 @@ fail: * * Returns 0 on success and 1 on exception. On success, *result contains the * result of the operation, if any. If *result is NULL, there is no override. - * - * TODO: the ufunc really should always be a ufunc, so that we can rely on - * using, e.g., ufunc->nin, ufunc->nout, etc. Right now, we cannot, since we - * also use this function to override np.dot and np.matmul. This should be - * fixed. */ NPY_NO_EXPORT int PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method, PyObject *args, PyObject *kwds, - PyObject **result, - int nin, int nout) + PyObject **result) { int i; int j; @@ -376,6 +371,8 @@ PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method, normal_kwds = PyDict_Copy(kwds); out = PyDict_GetItemString(normal_kwds, "out"); if (out != NULL) { + int nout = ufunc->nout; + if (PyTuple_Check(out)) { int all_none = 1; @@ -449,7 +446,7 @@ PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method, /* ufunc.__call__ */ if (strcmp(method, "__call__") == 0) { status = normalize___call___args(ufunc, args, &normal_args, - &normal_kwds, nin, nout); + &normal_kwds); } /* ufunc.reduce and ufunc.accumulate */ else if ((strcmp(method, "reduce") == 0) || diff --git a/numpy/core/src/private/ufunc_override.h b/numpy/core/src/private/ufunc_override.h index 68dd0221d..92618453b 100644 --- a/numpy/core/src/private/ufunc_override.h +++ b/numpy/core/src/private/ufunc_override.h @@ -11,6 +11,5 @@ PyUFunc_HasOverride(PyObject *args, PyObject *kwds, NPY_NO_EXPORT int PyUFunc_CheckOverride(PyUFuncObject *ufunc, char *method, PyObject *args, PyObject *kwds, - PyObject **result, - int nin, int nout); + PyObject **result); #endif diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c index 605d59e61..04aee4aef 100644 --- a/numpy/core/src/umath/ufunc_object.c +++ b/numpy/core/src/umath/ufunc_object.c @@ -4370,8 +4370,7 @@ ufunc_generic_call(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds) mps[i] = NULL; } - errval = PyUFunc_CheckOverride(ufunc, "__call__", args, kwds, &override, - ufunc->nin, ufunc->nout); + errval = PyUFunc_CheckOverride(ufunc, "__call__", args, kwds, &override); if (errval) { return NULL; } @@ -5088,8 +5087,7 @@ ufunc_outer(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds) } /* Note: `nin` and `nout` are not used in the normalization */ - errval = PyUFunc_CheckOverride(ufunc, "outer", args, kwds, &override, - ufunc->nin, ufunc->nout); + errval = PyUFunc_CheckOverride(ufunc, "outer", args, kwds, &override); if (errval) { return NULL; } @@ -5167,8 +5165,7 @@ ufunc_reduce(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds) PyObject *override = NULL; /* `nin` and `nout`, the last two arguments, are not actually used */ - errval = PyUFunc_CheckOverride(ufunc, "reduce", args, kwds, &override, - 1, ufunc->nout); + errval = PyUFunc_CheckOverride(ufunc, "reduce", args, kwds, &override); if (errval) { return NULL; } @@ -5185,8 +5182,7 @@ ufunc_accumulate(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds) PyObject *override = NULL; /* `nin` and `nout`, the last two arguments, are not actually used */ - errval = PyUFunc_CheckOverride(ufunc, "accumulate", args, kwds, &override, - 1, ufunc->nout); + errval = PyUFunc_CheckOverride(ufunc, "accumulate", args, kwds, &override); if (errval) { return NULL; } @@ -5203,8 +5199,7 @@ ufunc_reduceat(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds) PyObject *override = NULL; /* `nin` and `nout`, the last two arguments, are not actually used */ - errval = PyUFunc_CheckOverride(ufunc, "reduceat", args, kwds, &override, - ufunc->nin, ufunc->nout); + errval = PyUFunc_CheckOverride(ufunc, "reduceat", args, kwds, &override); if (errval) { return NULL; } @@ -5269,8 +5264,7 @@ ufunc_at(PyUFuncObject *ufunc, PyObject *args) NPY_BEGIN_THREADS_DEF; /* `nin` and `nout`, the last two arguments, are not actually used */ - errval = PyUFunc_CheckOverride(ufunc, "at", args, NULL, &override, - ufunc->nin + 1, 0); + errval = PyUFunc_CheckOverride(ufunc, "at", args, NULL, &override); if (errval) { return NULL; } |