summaryrefslogtreecommitdiff
path: root/numpy/core
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/core')
-rw-r--r--numpy/core/src/umath/ufunc_object.c42
-rw-r--r--numpy/core/tests/test_umath.py3
2 files changed, 28 insertions, 17 deletions
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c
index fd5ca3904..9348f13dd 100644
--- a/numpy/core/src/umath/ufunc_object.c
+++ b/numpy/core/src/umath/ufunc_object.c
@@ -94,6 +94,10 @@ _get_wrap_prepare_args(ufunc_full_args full_args) {
}
}
+static PyObject *
+_apply_array_wrap(
+ PyObject *wrap, PyArrayObject *obj, _ufunc_context const *context);
+
/* ---------------------------------------------------------------- */
static int
@@ -4148,7 +4152,7 @@ PyUFunc_GenericReduction(PyUFuncObject *ufunc, PyObject *args,
int axes[NPY_MAXDIMS];
PyObject *axes_in = NULL;
PyArrayObject *mp = NULL, *ret = NULL;
- PyObject *op, *res = NULL;
+ PyObject *op;
PyObject *obj_ind, *context;
PyArrayObject *indices = NULL;
PyArray_Descr *otype = NULL;
@@ -4394,25 +4398,31 @@ PyUFunc_GenericReduction(PyUFuncObject *ufunc, PyObject *args,
return NULL;
}
- /* If an output parameter was provided, don't wrap it */
- if (out != NULL) {
- return (PyObject *)ret;
- }
-
- if (Py_TYPE(op) != Py_TYPE(ret)) {
- res = PyObject_CallMethod(op, "__array_wrap__", "O", ret);
- if (res == NULL) {
- PyErr_Clear();
- }
- else if (res == Py_None) {
- Py_DECREF(res);
+ /* Wrap and return the output */
+ {
+ /* Find __array_wrap__ - note that these rules are different to the
+ * normal ufunc path
+ */
+ PyObject *wrap;
+ if (out != NULL) {
+ wrap = Py_None;
+ Py_INCREF(wrap);
+ }
+ else if (Py_TYPE(op) != Py_TYPE(ret)) {
+ wrap = PyObject_GetAttr(op, npy_um_str_array_wrap);
+ if (wrap == NULL) {
+ PyErr_Clear();
+ }
+ else if (!PyCallable_Check(wrap)) {
+ Py_DECREF(wrap);
+ wrap = NULL;
+ }
}
else {
- Py_DECREF(ret);
- return res;
+ wrap = NULL;
}
+ return _apply_array_wrap(wrap, ret, NULL);
}
- return PyArray_Return(ret);
fail:
Py_XDECREF(otype);
diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py
index 4772913be..e71e679fe 100644
--- a/numpy/core/tests/test_umath.py
+++ b/numpy/core/tests/test_umath.py
@@ -1568,13 +1568,14 @@ class TestSpecialMethods(object):
class A(object):
def __array__(self):
- return np.zeros(1)
+ return np.zeros(2)
def __array_wrap__(self, arr, context):
raise RuntimeError
a = A()
assert_raises(RuntimeError, ncu.maximum, a, a)
+ assert_raises(RuntimeError, ncu.maximum.reduce, a)
def test_failing_out_wrap(self):