summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/multiarray/item_selection.c9
-rw-r--r--numpy/core/src/multiarray/lowlevel_strided_loops.h87
2 files changed, 90 insertions, 6 deletions
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c
index 96e19fe28..68dbb4293 100644
--- a/numpy/core/src/multiarray/item_selection.c
+++ b/numpy/core/src/multiarray/item_selection.c
@@ -15,6 +15,7 @@
#include "common.h"
#include "ctors.h"
+#include "lowlevel_strided_loops.h"
#define PyAO PyArrayObject
#define _check_axis PyArray_CheckAxis
@@ -1701,12 +1702,8 @@ PyArray_CountNonzero(PyArrayObject *self)
npy_intp *strideptr, *innersizeptr;
/* If it's a trivial one-dimensional loop, don't use an iterator */
- if (ndim <= 1 || PyArray_CHKFLAGS(self, NPY_CONTIGUOUS) ||
- PyArray_CHKFLAGS(self, NPY_FORTRAN)) {
- data = PyArray_BYTES(self);
- stride = (ndim == 0) ? 0 : (PyArray_CHKFLAGS(self, NPY_FORTRAN) ?
- PyArray_STRIDE(self, 0) : PyArray_STRIDE(self, ndim-1));
- count = PyArray_SIZE(self);
+ if (PyArray_TRIVIALLY_ITERABLE(self)) {
+ PyArray_PREPARE_TRIVIAL_ITERATION(self, count, data, stride);
while (count--) {
if (nonzero(data, self)) {
diff --git a/numpy/core/src/multiarray/lowlevel_strided_loops.h b/numpy/core/src/multiarray/lowlevel_strided_loops.h
index 5ae6d2868..ebe841e50 100644
--- a/numpy/core/src/multiarray/lowlevel_strided_loops.h
+++ b/numpy/core/src/multiarray/lowlevel_strided_loops.h
@@ -207,4 +207,91 @@ PyArray_TransferStridedToNDim(npy_intp ndim,
PyArray_StridedTransferFn *stransfer,
void *transferdata);
+/*
+ * TRIVIAL ITERATION
+ *
+ * In some cases when the iteration order isn't important, iteration over
+ * arrays is trivial. This is the case when:
+ * * The array has 0 or 1 dimensions.
+ * * The array is C or Fortran contiguous.
+ * Use of an iterator can be skipped when this occurs. These macros assist
+ * in detecting and taking advantage of the situation. Note that it may
+ * be worthwhile to further check if the stride is a contiguous stride
+ * and take advantage of that.
+ *
+ * Here is example code for a single array:
+ *
+ * if (PyArray_TRIVIALLY_ITERABLE(self) {
+ * char *data;
+ * npy_intp count, stride;
+ *
+ * PyArray_PREPARE_TRIVIAL_ITERATION(self, count, data, stride);
+ *
+ * while (count--) {
+ * // Use the data pointer
+ *
+ * data += stride;
+ * }
+ * }
+ * else {
+ * // Create iterator, etc...
+ * }
+ *
+ * Here is example code for a pair of arrays:
+ *
+ * if (PyArray_TRIVIALLY_ITERABLE_PAIR(a1, a2) {
+ * char *data1, *data2;
+ * npy_intp count, stride1, stride2;
+ *
+ * PyArray_PREPARE_TRIVIAL_PAIR_ITERATION(a1, a2, count,
+ * data1, data2, stride1, stride2);
+ *
+ * while (count--) {
+ * // Use the data1 and data2 pointers
+ *
+ * data1 += stride1;
+ * data2 += stride2;
+ * }
+ * }
+ * else {
+ * // Create iterator, etc...
+ * }
+ */
+#define PyArray_TRIVIALLY_ITERABLE(arr) ( \
+ PyArray_NDIM(arr) <= 1 || \
+ PyArray_CHKFLAGS(arr, NPY_CONTIGUOUS) || \
+ PyArray_CHKFLAGS(arr, NPY_FORTRAN) \
+ )
+#define PyArray_PREPARE_TRIVIAL_ITERATION(arr, count, data, stride) \
+ count = PyArray_SIZE(arr), \
+ data = PyArray_BYTES(arr), \
+ stride = ((PyArray_NDIM(arr) == 0) ? 0 : \
+ (PyArray_CHKFLAGS(arr, NPY_FORTRAN) ? \
+ PyArray_STRIDE(arr, 0) : \
+ PyArray_STRIDE(arr, ndim-1)))
+
+#define PyArray_TRIVIALLY_ITERABLE_PAIR(arr1, arr2) (\
+ PyArray_TRIVIALLY_ITERABLE(arr1) && \
+ PyArray_NDIM(arr1) == PyArray_NDIM(arr2) && \
+ PyArray_CompareLists(PyArray_DIMS(arr1), \
+ PyArray_DIMS(arr2), \
+ PyArray_NDIM(arr1)) && \
+ PyArray_CHKFLAGS(arr1, NPY_FORTRAN) == \
+ PyArray_CHKFLAGS(arr2, NPY_FORTRAN) \
+ )
+#define PyArray_PREPARE_TRIVIAL_PAIR_ITERATION(arr1, arr2, \
+ count, \
+ data1, data2, \
+ stride1, stride2) \
+ count = PyArray_SIZE(arr1), \
+ data1 = PyArray_BYTES(arr1), \
+ data2 = PyArray_BYTES(arr2), \
+ stride1 = ((PyArray_NDIM(arr1) == 0) ? 0 : \
+ (PyArray_CHKFLAGS(arr1, NPY_FORTRAN) ? \
+ PyArray_STRIDE(arr1, 0) : \
+ PyArray_STRIDE(arr1, ndim-1))), \
+ stride2 = ((PyArray_NDIM(arr2) == 0) ? 0 : \
+ (PyArray_CHKFLAGS(arr2, NPY_FORTRAN) ? \
+ PyArray_STRIDE(arr2, 0) : \
+ PyArray_STRIDE(arr2, ndim-1)))
#endif