summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/item_selection.c39
1 files changed, 26 insertions, 13 deletions
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c
index c71c4325a..0bab1eff9 100644
--- a/numpy/core/src/multiarray/item_selection.c
+++ b/numpy/core/src/multiarray/item_selection.c
@@ -2133,14 +2133,14 @@ count_nonzero_bytes_384(const npy_uint64 * w)
/* Count the zero bytes between `*d` and `end`, updating `*d` to point to where to keep counting from. */
static NPY_INLINE NPY_GCC_OPT_3 npyv_u8
-count_zero_bytes_u8(const npy_uint8 **d, const npy_uint8 *end)
+count_zero_bytes_u8(const npy_uint8 **d, const npy_uint8 *end, npy_uint8 max_count)
{
const npyv_u8 vone = npyv_setall_u8(1);
const npyv_u8 vzero = npyv_zero_u8();
npy_intp lane_max = 0;
npyv_u8 vsum8 = npyv_zero_u8();
- while (*d < end && lane_max <= 0xFE) {
+ while (*d < end && lane_max <= max_count - 1) {
npyv_u8 vt = npyv_cvt_u8_b8(npyv_cmpeq_u8(npyv_load_u8(*d), vzero));
vt = npyv_and_u8(vt, vone);
vsum8 = npyv_add_u8(vsum8, vt);
@@ -2151,41 +2151,54 @@ count_zero_bytes_u8(const npy_uint8 **d, const npy_uint8 *end)
}
static NPY_INLINE NPY_GCC_OPT_3 npyv_u16
-count_zero_bytes_u16(const npy_uint8 **d, const npy_uint8 *end)
+count_zero_bytes_u16(const npy_uint8 **d, const npy_uint8 *end, npy_uint16 max_count)
{
npyv_u16 vsum16 = npyv_zero_u16();
npy_intp lane_max = 0;
- while (*d < end && lane_max <= 0xFF00-0xFF) {
- npyv_u8 vsum8 = count_zero_bytes_u8(d, end);
+ while (*d < end && lane_max <= max_count - 2*NPY_MAX_UINT8) {
+ npyv_u8 vsum8 = count_zero_bytes_u8(d, end, NPY_MAX_UINT8);
npyv_u16x2 part = npyv_expand_u16_u8(vsum8);
vsum16 = npyv_add_u16(vsum16, npyv_add_u16(part.val[0], part.val[1]));
- lane_max += 0xFF*2;
+ lane_max += 2*NPY_MAX_UINT8;
}
return vsum16;
}
static NPY_INLINE NPY_GCC_OPT_3 npyv_u32
-count_zero_bytes_u32(const npy_uint8 **d, const npy_uint8 *end)
+count_zero_bytes_u32(const npy_uint8 **d, const npy_uint8 *end, npy_uint32 max_count)
{
npyv_u32 vsum32 = npyv_zero_u32();
npy_intp lane_max = 0;
- // The last accumulation needs to adjustment (2**32-1)/nlanes to avoid overflow.
- while (*d < end && lane_max <= (0xFFFF0000-0xFFFF)/npyv_nlanes_u32) {
- npyv_u16 vsum16 = count_zero_bytes_u16(d, end);
+ while (*d < end && lane_max <= max_count - 2*NPY_MAX_UINT16) {
+ npyv_u16 vsum16 = count_zero_bytes_u16(d, end, NPY_MAX_UINT16);
npyv_u32x2 part = npyv_expand_u32_u16(vsum16);
vsum32 = npyv_add_u32(vsum32, npyv_add_u32(part.val[0], part.val[1]));
- lane_max += 0xFFFF*2;
+ lane_max += 2*NPY_MAX_UINT16;
}
return vsum32;
}
-
+/*
+ * Counts the number of non-zero values in a raw array.
+ * The one loop process is as follows(take SSE2 with 128bits vector for example):
+ * |------------16 lanes---------|
+ * [vsum8] 255 255 255 ... 255 255 255 255 count_zero_bytes_u8: counting 255*16 elements
+ * !!
+ * |------------8 lanes---------|
+ * [vsum16] 65535 65535 65535 ... 65535 count_zero_bytes_u16: counting 65535*8 elements
+ * !!
+ * |------------4 lanes---------|
+ * [vsum32] 1073741824 ... 1073741824 count_zero_bytes_u32(overflow control): counting 2**32-1 elements
+ * !!
+ * 2**32-1 count_zero_bytes
+*/
static NPY_INLINE NPY_GCC_OPT_3 npy_intp
count_nonzero_bytes(const npy_uint8 *d, npy_uintp unrollx)
{
npy_intp zero_count = 0;
const npy_uint8 *end = d + unrollx;
while (d < end) {
- npyv_u32 vsum32 = count_zero_bytes_u32(&d, end);
+ // The npyv_nlanes_u32 factor ensures that the sum of all lanes still fits in a uint32
+ npyv_u32 vsum32 = count_zero_bytes_u32(&d, end, NPY_MAX_UINT32 / npyv_nlanes_u32);
zero_count += npyv_sum_u32(vsum32);
}
return unrollx - zero_count;