diff options
| author | Chunlin <fangchunlin@huawei.com> | 2020-06-01 14:17:57 +0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2020-06-01 09:17:57 +0300 |
| commit | bdd4e2e29c4b011cf455a72f85c92010656d2722 (patch) | |
| tree | 79dbdfff31b8e3720d2aeeaba46174e74303d182 /numpy | |
| parent | 1b212bda1c2833467177eb0d62c6de6c224c6a80 (diff) | |
| download | numpy-bdd4e2e29c4b011cf455a72f85c92010656d2722.tar.gz | |
ENH: ARM Neon implementation with intrinsic for np.argmax. (#16375)
* Neon implementation with intrinsic for bool argmax
Diffstat (limited to 'numpy')
| -rw-r--r-- | numpy/core/src/multiarray/arraytypes.c.src | 26 |
1 files changed, 25 insertions, 1 deletions
diff --git a/numpy/core/src/multiarray/arraytypes.c.src b/numpy/core/src/multiarray/arraytypes.c.src index 552c56349..2048b5898 100644 --- a/numpy/core/src/multiarray/arraytypes.c.src +++ b/numpy/core/src/multiarray/arraytypes.c.src @@ -27,6 +27,9 @@ #include "arrayobject.h" #include "alloc.h" #include "typeinfo.h" +#if defined(__ARM_NEON__) || defined (__ARM_NEON) +#include <arm_neon.h> +#endif #ifdef NPY_HAVE_SSE2_INTRINSICS #include <emmintrin.h> #endif @@ -3070,7 +3073,15 @@ finish: ** ARGFUNC ** ***************************************************************************** */ - +#if defined(__ARM_NEON__) || defined (__ARM_NEON) + int32_t _mm_movemask_epi8_neon(uint8x16_t input) + { + int8x8_t m0 = vcreate_s8(0x0706050403020100ULL); + uint8x16_t v0 = vshlq_u8(vshrq_n_u8(input, 7), vcombine_s8(m0, m0)); + uint64x2_t v1 = vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(v0))); + return (int)vgetq_lane_u64(v1, 0) + ((int)vgetq_lane_u64(v1, 1) << 8); + } +#endif #define _LESS_THAN_OR_EQUAL(a,b) ((a) <= (b)) static int @@ -3091,6 +3102,19 @@ BOOL_argmax(npy_bool *ip, npy_intp n, npy_intp *max_ind, break; } } +#else + #if defined(__ARM_NEON__) || defined (__ARM_NEON) + uint8x16_t zero = vdupq_n_u8(0); + for(; i < n - (n % 32); i+=32) { + uint8x16_t d1 = vld1q_u8((char *)&ip[i]); + uint8x16_t d2 = vld1q_u8((char *)&ip[i + 16]); + d1 = vceqq_u8(d1, zero); + d2 = vceqq_u8(d2, zero); + if(_mm_movemask_epi8_neon(vminq_u8(d1, d2)) != 0xFFFF) { + break; + } + } + #endif #endif for (; i < n; i++) { if (ip[i]) { |
