diff options
author | Eric Wieser <wieser.eric@gmail.com> | 2018-04-26 22:06:34 -0700 |
---|---|---|
committer | Eric Wieser <wieser.eric@gmail.com> | 2018-06-10 23:16:16 -0700 |
commit | 8eeab6d3805bd35d87d4d7506451ad6eb6a51aed (patch) | |
tree | 2828839a4daf7ae466e3cabeec3ee916c2024245 /numpy/core | |
parent | 3bb11d6f1758e4d3d6ca812749cc5145f82713b3 (diff) | |
download | numpy-8eeab6d3805bd35d87d4d7506451ad6eb6a51aed.tar.gz |
MAINT: Extract a helper function to apply __array_wrap__
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/src/umath/ufunc_object.c | 112 |
1 files changed, 79 insertions, 33 deletions
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c index 5e92bc991..fd5ca3904 100644 --- a/numpy/core/src/umath/ufunc_object.c +++ b/numpy/core/src/umath/ufunc_object.c @@ -71,6 +71,13 @@ typedef struct { provided, then this is NULL. */ } ufunc_full_args; +/* C representation of the context argument to __array_wrap__ */ +typedef struct { + PyUFuncObject *ufunc; + ufunc_full_args args; + int out_i; +} _ufunc_context; + /* Get the arg tuple to pass in the context argument to __array_wrap__ and * __array_prepare__. * @@ -4485,6 +4492,67 @@ handle_out: return; } +/* + * Apply the __array_wrap__ function with the given array and content. + * + * Interprets wrap=None and wrap=NULL as intended by _find_array_wrap + * + * Steals a reference to obj and wrap. + * Pass context=NULL to indicate there is no context. + */ +static PyObject * +_apply_array_wrap( + PyObject *wrap, PyArrayObject *obj, _ufunc_context const *context) { + if (wrap == NULL) { + /* default behavior */ + return PyArray_Return(obj); + } + else if (wrap == Py_None) { + Py_DECREF(wrap); + return (PyObject *)obj; + } + else { + PyObject *res; + PyObject *py_context = NULL; + + /* Convert the context object to a tuple, if present */ + if (context == NULL) { + py_context = Py_None; + Py_INCREF(py_context); + } + else { + PyObject *args_tup; + /* Call the method with appropriate context */ + args_tup = _get_wrap_prepare_args(context->args); + if (args_tup == NULL) { + goto fail; + } + py_context = Py_BuildValue("OOi", + context->ufunc, args_tup, context->out_i); + Py_DECREF(args_tup); + if (py_context == NULL) { + goto fail; + } + } + /* try __array_wrap__(obj, context) */ + res = PyObject_CallFunctionObjArgs(wrap, obj, py_context, NULL); + Py_DECREF(py_context); + + /* try __array_wrap__(obj) if the context argument is not accepted */ + if (res == NULL && PyErr_ExceptionMatches(PyExc_TypeError)) { + PyErr_Clear(); + res = PyObject_CallFunctionObjArgs(wrap, obj, NULL); + } + Py_DECREF(wrap); + Py_DECREF(obj); + return res; + fail: + Py_DECREF(wrap); + Py_DECREF(obj); + return NULL; + } +} + static PyObject * ufunc_generic_call(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds) @@ -4555,42 +4623,20 @@ ufunc_generic_call(PyUFuncObject *ufunc, PyObject *args, PyObject *kwds) /* wrap outputs */ for (i = 0; i < ufunc->nout; i++) { int j = ufunc->nin+i; - PyObject *wrap = wraparr[i]; - - if (wrap == NULL) { - /* default behavior */ - retobj[i] = PyArray_Return(mps[j]); - } - else if (wrap == Py_None) { - Py_DECREF(wrap); - retobj[i] = (PyObject *)mps[j]; - } - else { - PyObject *res; - PyObject *args_tup; + _ufunc_context context; + PyObject *wrapped; - /* Call the method with appropriate context */ - args_tup = _get_wrap_prepare_args(full_args); - if (args_tup == NULL) { - goto fail; - } - res = PyObject_CallFunction( - wrap, "O(OOi)", mps[j], ufunc, args_tup, i); - Py_DECREF(args_tup); + context.ufunc = ufunc; + context.args = full_args; + context.out_i = i; - /* Handle __array_wrap__ that does not accept a context argument */ - if (res == NULL && PyErr_ExceptionMatches(PyExc_TypeError)) { - PyErr_Clear(); - res = PyObject_CallFunctionObjArgs(wrap, mps[j], NULL); - } - Py_DECREF(wrap); - Py_DECREF(mps[j]); - mps[j] = NULL; /* Prevent fail double-freeing this */ - if (res == NULL) { - goto fail; - } - retobj[i] = res; + wrapped = _apply_array_wrap(wraparr[i], mps[j], &context); + mps[j] = NULL; /* Prevent fail double-freeing this */ + if (wrapped == NULL) { + goto fail; } + + retobj[i] = wrapped; } Py_XDECREF(full_args.in); |