summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMark Wiebe <mwwiebe@gmail.com>2013-01-07 15:20:04 -0800
committerMark Wiebe <mwwiebe@gmail.com>2013-01-17 11:44:55 -0800
commitcac3de50a2cafe1114b0671dd55ec2d1f6f2601a (patch)
treec8cae8e648d46add487292fa22292d3f85740e62
parentfa9dbef4020bd242a9df6215bb3b9a10c8815848 (diff)
downloadnumpy-cac3de50a2cafe1114b0671dd55ec2d1f6f2601a.tar.gz
BUG: Fix for generalized ufunc zero-sized input case
-rw-r--r--numpy/core/src/umath/ufunc_object.c35
1 files changed, 33 insertions, 2 deletions
diff --git a/numpy/core/src/umath/ufunc_object.c b/numpy/core/src/umath/ufunc_object.c
index 1ce754921..51bd446ba 100644
--- a/numpy/core/src/umath/ufunc_object.c
+++ b/numpy/core/src/umath/ufunc_object.c
@@ -1985,7 +1985,8 @@ PyUFunc_GeneralizedFunction(PyUFuncObject *ufunc,
/* Create the iterator */
iter = NpyIter_AdvancedNew(nop, op, NPY_ITER_MULTI_INDEX|
NPY_ITER_REFS_OK|
- NPY_ITER_REDUCE_OK,
+ NPY_ITER_REDUCE_OK|
+ NPY_ITER_ZEROSIZE_OK,
order, NPY_UNSAFE_CASTING, op_flags,
dtypes, iter_ndim,
op_axes, iter_shape, 0);
@@ -2074,8 +2075,8 @@ PyUFunc_GeneralizedFunction(PyUFuncObject *ufunc,
NPY_UF_DBG_PRINT("Executing inner loop\n");
- /* Do the ufunc loop */
if (NpyIter_GetIterSize(iter) != 0) {
+ /* Do the ufunc loop */
NpyIter_IterNextFunc *iternext;
char **dataptr;
npy_intp *count_ptr;
@@ -2094,6 +2095,36 @@ PyUFunc_GeneralizedFunction(PyUFuncObject *ufunc,
inner_dimensions[0] = *count_ptr;
innerloop(dataptr, inner_dimensions, inner_strides, innerloopdata);
} while (iternext(iter));
+ } else {
+ /**
+ * For each output operand, check if it has non-zero size,
+ * and assign the identity if it does. For example, a dot
+ * product of two zero-length arrays will be a scalar,
+ * which has size one.
+ */
+ for (i = nin; i < nop; ++i) {
+ if (PyArray_SIZE(op[i]) != 0) {
+ switch (ufunc->identity) {
+ case PyUFunc_Zero:
+ assign_reduce_identity_zero(op[i]);
+ break;
+ case PyUFunc_One:
+ assign_reduce_identity_one(op[i]);
+ break;
+ case PyUFunc_None:
+ case PyUFunc_ReorderableNone:
+ PyErr_Format(PyExc_ValueError,
+ "ufunc %s ",
+ ufunc_name);
+ goto fail;
+ default:
+ PyErr_Format(PyExc_ValueError,
+ "ufunc %s has an invalid identity for reduction",
+ ufunc_name);
+ goto fail;
+ }
+ }
+ }
}
/* Check whether any errors occurred during the loop */