diff options
Diffstat (limited to 'numpy')
-rw-r--r-- | numpy/core/src/multiarray/item_selection.c | 39 |
1 files changed, 26 insertions, 13 deletions
diff --git a/numpy/core/src/multiarray/item_selection.c b/numpy/core/src/multiarray/item_selection.c index c71c4325a..0bab1eff9 100644 --- a/numpy/core/src/multiarray/item_selection.c +++ b/numpy/core/src/multiarray/item_selection.c @@ -2133,14 +2133,14 @@ count_nonzero_bytes_384(const npy_uint64 * w) /* Count the zero bytes between `*d` and `end`, updating `*d` to point to where to keep counting from. */ static NPY_INLINE NPY_GCC_OPT_3 npyv_u8 -count_zero_bytes_u8(const npy_uint8 **d, const npy_uint8 *end) +count_zero_bytes_u8(const npy_uint8 **d, const npy_uint8 *end, npy_uint8 max_count) { const npyv_u8 vone = npyv_setall_u8(1); const npyv_u8 vzero = npyv_zero_u8(); npy_intp lane_max = 0; npyv_u8 vsum8 = npyv_zero_u8(); - while (*d < end && lane_max <= 0xFE) { + while (*d < end && lane_max <= max_count - 1) { npyv_u8 vt = npyv_cvt_u8_b8(npyv_cmpeq_u8(npyv_load_u8(*d), vzero)); vt = npyv_and_u8(vt, vone); vsum8 = npyv_add_u8(vsum8, vt); @@ -2151,41 +2151,54 @@ count_zero_bytes_u8(const npy_uint8 **d, const npy_uint8 *end) } static NPY_INLINE NPY_GCC_OPT_3 npyv_u16 -count_zero_bytes_u16(const npy_uint8 **d, const npy_uint8 *end) +count_zero_bytes_u16(const npy_uint8 **d, const npy_uint8 *end, npy_uint16 max_count) { npyv_u16 vsum16 = npyv_zero_u16(); npy_intp lane_max = 0; - while (*d < end && lane_max <= 0xFF00-0xFF) { - npyv_u8 vsum8 = count_zero_bytes_u8(d, end); + while (*d < end && lane_max <= max_count - 2*NPY_MAX_UINT8) { + npyv_u8 vsum8 = count_zero_bytes_u8(d, end, NPY_MAX_UINT8); npyv_u16x2 part = npyv_expand_u16_u8(vsum8); vsum16 = npyv_add_u16(vsum16, npyv_add_u16(part.val[0], part.val[1])); - lane_max += 0xFF*2; + lane_max += 2*NPY_MAX_UINT8; } return vsum16; } static NPY_INLINE NPY_GCC_OPT_3 npyv_u32 -count_zero_bytes_u32(const npy_uint8 **d, const npy_uint8 *end) +count_zero_bytes_u32(const npy_uint8 **d, const npy_uint8 *end, npy_uint32 max_count) { npyv_u32 vsum32 = npyv_zero_u32(); npy_intp lane_max = 0; - // The last accumulation needs to adjustment (2**32-1)/nlanes to avoid overflow. - while (*d < end && lane_max <= (0xFFFF0000-0xFFFF)/npyv_nlanes_u32) { - npyv_u16 vsum16 = count_zero_bytes_u16(d, end); + while (*d < end && lane_max <= max_count - 2*NPY_MAX_UINT16) { + npyv_u16 vsum16 = count_zero_bytes_u16(d, end, NPY_MAX_UINT16); npyv_u32x2 part = npyv_expand_u32_u16(vsum16); vsum32 = npyv_add_u32(vsum32, npyv_add_u32(part.val[0], part.val[1])); - lane_max += 0xFFFF*2; + lane_max += 2*NPY_MAX_UINT16; } return vsum32; } - +/* + * Counts the number of non-zero values in a raw array. + * The one loop process is as follows(take SSE2 with 128bits vector for example): + * |------------16 lanes---------| + * [vsum8] 255 255 255 ... 255 255 255 255 count_zero_bytes_u8: counting 255*16 elements + * !! + * |------------8 lanes---------| + * [vsum16] 65535 65535 65535 ... 65535 count_zero_bytes_u16: counting 65535*8 elements + * !! + * |------------4 lanes---------| + * [vsum32] 1073741824 ... 1073741824 count_zero_bytes_u32(overflow control): counting 2**32-1 elements + * !! + * 2**32-1 count_zero_bytes +*/ static NPY_INLINE NPY_GCC_OPT_3 npy_intp count_nonzero_bytes(const npy_uint8 *d, npy_uintp unrollx) { npy_intp zero_count = 0; const npy_uint8 *end = d + unrollx; while (d < end) { - npyv_u32 vsum32 = count_zero_bytes_u32(&d, end); + // The npyv_nlanes_u32 factor ensures that the sum of all lanes still fits in a uint32 + npyv_u32 vsum32 = count_zero_bytes_u32(&d, end, NPY_MAX_UINT32 / npyv_nlanes_u32); zero_count += npyv_sum_u32(vsum32); } return unrollx - zero_count; |