summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/common/simd/avx512/conversion.h8
-rw-r--r--numpy/core/src/common/simd/avx512/operators.h94
-rw-r--r--numpy/distutils/ccompiler_opt.py2
-rw-r--r--numpy/distutils/checks/extra_avx512dq_mask.c16
4 files changed, 72 insertions, 48 deletions
diff --git a/numpy/core/src/common/simd/avx512/conversion.h b/numpy/core/src/common/simd/avx512/conversion.h
index 7f4ae484d..0bd44179b 100644
--- a/numpy/core/src/common/simd/avx512/conversion.h
+++ b/numpy/core/src/common/simd/avx512/conversion.h
@@ -119,7 +119,13 @@ NPY_FINLINE npy_uint64 npyv_tobits_b16(npyv_b16 a)
NPY_FINLINE npy_uint64 npyv_tobits_b32(npyv_b32 a)
{ return (npy_uint16)a; }
NPY_FINLINE npy_uint64 npyv_tobits_b64(npyv_b64 a)
-{ return (npy_uint8)a; }
+{
+#ifdef NPY_HAVE_AVX512DQ_MASK
+ return _cvtmask8_u32(a);
+#else
+ return (npy_uint8)a;
+#endif
+}
// round to nearest integer (assuming even)
#define npyv_round_s32_f32 _mm512_cvtps_epi32
diff --git a/numpy/core/src/common/simd/avx512/operators.h b/numpy/core/src/common/simd/avx512/operators.h
index 5f1771770..d53932fa8 100644
--- a/numpy/core/src/common/simd/avx512/operators.h
+++ b/numpy/core/src/common/simd/avx512/operators.h
@@ -90,21 +90,6 @@
NPYV_IMPL_AVX512_FROM_SI512_PS_2ARG(npyv_and_f32, _mm512_and_si512)
NPYV_IMPL_AVX512_FROM_SI512_PD_2ARG(npyv_and_f64, _mm512_and_si512)
#endif
-#ifdef NPY_HAVE_AVX512BW_MASK
- #define npyv_and_b8 _kand_mask64
- #define npyv_and_b16 _kand_mask32
-#elif defined(NPY_HAVE_AVX512BW)
- NPY_FINLINE npyv_b8 npyv_and_b8(npyv_b8 a, npyv_b8 b)
- { return a & b; }
- NPY_FINLINE npyv_b16 npyv_and_b16(npyv_b16 a, npyv_b16 b)
- { return a & b; }
-#else
- #define npyv_and_b8 _mm512_and_si512
- #define npyv_and_b16 _mm512_and_si512
-#endif
-#define npyv_and_b32 _mm512_kand
-#define npyv_and_b64 _mm512_kand
-
// OR
#define npyv_or_u8 _mm512_or_si512
#define npyv_or_s8 _mm512_or_si512
@@ -121,20 +106,6 @@
NPYV_IMPL_AVX512_FROM_SI512_PS_2ARG(npyv_or_f32, _mm512_or_si512)
NPYV_IMPL_AVX512_FROM_SI512_PD_2ARG(npyv_or_f64, _mm512_or_si512)
#endif
-#ifdef NPY_HAVE_AVX512BW_MASK
- #define npyv_or_b8 _kor_mask64
- #define npyv_or_b16 _kor_mask32
-#elif defined(NPY_HAVE_AVX512BW)
- NPY_FINLINE npyv_b8 npyv_or_b8(npyv_b8 a, npyv_b8 b)
- { return a | b; }
- NPY_FINLINE npyv_b16 npyv_or_b16(npyv_b16 a, npyv_b16 b)
- { return a | b; }
-#else
- #define npyv_or_b8 _mm512_or_si512
- #define npyv_or_b16 _mm512_or_si512
-#endif
-#define npyv_or_b32 _mm512_kor
-#define npyv_or_b64 _mm512_kor
// XOR
#define npyv_xor_u8 _mm512_xor_si512
@@ -152,21 +123,6 @@
NPYV_IMPL_AVX512_FROM_SI512_PS_2ARG(npyv_xor_f32, _mm512_xor_si512)
NPYV_IMPL_AVX512_FROM_SI512_PD_2ARG(npyv_xor_f64, _mm512_xor_si512)
#endif
-#ifdef NPY_HAVE_AVX512BW_MASK
- #define npyv_xor_b8 _kxor_mask64
- #define npyv_xor_b16 _kxor_mask32
-#elif defined(NPY_HAVE_AVX512BW)
- NPY_FINLINE npyv_b8 npyv_xor_b8(npyv_b8 a, npyv_b8 b)
- { return a ^ b; }
- NPY_FINLINE npyv_b16 npyv_xor_b16(npyv_b16 a, npyv_b16 b)
- { return a ^ b; }
-#else
- #define npyv_xor_b8 _mm512_xor_si512
- #define npyv_xor_b16 _mm512_xor_si512
-#endif
-#define npyv_xor_b32 _mm512_kxor
-#define npyv_xor_b64 _mm512_kxor
-
// NOT
#define npyv_not_u8(A) _mm512_xor_si512(A, _mm512_set1_epi32(-1))
#define npyv_not_s8 npyv_not_u8
@@ -183,21 +139,67 @@
#define npyv_not_f32(A) _mm512_castsi512_ps(npyv_not_u32(_mm512_castps_si512(A)))
#define npyv_not_f64(A) _mm512_castsi512_pd(npyv_not_u64(_mm512_castpd_si512(A)))
#endif
+
+/***************************
+ * Logical (boolean)
+ ***************************/
#ifdef NPY_HAVE_AVX512BW_MASK
+ #define npyv_and_b8 _kand_mask64
+ #define npyv_and_b16 _kand_mask32
+ #define npyv_or_b8 _kor_mask64
+ #define npyv_or_b16 _kor_mask32
+ #define npyv_xor_b8 _kxor_mask64
+ #define npyv_xor_b16 _kxor_mask32
#define npyv_not_b8 _knot_mask64
#define npyv_not_b16 _knot_mask32
#elif defined(NPY_HAVE_AVX512BW)
- NPY_FINLINE npyv_b8 npyv_not_b8(npyv_b8 a)
+ NPY_FINLINE npyv_b8 npyv_and_b8(npyv_b8 a, npyv_b8 b)
+ { return a & b; }
+ NPY_FINLINE npyv_b16 npyv_and_b16(npyv_b16 a, npyv_b16 b)
+ { return a & b; }
+ NPY_FINLINE npyv_b8 npyv_or_b8(npyv_b8 a, npyv_b8 b)
+ { return a | b; }
+ NPY_FINLINE npyv_b16 npyv_or_b16(npyv_b16 a, npyv_b16 b)
+ { return a | b; }
+ NPY_FINLINE npyv_b8 npyv_xor_b8(npyv_b8 a, npyv_b8 b)
+ { return a ^ b; }
+ NPY_FINLINE npyv_b16 npyv_xor_b16(npyv_b16 a, npyv_b16 b)
+ { return a ^ b; }
+ NPY_FINLINE npyv_b8 npyv_not_b8(npyv_b8 a)
{ return ~a; }
NPY_FINLINE npyv_b16 npyv_not_b16(npyv_b16 a)
{ return ~a; }
#else
+ #define npyv_and_b8 _mm512_and_si512
+ #define npyv_and_b16 _mm512_and_si512
+ #define npyv_or_b8 _mm512_or_si512
+ #define npyv_or_b16 _mm512_or_si512
+ #define npyv_xor_b8 _mm512_xor_si512
+ #define npyv_xor_b16 _mm512_xor_si512
#define npyv_not_b8 npyv_not_u8
#define npyv_not_b16 npyv_not_u8
#endif
+
+#define npyv_and_b32 _mm512_kand
+#define npyv_or_b32 _mm512_kor
+#define npyv_xor_b32 _mm512_kxor
#define npyv_not_b32 _mm512_knot
-#define npyv_not_b64 _mm512_knot
+#ifdef NPY_HAVE_AVX512DQ_MASK
+ #define npyv_and_b64 _kand_mask8
+ #define npyv_or_b64 _kor_mask8
+ #define npyv_xor_b64 _kxor_mask8
+ #define npyv_not_b64 _knot_mask8
+#else
+ NPY_FINLINE npyv_b64 npyv_and_b64(npyv_b64 a, npyv_b64 b)
+ { return (npyv_b64)_mm512_kand((npyv_b32)a, (npyv_b32)b); }
+ NPY_FINLINE npyv_b64 npyv_or_b64(npyv_b64 a, npyv_b64 b)
+ { return (npyv_b64)_mm512_kor((npyv_b32)a, (npyv_b32)b); }
+ NPY_FINLINE npyv_b64 npyv_xor_b64(npyv_b64 a, npyv_b64 b)
+ { return (npyv_b64)_mm512_kxor((npyv_b32)a, (npyv_b32)b); }
+ NPY_FINLINE npyv_b64 npyv_not_b64(npyv_b64 a)
+ { return (npyv_b64)_mm512_knot((npyv_b32)a); }
+#endif
/***************************
* Comparison
diff --git a/numpy/distutils/ccompiler_opt.py b/numpy/distutils/ccompiler_opt.py
index 5fa17b2ee..e6c720399 100644
--- a/numpy/distutils/ccompiler_opt.py
+++ b/numpy/distutils/ccompiler_opt.py
@@ -259,7 +259,7 @@ class _Config:
AVX512_SKX = dict(
interest=42, implies="AVX512CD", group="AVX512VL AVX512BW AVX512DQ",
detect="AVX512_SKX", implies_detect=False,
- extra_checks="AVX512BW_MASK"
+ extra_checks="AVX512BW_MASK AVX512DQ_MASK"
),
AVX512_CLX = dict(
interest=43, implies="AVX512_SKX", group="AVX512VNNI",
diff --git a/numpy/distutils/checks/extra_avx512dq_mask.c b/numpy/distutils/checks/extra_avx512dq_mask.c
new file mode 100644
index 000000000..f0dc88bdd
--- /dev/null
+++ b/numpy/distutils/checks/extra_avx512dq_mask.c
@@ -0,0 +1,16 @@
+#include <immintrin.h>
+/**
+ * Test DQ mask operations due to:
+ * - MSVC has supported it since vs2019 see,
+ * https://developercommunity.visualstudio.com/content/problem/518298/missing-avx512bw-mask-intrinsics.html
+ * - Clang >= v8.0
+ * - GCC >= v7.1
+ */
+int main(void)
+{
+ __mmask8 m8 = _mm512_cmpeq_epi64_mask(_mm512_set1_epi64(1), _mm512_set1_epi64(1));
+ m8 = _kor_mask8(m8, m8);
+ m8 = _kxor_mask8(m8, m8);
+ m8 = _cvtu32_mask8(_cvtmask8_u32(m8));
+ return (int)_cvtmask8_u32(m8);
+}