summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/umath/reduction.c46
1 files changed, 16 insertions, 30 deletions
diff --git a/numpy/core/src/umath/reduction.c b/numpy/core/src/umath/reduction.c
index f1423d8b9..86cc20eb1 100644
--- a/numpy/core/src/umath/reduction.c
+++ b/numpy/core/src/umath/reduction.c
@@ -166,7 +166,7 @@ PyArray_CopyInitialReduceValues(
* identity : If Py_None, PyArray_CopyInitialReduceValues is used, otherwise
* this value is used to initialize the result to
* the reduction's unit.
- * loop : The loop which does the reduction.
+ * loop : `reduce_loop` from `ufunc_object.c`. TODO: Refactor
* data : Data which is passed to the inner loop.
* buffersize : Buffer size for the iterator. For the default, pass in 0.
* funcname : The name of the reduction function, for error messages.
@@ -182,18 +182,15 @@ PyArray_CopyInitialReduceValues(
* generalized ufuncs!)
*/
NPY_NO_EXPORT PyArrayObject *
-PyUFunc_ReduceWrapper(PyArrayObject *operand, PyArrayObject *out,
- PyArrayObject *wheremask,
- PyArray_Descr *operand_dtype,
- PyArray_Descr *result_dtype,
- NPY_CASTING casting,
- npy_bool *axis_flags, int reorderable,
- int keepdims,
- PyObject *identity,
- PyArray_ReduceLoopFunc *loop,
- void *data, npy_intp buffersize, const char *funcname,
- int errormask)
+PyUFunc_ReduceWrapper(
+ PyArrayObject *operand, PyArrayObject *out, PyArrayObject *wheremask,
+ PyArray_Descr *operand_dtype, PyArray_Descr *result_dtype,
+ NPY_CASTING casting,
+ npy_bool *axis_flags, int reorderable, int keepdims,
+ PyObject *identity, PyArray_ReduceLoopFunc *loop,
+ void *data, npy_intp buffersize, const char *funcname, int errormask)
{
+ assert(loop != NULL);
PyArrayObject *result = NULL;
npy_intp skip_first_count = 0;
@@ -201,7 +198,7 @@ PyUFunc_ReduceWrapper(PyArrayObject *operand, PyArrayObject *out,
NpyIter *iter = NULL;
PyArrayObject *op[3];
PyArray_Descr *op_dtypes[3];
- npy_uint32 flags, op_flags[3];
+ npy_uint32 it_flags, op_flags[3];
/* More than one axis means multiple orders are possible */
if (!reorderable && count_axes(PyArray_NDIM(operand), axis_flags) > 1) {
@@ -227,7 +224,7 @@ PyUFunc_ReduceWrapper(PyArrayObject *operand, PyArrayObject *out,
op_dtypes[0] = result_dtype;
op_dtypes[1] = operand_dtype;
- flags = NPY_ITER_BUFFERED |
+ it_flags = NPY_ITER_BUFFERED |
NPY_ITER_EXTERNAL_LOOP |
NPY_ITER_GROWINNER |
NPY_ITER_DONT_NEGATE_STRIDES |
@@ -293,7 +290,7 @@ PyUFunc_ReduceWrapper(PyArrayObject *operand, PyArrayObject *out,
}
}
- iter = NpyIter_AdvancedNew(wheremask == NULL ? 2 : 3, op, flags,
+ iter = NpyIter_AdvancedNew(wheremask == NULL ? 2 : 3, op, it_flags,
NPY_KEEPORDER, casting,
op_flags,
op_dtypes,
@@ -304,11 +301,14 @@ PyUFunc_ReduceWrapper(PyArrayObject *operand, PyArrayObject *out,
result = NpyIter_GetOperandArray(iter)[0];
+ int needs_api = NpyIter_IterationNeedsAPI(iter);
+ /* Start with the floating-point exception flags cleared */
+ npy_clear_floatstatus_barrier((char*)&iter);
+
/*
* Initialize the result to the reduction unit if possible,
* otherwise copy the initial values and get a view to the rest.
*/
-
if (identity != Py_None) {
if (PyArray_FillWithScalar(result, identity) < 0) {
goto fail;
@@ -331,15 +331,11 @@ PyUFunc_ReduceWrapper(PyArrayObject *operand, PyArrayObject *out,
goto fail;
}
- /* Start with the floating-point exception flags cleared */
- npy_clear_floatstatus_barrier((char*)&iter);
-
if (NpyIter_GetIterSize(iter) != 0) {
NpyIter_IterNextFunc *iternext;
char **dataptr;
npy_intp *strideptr;
npy_intp *countptr;
- int needs_api;
iternext = NpyIter_GetIterNext(iter, NULL);
if (iternext == NULL) {
@@ -349,16 +345,6 @@ PyUFunc_ReduceWrapper(PyArrayObject *operand, PyArrayObject *out,
strideptr = NpyIter_GetInnerStrideArray(iter);
countptr = NpyIter_GetInnerLoopSizePtr(iter);
- needs_api = NpyIter_IterationNeedsAPI(iter);
-
- /* Straightforward reduction */
- if (loop == NULL) {
- PyErr_Format(PyExc_RuntimeError,
- "reduction operation %s did not supply an "
- "inner loop function", funcname);
- goto fail;
- }
-
if (loop(iter, dataptr, strideptr, countptr,
iternext, needs_api, skip_first_count, data) < 0) {
goto fail;