summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSebastian Berg <sebastian@sipsolutions.net>2020-05-11 13:47:22 -0500
committerSebastian Berg <sebastian@sipsolutions.net>2020-05-11 18:42:14 -0500
commit407f62667cba244ab3d67efbebe52fd3e04b2924 (patch)
tree6185fc88bb588bd5b573dab5dd894dd169a7af5f
parentd3bfaf80c4feb0254f9089e757734e8feea05425 (diff)
downloadnumpy-407f62667cba244ab3d67efbebe52fd3e04b2924.tar.gz
MAINT: (nditer) move duplicated reduce axis check into helper func
-rw-r--r--numpy/core/src/multiarray/nditer_constr.c97
1 files changed, 43 insertions, 54 deletions
diff --git a/numpy/core/src/multiarray/nditer_constr.c b/numpy/core/src/multiarray/nditer_constr.c
index 4701e74d4..83ea52ba9 100644
--- a/numpy/core/src/multiarray/nditer_constr.c
+++ b/numpy/core/src/multiarray/nditer_constr.c
@@ -1411,6 +1411,38 @@ check_mask_for_writemasked_reduction(NpyIter *iter, int iop)
}
/*
+ * Check whether a reduction is OK based on the flags and the operand being
+ * readwrite. This path is deprecated, since usually only specific axes
+ * should be reduced. If axes are specified explicitely, the flag is
+ * unnecessary.
+ */
+static int
+npyiter_check_reduce_ok_and_set_flags(
+ NpyIter *iter, npy_uint32 flags, npyiter_opitflags *op_itflags) {
+ /* If it's writeable, this means a reduction */
+ if (*op_itflags & NPY_OP_ITFLAG_WRITE) {
+ if (!(flags & NPY_ITER_REDUCE_OK)) {
+ PyErr_SetString(PyExc_ValueError,
+ "output operand requires a reduction, but reduction is"
+ "not enabled");
+ return 0;
+ }
+ if (!(*op_itflags & NPY_OP_ITFLAG_READ)) {
+ PyErr_SetString(PyExc_ValueError,
+ "output operand requires a reduction, but is flagged as "
+ "write-only, not read-write");
+ return 0;
+ }
+ NPY_IT_DBG_PRINT("Iterator: Indicating that a reduction is"
+ "occurring\n");
+
+ NIT_ITFLAGS(iter) |= NPY_ITFLAG_REDUCE;
+ *op_itflags |= NPY_OP_ITFLAG_REDUCE;
+ }
+ return 1;
+}
+
+/*
* Fills in the AXISDATA for the 'nop' operands, broadcasting
* the dimensionas as necessary. Also fills
* in the ITERSIZE data member.
@@ -1621,51 +1653,20 @@ npyiter_fill_axisdata(NpyIter *iter, npy_uint32 flags, npyiter_opitflags *op_itf
if (op_flags[iop] & NPY_ITER_NO_BROADCAST) {
goto operand_different_than_broadcast;
}
- /* If it's writeable, this means a reduction */
- if (op_itflags[iop] & NPY_OP_ITFLAG_WRITE) {
- if (!(flags & NPY_ITER_REDUCE_OK)) {
- PyErr_SetString(PyExc_ValueError,
- "output operand requires a reduction, but "
- "reduction is not enabled");
- return 0;
- }
- if (!(op_itflags[iop] & NPY_OP_ITFLAG_READ)) {
- PyErr_SetString(PyExc_ValueError,
- "output operand requires a reduction, but "
- "is flagged as write-only, not "
- "read-write");
- return 0;
- }
- NIT_ITFLAGS(iter) |= NPY_ITFLAG_REDUCE;
- op_itflags[iop] |= NPY_OP_ITFLAG_REDUCE;
+ if (!npyiter_check_reduce_ok_and_set_flags(
+ iter, flags, &op_itflags[iop])) {
+ return 0;
}
}
else {
strides[iop] = PyArray_STRIDE(op_cur, i);
}
}
- else if (bshape == 1) {
- strides[iop] = 0;
- }
else {
strides[iop] = 0;
- /* If it's writeable, this means a reduction */
- if (op_itflags[iop] & NPY_OP_ITFLAG_WRITE) {
- if (!(flags & NPY_ITER_REDUCE_OK)) {
- PyErr_SetString(PyExc_ValueError,
- "output operand requires a reduction, but "
- "reduction is not enabled");
- return 0;
- }
- if (!(op_itflags[iop] & NPY_OP_ITFLAG_READ)) {
- PyErr_SetString(PyExc_ValueError,
- "output operand requires a reduction, but "
- "is flagged as write-only, not "
- "read-write");
- return 0;
- }
- NIT_ITFLAGS(iter) |= NPY_ITFLAG_REDUCE;
- op_itflags[iop] |= NPY_OP_ITFLAG_REDUCE;
+ if (!npyiter_check_reduce_ok_and_set_flags(
+ iter, flags, &op_itflags[iop])) {
+ return 0;
}
}
}
@@ -2536,27 +2537,15 @@ npyiter_new_temp_array(NpyIter *iter, PyTypeObject *subtype,
if (shape == NULL) {
/*
* If deleting this axis produces a reduction, but
- * reduction wasn't enabled, throw an error
+ * reduction wasn't enabled, throw an error.
+ * NOTE: We currently always allow new-axis if the iteration
+ * size is 1 (thus allowing broadcasting sometimes).
*/
if (NAD_SHAPE(axisdata) != 1) {
- if (!(flags & NPY_ITER_REDUCE_OK)) {
- PyErr_SetString(PyExc_ValueError,
- "output requires a reduction, but "
- "reduction is not enabled");
- return NULL;
- }
- if (!((*op_itflags) & NPY_OP_ITFLAG_READ)) {
- PyErr_SetString(PyExc_ValueError,
- "output requires a reduction, but "
- "is flagged as write-only, not read-write");
+ if (!npyiter_check_reduce_ok_and_set_flags(
+ iter, flags, op_itflags)) {
return NULL;
}
-
- NPY_IT_DBG_PRINT("Iterator: Indicating that a "
- "reduction is occurring\n");
- /* Indicate that a reduction is occurring */
- NIT_ITFLAGS(iter) |= NPY_ITFLAG_REDUCE;
- (*op_itflags) |= NPY_OP_ITFLAG_REDUCE;
}
}
}