diff options
-rw-r--r-- | numpy/core/src/common/simd/avx512/conversion.h | 8 | ||||
-rw-r--r-- | numpy/core/src/common/simd/avx512/operators.h | 94 | ||||
-rw-r--r-- | numpy/distutils/ccompiler_opt.py | 2 | ||||
-rw-r--r-- | numpy/distutils/checks/extra_avx512dq_mask.c | 16 |
4 files changed, 72 insertions, 48 deletions
diff --git a/numpy/core/src/common/simd/avx512/conversion.h b/numpy/core/src/common/simd/avx512/conversion.h index 7f4ae484d..0bd44179b 100644 --- a/numpy/core/src/common/simd/avx512/conversion.h +++ b/numpy/core/src/common/simd/avx512/conversion.h @@ -119,7 +119,13 @@ NPY_FINLINE npy_uint64 npyv_tobits_b16(npyv_b16 a) NPY_FINLINE npy_uint64 npyv_tobits_b32(npyv_b32 a) { return (npy_uint16)a; } NPY_FINLINE npy_uint64 npyv_tobits_b64(npyv_b64 a) -{ return (npy_uint8)a; } +{ +#ifdef NPY_HAVE_AVX512DQ_MASK + return _cvtmask8_u32(a); +#else + return (npy_uint8)a; +#endif +} // round to nearest integer (assuming even) #define npyv_round_s32_f32 _mm512_cvtps_epi32 diff --git a/numpy/core/src/common/simd/avx512/operators.h b/numpy/core/src/common/simd/avx512/operators.h index 5f1771770..d53932fa8 100644 --- a/numpy/core/src/common/simd/avx512/operators.h +++ b/numpy/core/src/common/simd/avx512/operators.h @@ -90,21 +90,6 @@ NPYV_IMPL_AVX512_FROM_SI512_PS_2ARG(npyv_and_f32, _mm512_and_si512) NPYV_IMPL_AVX512_FROM_SI512_PD_2ARG(npyv_and_f64, _mm512_and_si512) #endif -#ifdef NPY_HAVE_AVX512BW_MASK - #define npyv_and_b8 _kand_mask64 - #define npyv_and_b16 _kand_mask32 -#elif defined(NPY_HAVE_AVX512BW) - NPY_FINLINE npyv_b8 npyv_and_b8(npyv_b8 a, npyv_b8 b) - { return a & b; } - NPY_FINLINE npyv_b16 npyv_and_b16(npyv_b16 a, npyv_b16 b) - { return a & b; } -#else - #define npyv_and_b8 _mm512_and_si512 - #define npyv_and_b16 _mm512_and_si512 -#endif -#define npyv_and_b32 _mm512_kand -#define npyv_and_b64 _mm512_kand - // OR #define npyv_or_u8 _mm512_or_si512 #define npyv_or_s8 _mm512_or_si512 @@ -121,20 +106,6 @@ NPYV_IMPL_AVX512_FROM_SI512_PS_2ARG(npyv_or_f32, _mm512_or_si512) NPYV_IMPL_AVX512_FROM_SI512_PD_2ARG(npyv_or_f64, _mm512_or_si512) #endif -#ifdef NPY_HAVE_AVX512BW_MASK - #define npyv_or_b8 _kor_mask64 - #define npyv_or_b16 _kor_mask32 -#elif defined(NPY_HAVE_AVX512BW) - NPY_FINLINE npyv_b8 npyv_or_b8(npyv_b8 a, npyv_b8 b) - { return a | b; } - NPY_FINLINE npyv_b16 npyv_or_b16(npyv_b16 a, npyv_b16 b) - { return a | b; } -#else - #define npyv_or_b8 _mm512_or_si512 - #define npyv_or_b16 _mm512_or_si512 -#endif -#define npyv_or_b32 _mm512_kor -#define npyv_or_b64 _mm512_kor // XOR #define npyv_xor_u8 _mm512_xor_si512 @@ -152,21 +123,6 @@ NPYV_IMPL_AVX512_FROM_SI512_PS_2ARG(npyv_xor_f32, _mm512_xor_si512) NPYV_IMPL_AVX512_FROM_SI512_PD_2ARG(npyv_xor_f64, _mm512_xor_si512) #endif -#ifdef NPY_HAVE_AVX512BW_MASK - #define npyv_xor_b8 _kxor_mask64 - #define npyv_xor_b16 _kxor_mask32 -#elif defined(NPY_HAVE_AVX512BW) - NPY_FINLINE npyv_b8 npyv_xor_b8(npyv_b8 a, npyv_b8 b) - { return a ^ b; } - NPY_FINLINE npyv_b16 npyv_xor_b16(npyv_b16 a, npyv_b16 b) - { return a ^ b; } -#else - #define npyv_xor_b8 _mm512_xor_si512 - #define npyv_xor_b16 _mm512_xor_si512 -#endif -#define npyv_xor_b32 _mm512_kxor -#define npyv_xor_b64 _mm512_kxor - // NOT #define npyv_not_u8(A) _mm512_xor_si512(A, _mm512_set1_epi32(-1)) #define npyv_not_s8 npyv_not_u8 @@ -183,21 +139,67 @@ #define npyv_not_f32(A) _mm512_castsi512_ps(npyv_not_u32(_mm512_castps_si512(A))) #define npyv_not_f64(A) _mm512_castsi512_pd(npyv_not_u64(_mm512_castpd_si512(A))) #endif + +/*************************** + * Logical (boolean) + ***************************/ #ifdef NPY_HAVE_AVX512BW_MASK + #define npyv_and_b8 _kand_mask64 + #define npyv_and_b16 _kand_mask32 + #define npyv_or_b8 _kor_mask64 + #define npyv_or_b16 _kor_mask32 + #define npyv_xor_b8 _kxor_mask64 + #define npyv_xor_b16 _kxor_mask32 #define npyv_not_b8 _knot_mask64 #define npyv_not_b16 _knot_mask32 #elif defined(NPY_HAVE_AVX512BW) - NPY_FINLINE npyv_b8 npyv_not_b8(npyv_b8 a) + NPY_FINLINE npyv_b8 npyv_and_b8(npyv_b8 a, npyv_b8 b) + { return a & b; } + NPY_FINLINE npyv_b16 npyv_and_b16(npyv_b16 a, npyv_b16 b) + { return a & b; } + NPY_FINLINE npyv_b8 npyv_or_b8(npyv_b8 a, npyv_b8 b) + { return a | b; } + NPY_FINLINE npyv_b16 npyv_or_b16(npyv_b16 a, npyv_b16 b) + { return a | b; } + NPY_FINLINE npyv_b8 npyv_xor_b8(npyv_b8 a, npyv_b8 b) + { return a ^ b; } + NPY_FINLINE npyv_b16 npyv_xor_b16(npyv_b16 a, npyv_b16 b) + { return a ^ b; } + NPY_FINLINE npyv_b8 npyv_not_b8(npyv_b8 a) { return ~a; } NPY_FINLINE npyv_b16 npyv_not_b16(npyv_b16 a) { return ~a; } #else + #define npyv_and_b8 _mm512_and_si512 + #define npyv_and_b16 _mm512_and_si512 + #define npyv_or_b8 _mm512_or_si512 + #define npyv_or_b16 _mm512_or_si512 + #define npyv_xor_b8 _mm512_xor_si512 + #define npyv_xor_b16 _mm512_xor_si512 #define npyv_not_b8 npyv_not_u8 #define npyv_not_b16 npyv_not_u8 #endif + +#define npyv_and_b32 _mm512_kand +#define npyv_or_b32 _mm512_kor +#define npyv_xor_b32 _mm512_kxor #define npyv_not_b32 _mm512_knot -#define npyv_not_b64 _mm512_knot +#ifdef NPY_HAVE_AVX512DQ_MASK + #define npyv_and_b64 _kand_mask8 + #define npyv_or_b64 _kor_mask8 + #define npyv_xor_b64 _kxor_mask8 + #define npyv_not_b64 _knot_mask8 +#else + NPY_FINLINE npyv_b64 npyv_and_b64(npyv_b64 a, npyv_b64 b) + { return (npyv_b64)_mm512_kand((npyv_b32)a, (npyv_b32)b); } + NPY_FINLINE npyv_b64 npyv_or_b64(npyv_b64 a, npyv_b64 b) + { return (npyv_b64)_mm512_kor((npyv_b32)a, (npyv_b32)b); } + NPY_FINLINE npyv_b64 npyv_xor_b64(npyv_b64 a, npyv_b64 b) + { return (npyv_b64)_mm512_kxor((npyv_b32)a, (npyv_b32)b); } + NPY_FINLINE npyv_b64 npyv_not_b64(npyv_b64 a) + { return (npyv_b64)_mm512_knot((npyv_b32)a); } +#endif /*************************** * Comparison diff --git a/numpy/distutils/ccompiler_opt.py b/numpy/distutils/ccompiler_opt.py index 5fa17b2ee..e6c720399 100644 --- a/numpy/distutils/ccompiler_opt.py +++ b/numpy/distutils/ccompiler_opt.py @@ -259,7 +259,7 @@ class _Config: AVX512_SKX = dict( interest=42, implies="AVX512CD", group="AVX512VL AVX512BW AVX512DQ", detect="AVX512_SKX", implies_detect=False, - extra_checks="AVX512BW_MASK" + extra_checks="AVX512BW_MASK AVX512DQ_MASK" ), AVX512_CLX = dict( interest=43, implies="AVX512_SKX", group="AVX512VNNI", diff --git a/numpy/distutils/checks/extra_avx512dq_mask.c b/numpy/distutils/checks/extra_avx512dq_mask.c new file mode 100644 index 000000000..f0dc88bdd --- /dev/null +++ b/numpy/distutils/checks/extra_avx512dq_mask.c @@ -0,0 +1,16 @@ +#include <immintrin.h> +/** + * Test DQ mask operations due to: + * - MSVC has supported it since vs2019 see, + * https://developercommunity.visualstudio.com/content/problem/518298/missing-avx512bw-mask-intrinsics.html + * - Clang >= v8.0 + * - GCC >= v7.1 + */ +int main(void) +{ + __mmask8 m8 = _mm512_cmpeq_epi64_mask(_mm512_set1_epi64(1), _mm512_set1_epi64(1)); + m8 = _kor_mask8(m8, m8); + m8 = _kxor_mask8(m8, m8); + m8 = _cvtu32_mask8(_cvtmask8_u32(m8)); + return (int)_cvtmask8_u32(m8); +} |