summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorChunlin <fangchunlin@huawei.com>2020-06-01 14:17:57 +0800
committerGitHub <noreply@github.com>2020-06-01 09:17:57 +0300
commitbdd4e2e29c4b011cf455a72f85c92010656d2722 (patch)
tree79dbdfff31b8e3720d2aeeaba46174e74303d182 /numpy
parent1b212bda1c2833467177eb0d62c6de6c224c6a80 (diff)
downloadnumpy-bdd4e2e29c4b011cf455a72f85c92010656d2722.tar.gz
ENH: ARM Neon implementation with intrinsic for np.argmax. (#16375)
* Neon implementation with intrinsic for bool argmax
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/multiarray/arraytypes.c.src26
1 files changed, 25 insertions, 1 deletions
diff --git a/numpy/core/src/multiarray/arraytypes.c.src b/numpy/core/src/multiarray/arraytypes.c.src
index 552c56349..2048b5898 100644
--- a/numpy/core/src/multiarray/arraytypes.c.src
+++ b/numpy/core/src/multiarray/arraytypes.c.src
@@ -27,6 +27,9 @@
#include "arrayobject.h"
#include "alloc.h"
#include "typeinfo.h"
+#if defined(__ARM_NEON__) || defined (__ARM_NEON)
+#include <arm_neon.h>
+#endif
#ifdef NPY_HAVE_SSE2_INTRINSICS
#include <emmintrin.h>
#endif
@@ -3070,7 +3073,15 @@ finish:
** ARGFUNC **
*****************************************************************************
*/
-
+#if defined(__ARM_NEON__) || defined (__ARM_NEON)
+ int32_t _mm_movemask_epi8_neon(uint8x16_t input)
+ {
+ int8x8_t m0 = vcreate_s8(0x0706050403020100ULL);
+ uint8x16_t v0 = vshlq_u8(vshrq_n_u8(input, 7), vcombine_s8(m0, m0));
+ uint64x2_t v1 = vpaddlq_u32(vpaddlq_u16(vpaddlq_u8(v0)));
+ return (int)vgetq_lane_u64(v1, 0) + ((int)vgetq_lane_u64(v1, 1) << 8);
+ }
+#endif
#define _LESS_THAN_OR_EQUAL(a,b) ((a) <= (b))
static int
@@ -3091,6 +3102,19 @@ BOOL_argmax(npy_bool *ip, npy_intp n, npy_intp *max_ind,
break;
}
}
+#else
+ #if defined(__ARM_NEON__) || defined (__ARM_NEON)
+ uint8x16_t zero = vdupq_n_u8(0);
+ for(; i < n - (n % 32); i+=32) {
+ uint8x16_t d1 = vld1q_u8((char *)&ip[i]);
+ uint8x16_t d2 = vld1q_u8((char *)&ip[i + 16]);
+ d1 = vceqq_u8(d1, zero);
+ d2 = vceqq_u8(d2, zero);
+ if(_mm_movemask_epi8_neon(vminq_u8(d1, d2)) != 0xFFFF) {
+ break;
+ }
+ }
+ #endif
#endif
for (; i < n; i++) {
if (ip[i]) {