summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSebastian Berg <sebastian@sipsolutions.net>2013-05-08 19:39:16 +0200
committerSebastian Berg <sebastian@sipsolutions.net>2013-05-08 20:29:26 +0200
commitc018cd8570cab003f81e5b15dc0d6fce6dd7925e (patch)
tree4471c78dcbf508d1f658f4b9efd5255f09fb4a8e
parent7d76c744007afa6bb3af3b9d9fb64a65201e3635 (diff)
downloadnumpy-c018cd8570cab003f81e5b15dc0d6fce6dd7925e.tar.gz
BUG: Fix 0-d array special case from reductions.
This channels scalars through the usual reduction machinery and modifies it slightly to correctly support scalar reductions of identity-less ufuncs.
-rw-r--r--numpy/core/src/umath/reduction.c3
-rw-r--r--numpy/core/src/umath/ufunc_object.c31
-rw-r--r--numpy/core/tests/test_ufunc.py13
3 files changed, 23 insertions, 24 deletions
diff --git a/numpy/core/src/umath/reduction.c b/numpy/core/src/umath/reduction.c
index f69aea2d0..3f2b94a4a 100644
--- a/numpy/core/src/umath/reduction.c
+++ b/numpy/core/src/umath/reduction.c
@@ -483,7 +483,8 @@ PyUFunc_ReduceWrapper(PyArrayObject *operand, PyArrayObject *out,
if (op_view == NULL) {
goto fail;
}
- if (PyArray_SIZE(op_view) == 0) {
+ /* empty op_view signals no reduction; but 0-d arrays cannot be empty */
+ if ((PyArray_SIZE(op_view) == 0) || (PyArray_NDIM(operand) == 0)) {
Py_DECREF(op_view);
op_view = NULL;
goto finish;
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c
index 9c499d322..3d3f63d4e 100644
--- a/numpy/core/src/umath/ufunc_object.c
+++ b/numpy/core/src/umath/ufunc_object.c
@@ -3684,31 +3684,16 @@ PyUFunc_GenericReduction(PyUFuncObject *ufunc, PyObject *args,
* 'prod', et al, also allow a reduction where axis=0, even
* though this is technically incorrect.
*/
- if (operation == UFUNC_REDUCE &&
- (naxes == 0 || (naxes == 1 && axes[0] == 0))) {
+ naxes = 0;
+
+ if (!(operation == UFUNC_REDUCE &&
+ (naxes == 0 || (naxes == 1 && axes[0] == 0)))) {
+ PyErr_Format(PyExc_TypeError, "cannot %s on a scalar",
+ _reduce_type[operation]);
Py_XDECREF(otype);
- /* If there's an output parameter, copy the value */
- if (out != NULL) {
- if (PyArray_CopyInto(out, mp) < 0) {
- Py_DECREF(mp);
- return NULL;
- }
- else {
- Py_DECREF(mp);
- Py_INCREF(out);
- return (PyObject *)out;
- }
- }
- /* Otherwise return the array unscathed */
- else {
- return PyArray_Return(mp);
- }
+ Py_DECREF(mp);
+ return NULL;
}
- PyErr_Format(PyExc_TypeError, "cannot %s on a scalar",
- _reduce_type[operation]);
- Py_XDECREF(otype);
- Py_DECREF(mp);
- return NULL;
}
/*
diff --git a/numpy/core/tests/test_ufunc.py b/numpy/core/tests/test_ufunc.py
index dbbd15397..816b22052 100644
--- a/numpy/core/tests/test_ufunc.py
+++ b/numpy/core/tests/test_ufunc.py
@@ -564,12 +564,25 @@ class TestUfunc(TestCase):
assert_equal(np.max(3, axis=0), 3)
assert_equal(np.min(2.5, axis=0), 2.5)
+ # Check scalar behaviour for ufuncs without an identity
+ assert_equal(np.power.reduce(3), 3)
+
# Make sure that scalars are coming out from this operation
assert_(type(np.prod(np.float32(2.5), axis=0)) is np.float32)
assert_(type(np.sum(np.float32(2.5), axis=0)) is np.float32)
assert_(type(np.max(np.float32(2.5), axis=0)) is np.float32)
assert_(type(np.min(np.float32(2.5), axis=0)) is np.float32)
+ # check if scalars/0-d arrays get cast
+ assert_(type(np.any(0, axis=0)) is np.bool_)
+
+ # assert that 0-d arrays get wrapped
+ class MyArray(np.ndarray):
+ pass
+ a = np.array(1).view(MyArray)
+ assert_(type(np.any(a)) is MyArray)
+
+
def test_casting_out_param(self):
# Test that it's possible to do casts on output
a = np.ones((200,100), np.int64)