diff options
-rw-r--r-- | numpy/core/src/multiarray/item_selection.c | 9 | ||||
-rw-r--r-- | numpy/core/src/multiarray/lowlevel_strided_loops.h | 87 |
2 files changed, 90 insertions, 6 deletions
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c index 96e19fe28..68dbb4293 100644 --- a/numpy/core/src/multiarray/item_selection.c +++ b/numpy/core/src/multiarray/item_selection.c @@ -15,6 +15,7 @@ #include "common.h" #include "ctors.h" +#include "lowlevel_strided_loops.h" #define PyAO PyArrayObject #define _check_axis PyArray_CheckAxis @@ -1701,12 +1702,8 @@ PyArray_CountNonzero(PyArrayObject *self) npy_intp *strideptr, *innersizeptr; /* If it's a trivial one-dimensional loop, don't use an iterator */ - if (ndim <= 1 || PyArray_CHKFLAGS(self, NPY_CONTIGUOUS) || - PyArray_CHKFLAGS(self, NPY_FORTRAN)) { - data = PyArray_BYTES(self); - stride = (ndim == 0) ? 0 : (PyArray_CHKFLAGS(self, NPY_FORTRAN) ? - PyArray_STRIDE(self, 0) : PyArray_STRIDE(self, ndim-1)); - count = PyArray_SIZE(self); + if (PyArray_TRIVIALLY_ITERABLE(self)) { + PyArray_PREPARE_TRIVIAL_ITERATION(self, count, data, stride); while (count--) { if (nonzero(data, self)) { diff --git a/numpy/core/src/multiarray/lowlevel_strided_loops.h b/numpy/core/src/multiarray/lowlevel_strided_loops.h index 5ae6d2868..ebe841e50 100644 --- a/numpy/core/src/multiarray/lowlevel_strided_loops.h +++ b/numpy/core/src/multiarray/lowlevel_strided_loops.h @@ -207,4 +207,91 @@ PyArray_TransferStridedToNDim(npy_intp ndim, PyArray_StridedTransferFn *stransfer, void *transferdata); +/* + * TRIVIAL ITERATION + * + * In some cases when the iteration order isn't important, iteration over + * arrays is trivial. This is the case when: + * * The array has 0 or 1 dimensions. + * * The array is C or Fortran contiguous. + * Use of an iterator can be skipped when this occurs. These macros assist + * in detecting and taking advantage of the situation. Note that it may + * be worthwhile to further check if the stride is a contiguous stride + * and take advantage of that. + * + * Here is example code for a single array: + * + * if (PyArray_TRIVIALLY_ITERABLE(self) { + * char *data; + * npy_intp count, stride; + * + * PyArray_PREPARE_TRIVIAL_ITERATION(self, count, data, stride); + * + * while (count--) { + * // Use the data pointer + * + * data += stride; + * } + * } + * else { + * // Create iterator, etc... + * } + * + * Here is example code for a pair of arrays: + * + * if (PyArray_TRIVIALLY_ITERABLE_PAIR(a1, a2) { + * char *data1, *data2; + * npy_intp count, stride1, stride2; + * + * PyArray_PREPARE_TRIVIAL_PAIR_ITERATION(a1, a2, count, + * data1, data2, stride1, stride2); + * + * while (count--) { + * // Use the data1 and data2 pointers + * + * data1 += stride1; + * data2 += stride2; + * } + * } + * else { + * // Create iterator, etc... + * } + */ +#define PyArray_TRIVIALLY_ITERABLE(arr) ( \ + PyArray_NDIM(arr) <= 1 || \ + PyArray_CHKFLAGS(arr, NPY_CONTIGUOUS) || \ + PyArray_CHKFLAGS(arr, NPY_FORTRAN) \ + ) +#define PyArray_PREPARE_TRIVIAL_ITERATION(arr, count, data, stride) \ + count = PyArray_SIZE(arr), \ + data = PyArray_BYTES(arr), \ + stride = ((PyArray_NDIM(arr) == 0) ? 0 : \ + (PyArray_CHKFLAGS(arr, NPY_FORTRAN) ? \ + PyArray_STRIDE(arr, 0) : \ + PyArray_STRIDE(arr, ndim-1))) + +#define PyArray_TRIVIALLY_ITERABLE_PAIR(arr1, arr2) (\ + PyArray_TRIVIALLY_ITERABLE(arr1) && \ + PyArray_NDIM(arr1) == PyArray_NDIM(arr2) && \ + PyArray_CompareLists(PyArray_DIMS(arr1), \ + PyArray_DIMS(arr2), \ + PyArray_NDIM(arr1)) && \ + PyArray_CHKFLAGS(arr1, NPY_FORTRAN) == \ + PyArray_CHKFLAGS(arr2, NPY_FORTRAN) \ + ) +#define PyArray_PREPARE_TRIVIAL_PAIR_ITERATION(arr1, arr2, \ + count, \ + data1, data2, \ + stride1, stride2) \ + count = PyArray_SIZE(arr1), \ + data1 = PyArray_BYTES(arr1), \ + data2 = PyArray_BYTES(arr2), \ + stride1 = ((PyArray_NDIM(arr1) == 0) ? 0 : \ + (PyArray_CHKFLAGS(arr1, NPY_FORTRAN) ? \ + PyArray_STRIDE(arr1, 0) : \ + PyArray_STRIDE(arr1, ndim-1))), \ + stride2 = ((PyArray_NDIM(arr2) == 0) ? 0 : \ + (PyArray_CHKFLAGS(arr2, NPY_FORTRAN) ? \ + PyArray_STRIDE(arr2, 0) : \ + PyArray_STRIDE(arr2, ndim-1))) #endif |