diff options
Diffstat (limited to 'numpy/core')
-rw-r--r-- | numpy/core/src/umath/ufunc_object.c | 42 | ||||
-rw-r--r-- | numpy/core/tests/test_umath.py | 3 |
2 files changed, 28 insertions, 17 deletions
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c index fd5ca3904..9348f13dd 100644 --- a/numpy/core/src/umath/ufunc_object.c +++ b/numpy/core/src/umath/ufunc_object.c @@ -94,6 +94,10 @@ _get_wrap_prepare_args(ufunc_full_args full_args) { } } +static PyObject * +_apply_array_wrap( + PyObject *wrap, PyArrayObject *obj, _ufunc_context const *context); + /* ---------------------------------------------------------------- */ static int @@ -4148,7 +4152,7 @@ PyUFunc_GenericReduction(PyUFuncObject *ufunc, PyObject *args, int axes[NPY_MAXDIMS]; PyObject *axes_in = NULL; PyArrayObject *mp = NULL, *ret = NULL; - PyObject *op, *res = NULL; + PyObject *op; PyObject *obj_ind, *context; PyArrayObject *indices = NULL; PyArray_Descr *otype = NULL; @@ -4394,25 +4398,31 @@ PyUFunc_GenericReduction(PyUFuncObject *ufunc, PyObject *args, return NULL; } - /* If an output parameter was provided, don't wrap it */ - if (out != NULL) { - return (PyObject *)ret; - } - - if (Py_TYPE(op) != Py_TYPE(ret)) { - res = PyObject_CallMethod(op, "__array_wrap__", "O", ret); - if (res == NULL) { - PyErr_Clear(); - } - else if (res == Py_None) { - Py_DECREF(res); + /* Wrap and return the output */ + { + /* Find __array_wrap__ - note that these rules are different to the + * normal ufunc path + */ + PyObject *wrap; + if (out != NULL) { + wrap = Py_None; + Py_INCREF(wrap); + } + else if (Py_TYPE(op) != Py_TYPE(ret)) { + wrap = PyObject_GetAttr(op, npy_um_str_array_wrap); + if (wrap == NULL) { + PyErr_Clear(); + } + else if (!PyCallable_Check(wrap)) { + Py_DECREF(wrap); + wrap = NULL; + } } else { - Py_DECREF(ret); - return res; + wrap = NULL; } + return _apply_array_wrap(wrap, ret, NULL); } - return PyArray_Return(ret); fail: Py_XDECREF(otype); diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py index 4772913be..e71e679fe 100644 --- a/numpy/core/tests/test_umath.py +++ b/numpy/core/tests/test_umath.py @@ -1568,13 +1568,14 @@ class TestSpecialMethods(object): class A(object): def __array__(self): - return np.zeros(1) + return np.zeros(2) def __array_wrap__(self, arr, context): raise RuntimeError a = A() assert_raises(RuntimeError, ncu.maximum, a, a) + assert_raises(RuntimeError, ncu.maximum.reduce, a) def test_failing_out_wrap(self): |