summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSayed Adel <seiko@imavr.com>2021-01-05 06:46:02 +0000
committerSayed Adel <seiko@imavr.com>2021-01-05 07:15:55 +0000
commit998ca7c883c3c3bcedfae4f21f1a7a9f32b19301 (patch)
tree6ce53f270b521b3da7b02b71907ab8635a24c369
parent172311808c0857465cd09e9f7d295bfcd0179b1e (diff)
downloadnumpy-998ca7c883c3c3bcedfae4f21f1a7a9f32b19301.tar.gz
BUG, MAINT: improve avx512 mask logical operations
It also fixes conversion warning between `__mmask16` and `__mmask8` on msvc2019 when logical intrinsics of AVX512DQ are available.
-rw-r--r--numpy/core/src/common/simd/avx512/conversion.h8
-rw-r--r--numpy/core/src/common/simd/avx512/operators.h94
2 files changed, 55 insertions, 47 deletions
diff --git a/numpy/core/src/common/simd/avx512/conversion.h b/numpy/core/src/common/simd/avx512/conversion.h
index 7f4ae484d..0bd44179b 100644
--- a/numpy/core/src/common/simd/avx512/conversion.h
+++ b/numpy/core/src/common/simd/avx512/conversion.h
@@ -119,7 +119,13 @@ NPY_FINLINE npy_uint64 npyv_tobits_b16(npyv_b16 a)
NPY_FINLINE npy_uint64 npyv_tobits_b32(npyv_b32 a)
{ return (npy_uint16)a; }
NPY_FINLINE npy_uint64 npyv_tobits_b64(npyv_b64 a)
-{ return (npy_uint8)a; }
+{
+#ifdef NPY_HAVE_AVX512DQ_MASK
+ return _cvtmask8_u32(a);
+#else
+ return (npy_uint8)a;
+#endif
+}
// round to nearest integer (assuming even)
#define npyv_round_s32_f32 _mm512_cvtps_epi32
diff --git a/numpy/core/src/common/simd/avx512/operators.h b/numpy/core/src/common/simd/avx512/operators.h
index 5f1771770..d53932fa8 100644
--- a/numpy/core/src/common/simd/avx512/operators.h
+++ b/numpy/core/src/common/simd/avx512/operators.h
@@ -90,21 +90,6 @@
NPYV_IMPL_AVX512_FROM_SI512_PS_2ARG(npyv_and_f32, _mm512_and_si512)
NPYV_IMPL_AVX512_FROM_SI512_PD_2ARG(npyv_and_f64, _mm512_and_si512)
#endif
-#ifdef NPY_HAVE_AVX512BW_MASK
- #define npyv_and_b8 _kand_mask64
- #define npyv_and_b16 _kand_mask32
-#elif defined(NPY_HAVE_AVX512BW)
- NPY_FINLINE npyv_b8 npyv_and_b8(npyv_b8 a, npyv_b8 b)
- { return a & b; }
- NPY_FINLINE npyv_b16 npyv_and_b16(npyv_b16 a, npyv_b16 b)
- { return a & b; }
-#else
- #define npyv_and_b8 _mm512_and_si512
- #define npyv_and_b16 _mm512_and_si512
-#endif
-#define npyv_and_b32 _mm512_kand
-#define npyv_and_b64 _mm512_kand
-
// OR
#define npyv_or_u8 _mm512_or_si512
#define npyv_or_s8 _mm512_or_si512
@@ -121,20 +106,6 @@
NPYV_IMPL_AVX512_FROM_SI512_PS_2ARG(npyv_or_f32, _mm512_or_si512)
NPYV_IMPL_AVX512_FROM_SI512_PD_2ARG(npyv_or_f64, _mm512_or_si512)
#endif
-#ifdef NPY_HAVE_AVX512BW_MASK
- #define npyv_or_b8 _kor_mask64
- #define npyv_or_b16 _kor_mask32
-#elif defined(NPY_HAVE_AVX512BW)
- NPY_FINLINE npyv_b8 npyv_or_b8(npyv_b8 a, npyv_b8 b)
- { return a | b; }
- NPY_FINLINE npyv_b16 npyv_or_b16(npyv_b16 a, npyv_b16 b)
- { return a | b; }
-#else
- #define npyv_or_b8 _mm512_or_si512
- #define npyv_or_b16 _mm512_or_si512
-#endif
-#define npyv_or_b32 _mm512_kor
-#define npyv_or_b64 _mm512_kor
// XOR
#define npyv_xor_u8 _mm512_xor_si512
@@ -152,21 +123,6 @@
NPYV_IMPL_AVX512_FROM_SI512_PS_2ARG(npyv_xor_f32, _mm512_xor_si512)
NPYV_IMPL_AVX512_FROM_SI512_PD_2ARG(npyv_xor_f64, _mm512_xor_si512)
#endif
-#ifdef NPY_HAVE_AVX512BW_MASK
- #define npyv_xor_b8 _kxor_mask64
- #define npyv_xor_b16 _kxor_mask32
-#elif defined(NPY_HAVE_AVX512BW)
- NPY_FINLINE npyv_b8 npyv_xor_b8(npyv_b8 a, npyv_b8 b)
- { return a ^ b; }
- NPY_FINLINE npyv_b16 npyv_xor_b16(npyv_b16 a, npyv_b16 b)
- { return a ^ b; }
-#else
- #define npyv_xor_b8 _mm512_xor_si512
- #define npyv_xor_b16 _mm512_xor_si512
-#endif
-#define npyv_xor_b32 _mm512_kxor
-#define npyv_xor_b64 _mm512_kxor
-
// NOT
#define npyv_not_u8(A) _mm512_xor_si512(A, _mm512_set1_epi32(-1))
#define npyv_not_s8 npyv_not_u8
@@ -183,21 +139,67 @@
#define npyv_not_f32(A) _mm512_castsi512_ps(npyv_not_u32(_mm512_castps_si512(A)))
#define npyv_not_f64(A) _mm512_castsi512_pd(npyv_not_u64(_mm512_castpd_si512(A)))
#endif
+
+/***************************
+ * Logical (boolean)
+ ***************************/
#ifdef NPY_HAVE_AVX512BW_MASK
+ #define npyv_and_b8 _kand_mask64
+ #define npyv_and_b16 _kand_mask32
+ #define npyv_or_b8 _kor_mask64
+ #define npyv_or_b16 _kor_mask32
+ #define npyv_xor_b8 _kxor_mask64
+ #define npyv_xor_b16 _kxor_mask32
#define npyv_not_b8 _knot_mask64
#define npyv_not_b16 _knot_mask32
#elif defined(NPY_HAVE_AVX512BW)
- NPY_FINLINE npyv_b8 npyv_not_b8(npyv_b8 a)
+ NPY_FINLINE npyv_b8 npyv_and_b8(npyv_b8 a, npyv_b8 b)
+ { return a & b; }
+ NPY_FINLINE npyv_b16 npyv_and_b16(npyv_b16 a, npyv_b16 b)
+ { return a & b; }
+ NPY_FINLINE npyv_b8 npyv_or_b8(npyv_b8 a, npyv_b8 b)
+ { return a | b; }
+ NPY_FINLINE npyv_b16 npyv_or_b16(npyv_b16 a, npyv_b16 b)
+ { return a | b; }
+ NPY_FINLINE npyv_b8 npyv_xor_b8(npyv_b8 a, npyv_b8 b)
+ { return a ^ b; }
+ NPY_FINLINE npyv_b16 npyv_xor_b16(npyv_b16 a, npyv_b16 b)
+ { return a ^ b; }
+ NPY_FINLINE npyv_b8 npyv_not_b8(npyv_b8 a)
{ return ~a; }
NPY_FINLINE npyv_b16 npyv_not_b16(npyv_b16 a)
{ return ~a; }
#else
+ #define npyv_and_b8 _mm512_and_si512
+ #define npyv_and_b16 _mm512_and_si512
+ #define npyv_or_b8 _mm512_or_si512
+ #define npyv_or_b16 _mm512_or_si512
+ #define npyv_xor_b8 _mm512_xor_si512
+ #define npyv_xor_b16 _mm512_xor_si512
#define npyv_not_b8 npyv_not_u8
#define npyv_not_b16 npyv_not_u8
#endif
+
+#define npyv_and_b32 _mm512_kand
+#define npyv_or_b32 _mm512_kor
+#define npyv_xor_b32 _mm512_kxor
#define npyv_not_b32 _mm512_knot
-#define npyv_not_b64 _mm512_knot
+#ifdef NPY_HAVE_AVX512DQ_MASK
+ #define npyv_and_b64 _kand_mask8
+ #define npyv_or_b64 _kor_mask8
+ #define npyv_xor_b64 _kxor_mask8
+ #define npyv_not_b64 _knot_mask8
+#else
+ NPY_FINLINE npyv_b64 npyv_and_b64(npyv_b64 a, npyv_b64 b)
+ { return (npyv_b64)_mm512_kand((npyv_b32)a, (npyv_b32)b); }
+ NPY_FINLINE npyv_b64 npyv_or_b64(npyv_b64 a, npyv_b64 b)
+ { return (npyv_b64)_mm512_kor((npyv_b32)a, (npyv_b32)b); }
+ NPY_FINLINE npyv_b64 npyv_xor_b64(npyv_b64 a, npyv_b64 b)
+ { return (npyv_b64)_mm512_kxor((npyv_b32)a, (npyv_b32)b); }
+ NPY_FINLINE npyv_b64 npyv_not_b64(npyv_b64 a)
+ { return (npyv_b64)_mm512_knot((npyv_b32)a); }
+#endif
/***************************
* Comparison