summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorCharles Harris <charlesr.harris@gmail.com>2022-03-13 17:20:18 -0600
committerGitHub <noreply@github.com>2022-03-13 17:20:18 -0600
commit780799bb223c4467fb7a8ab1dc8252396813b960 (patch)
tree01fe83e9f617c6840d39dd850a1f4f71569b1c13 /numpy
parent55f31b2bd84898eaf33e70bd5251fbefeac44ff3 (diff)
parent8077036a57b2c55bd509c0a0b03c10b39cf70f33 (diff)
downloadnumpy-780799bb223c4467fb7a8ab1dc8252396813b960.tar.gz
Merge pull request #21130 from zephyr111/faster-where
ENH: improve the speed of numpy.where using a branchless code
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/item_selection.c30
1 files changed, 25 insertions, 5 deletions
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c
index 086b674c8..9fad153a3 100644
--- a/numpy/core/src/multiarray/item_selection.c
+++ b/numpy/core/src/multiarray/item_selection.c
@@ -2641,13 +2641,33 @@ PyArray_Nonzero(PyArrayObject *self)
*multi_index++ = j++;
}
}
+ /*
+ * Fallback to a branchless strategy to avoid branch misprediction
+ * stalls that are very expensive on most modern processors.
+ */
else {
- npy_intp j;
- for (j = 0; j < count; ++j) {
- if (*data != 0) {
- *multi_index++ = j;
- }
+ npy_intp *multi_index_end = multi_index + nonzero_count;
+ npy_intp j = 0;
+
+ /* Manually unroll for GCC and maybe other compilers */
+ while (multi_index + 4 < multi_index_end) {
+ *multi_index = j;
+ multi_index += data[0] != 0;
+ *multi_index = j + 1;
+ multi_index += data[stride] != 0;
+ *multi_index = j + 2;
+ multi_index += data[stride * 2] != 0;
+ *multi_index = j + 3;
+ multi_index += data[stride * 3] != 0;
+ data += stride * 4;
+ j += 4;
+ }
+
+ while (multi_index < multi_index_end) {
+ *multi_index = j;
+ multi_index += *data != 0;
data += stride;
+ ++j;
}
}
}