diff options
author | Sebastian Berg <sebastian@sipsolutions.net> | 2020-05-11 13:47:22 -0500 |
---|---|---|
committer | Sebastian Berg <sebastian@sipsolutions.net> | 2020-05-11 18:42:14 -0500 |
commit | 407f62667cba244ab3d67efbebe52fd3e04b2924 (patch) | |
tree | 6185fc88bb588bd5b573dab5dd894dd169a7af5f | |
parent | d3bfaf80c4feb0254f9089e757734e8feea05425 (diff) | |
download | numpy-407f62667cba244ab3d67efbebe52fd3e04b2924.tar.gz |
MAINT: (nditer) move duplicated reduce axis check into helper func
-rw-r--r-- | numpy/core/src/multiarray/nditer_constr.c | 97 |
1 files changed, 43 insertions, 54 deletions
diff --git a/numpy/core/src/multiarray/nditer_constr.c b/numpy/core/src/multiarray/nditer_constr.c index 4701e74d4..83ea52ba9 100644 --- a/numpy/core/src/multiarray/nditer_constr.c +++ b/numpy/core/src/multiarray/nditer_constr.c @@ -1411,6 +1411,38 @@ check_mask_for_writemasked_reduction(NpyIter *iter, int iop) } /* + * Check whether a reduction is OK based on the flags and the operand being + * readwrite. This path is deprecated, since usually only specific axes + * should be reduced. If axes are specified explicitely, the flag is + * unnecessary. + */ +static int +npyiter_check_reduce_ok_and_set_flags( + NpyIter *iter, npy_uint32 flags, npyiter_opitflags *op_itflags) { + /* If it's writeable, this means a reduction */ + if (*op_itflags & NPY_OP_ITFLAG_WRITE) { + if (!(flags & NPY_ITER_REDUCE_OK)) { + PyErr_SetString(PyExc_ValueError, + "output operand requires a reduction, but reduction is" + "not enabled"); + return 0; + } + if (!(*op_itflags & NPY_OP_ITFLAG_READ)) { + PyErr_SetString(PyExc_ValueError, + "output operand requires a reduction, but is flagged as " + "write-only, not read-write"); + return 0; + } + NPY_IT_DBG_PRINT("Iterator: Indicating that a reduction is" + "occurring\n"); + + NIT_ITFLAGS(iter) |= NPY_ITFLAG_REDUCE; + *op_itflags |= NPY_OP_ITFLAG_REDUCE; + } + return 1; +} + +/* * Fills in the AXISDATA for the 'nop' operands, broadcasting * the dimensionas as necessary. Also fills * in the ITERSIZE data member. @@ -1621,51 +1653,20 @@ npyiter_fill_axisdata(NpyIter *iter, npy_uint32 flags, npyiter_opitflags *op_itf if (op_flags[iop] & NPY_ITER_NO_BROADCAST) { goto operand_different_than_broadcast; } - /* If it's writeable, this means a reduction */ - if (op_itflags[iop] & NPY_OP_ITFLAG_WRITE) { - if (!(flags & NPY_ITER_REDUCE_OK)) { - PyErr_SetString(PyExc_ValueError, - "output operand requires a reduction, but " - "reduction is not enabled"); - return 0; - } - if (!(op_itflags[iop] & NPY_OP_ITFLAG_READ)) { - PyErr_SetString(PyExc_ValueError, - "output operand requires a reduction, but " - "is flagged as write-only, not " - "read-write"); - return 0; - } - NIT_ITFLAGS(iter) |= NPY_ITFLAG_REDUCE; - op_itflags[iop] |= NPY_OP_ITFLAG_REDUCE; + if (!npyiter_check_reduce_ok_and_set_flags( + iter, flags, &op_itflags[iop])) { + return 0; } } else { strides[iop] = PyArray_STRIDE(op_cur, i); } } - else if (bshape == 1) { - strides[iop] = 0; - } else { strides[iop] = 0; - /* If it's writeable, this means a reduction */ - if (op_itflags[iop] & NPY_OP_ITFLAG_WRITE) { - if (!(flags & NPY_ITER_REDUCE_OK)) { - PyErr_SetString(PyExc_ValueError, - "output operand requires a reduction, but " - "reduction is not enabled"); - return 0; - } - if (!(op_itflags[iop] & NPY_OP_ITFLAG_READ)) { - PyErr_SetString(PyExc_ValueError, - "output operand requires a reduction, but " - "is flagged as write-only, not " - "read-write"); - return 0; - } - NIT_ITFLAGS(iter) |= NPY_ITFLAG_REDUCE; - op_itflags[iop] |= NPY_OP_ITFLAG_REDUCE; + if (!npyiter_check_reduce_ok_and_set_flags( + iter, flags, &op_itflags[iop])) { + return 0; } } } @@ -2536,27 +2537,15 @@ npyiter_new_temp_array(NpyIter *iter, PyTypeObject *subtype, if (shape == NULL) { /* * If deleting this axis produces a reduction, but - * reduction wasn't enabled, throw an error + * reduction wasn't enabled, throw an error. + * NOTE: We currently always allow new-axis if the iteration + * size is 1 (thus allowing broadcasting sometimes). */ if (NAD_SHAPE(axisdata) != 1) { - if (!(flags & NPY_ITER_REDUCE_OK)) { - PyErr_SetString(PyExc_ValueError, - "output requires a reduction, but " - "reduction is not enabled"); - return NULL; - } - if (!((*op_itflags) & NPY_OP_ITFLAG_READ)) { - PyErr_SetString(PyExc_ValueError, - "output requires a reduction, but " - "is flagged as write-only, not read-write"); + if (!npyiter_check_reduce_ok_and_set_flags( + iter, flags, op_itflags)) { return NULL; } - - NPY_IT_DBG_PRINT("Iterator: Indicating that a " - "reduction is occurring\n"); - /* Indicate that a reduction is occurring */ - NIT_ITFLAGS(iter) |= NPY_ITFLAG_REDUCE; - (*op_itflags) |= NPY_OP_ITFLAG_REDUCE; } } } |