diff options
author | Sebastian Berg <sebastian@sipsolutions.net> | 2013-05-08 19:39:16 +0200 |
---|---|---|
committer | Sebastian Berg <sebastian@sipsolutions.net> | 2013-05-08 20:29:26 +0200 |
commit | c018cd8570cab003f81e5b15dc0d6fce6dd7925e (patch) | |
tree | 4471c78dcbf508d1f658f4b9efd5255f09fb4a8e | |
parent | 7d76c744007afa6bb3af3b9d9fb64a65201e3635 (diff) | |
download | numpy-c018cd8570cab003f81e5b15dc0d6fce6dd7925e.tar.gz |
BUG: Fix 0-d array special case from reductions.
This channels scalars through the usual reduction machinery and
modifies it slightly to correctly support scalar reductions of
identity-less ufuncs.
-rw-r--r-- | numpy/core/src/umath/reduction.c | 3 | ||||
-rw-r--r-- | numpy/core/src/umath/ufunc_object.c | 31 | ||||
-rw-r--r-- | numpy/core/tests/test_ufunc.py | 13 |
3 files changed, 23 insertions, 24 deletions
diff --git a/numpy/core/src/umath/reduction.c b/numpy/core/src/umath/reduction.c index f69aea2d0..3f2b94a4a 100644 --- a/numpy/core/src/umath/reduction.c +++ b/numpy/core/src/umath/reduction.c @@ -483,7 +483,8 @@ PyUFunc_ReduceWrapper(PyArrayObject *operand, PyArrayObject *out, if (op_view == NULL) { goto fail; } - if (PyArray_SIZE(op_view) == 0) { + /* empty op_view signals no reduction; but 0-d arrays cannot be empty */ + if ((PyArray_SIZE(op_view) == 0) || (PyArray_NDIM(operand) == 0)) { Py_DECREF(op_view); op_view = NULL; goto finish; diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c index 9c499d322..3d3f63d4e 100644 --- a/numpy/core/src/umath/ufunc_object.c +++ b/numpy/core/src/umath/ufunc_object.c @@ -3684,31 +3684,16 @@ PyUFunc_GenericReduction(PyUFuncObject *ufunc, PyObject *args, * 'prod', et al, also allow a reduction where axis=0, even * though this is technically incorrect. */ - if (operation == UFUNC_REDUCE && - (naxes == 0 || (naxes == 1 && axes[0] == 0))) { + naxes = 0; + + if (!(operation == UFUNC_REDUCE && + (naxes == 0 || (naxes == 1 && axes[0] == 0)))) { + PyErr_Format(PyExc_TypeError, "cannot %s on a scalar", + _reduce_type[operation]); Py_XDECREF(otype); - /* If there's an output parameter, copy the value */ - if (out != NULL) { - if (PyArray_CopyInto(out, mp) < 0) { - Py_DECREF(mp); - return NULL; - } - else { - Py_DECREF(mp); - Py_INCREF(out); - return (PyObject *)out; - } - } - /* Otherwise return the array unscathed */ - else { - return PyArray_Return(mp); - } + Py_DECREF(mp); + return NULL; } - PyErr_Format(PyExc_TypeError, "cannot %s on a scalar", - _reduce_type[operation]); - Py_XDECREF(otype); - Py_DECREF(mp); - return NULL; } /* diff --git a/numpy/core/tests/test_ufunc.py b/numpy/core/tests/test_ufunc.py index dbbd15397..816b22052 100644 --- a/numpy/core/tests/test_ufunc.py +++ b/numpy/core/tests/test_ufunc.py @@ -564,12 +564,25 @@ class TestUfunc(TestCase): assert_equal(np.max(3, axis=0), 3) assert_equal(np.min(2.5, axis=0), 2.5) + # Check scalar behaviour for ufuncs without an identity + assert_equal(np.power.reduce(3), 3) + # Make sure that scalars are coming out from this operation assert_(type(np.prod(np.float32(2.5), axis=0)) is np.float32) assert_(type(np.sum(np.float32(2.5), axis=0)) is np.float32) assert_(type(np.max(np.float32(2.5), axis=0)) is np.float32) assert_(type(np.min(np.float32(2.5), axis=0)) is np.float32) + # check if scalars/0-d arrays get cast + assert_(type(np.any(0, axis=0)) is np.bool_) + + # assert that 0-d arrays get wrapped + class MyArray(np.ndarray): + pass + a = np.array(1).view(MyArray) + assert_(type(np.any(a)) is MyArray) + + def test_casting_out_param(self): # Test that it's possible to do casts on output a = np.ones((200,100), np.int64) |