diff options
-rw-r--r-- | numpy/core/src/common/simd/avx512/conversion.h | 8 | ||||
-rw-r--r-- | numpy/core/src/common/simd/avx512/operators.h | 94 |
2 files changed, 55 insertions, 47 deletions
diff --git a/numpy/core/src/common/simd/avx512/conversion.h b/numpy/core/src/common/simd/avx512/conversion.h index 7f4ae484d..0bd44179b 100644 --- a/numpy/core/src/common/simd/avx512/conversion.h +++ b/numpy/core/src/common/simd/avx512/conversion.h @@ -119,7 +119,13 @@ NPY_FINLINE npy_uint64 npyv_tobits_b16(npyv_b16 a) NPY_FINLINE npy_uint64 npyv_tobits_b32(npyv_b32 a) { return (npy_uint16)a; } NPY_FINLINE npy_uint64 npyv_tobits_b64(npyv_b64 a) -{ return (npy_uint8)a; } +{ +#ifdef NPY_HAVE_AVX512DQ_MASK + return _cvtmask8_u32(a); +#else + return (npy_uint8)a; +#endif +} // round to nearest integer (assuming even) #define npyv_round_s32_f32 _mm512_cvtps_epi32 diff --git a/numpy/core/src/common/simd/avx512/operators.h b/numpy/core/src/common/simd/avx512/operators.h index 5f1771770..d53932fa8 100644 --- a/numpy/core/src/common/simd/avx512/operators.h +++ b/numpy/core/src/common/simd/avx512/operators.h @@ -90,21 +90,6 @@ NPYV_IMPL_AVX512_FROM_SI512_PS_2ARG(npyv_and_f32, _mm512_and_si512) NPYV_IMPL_AVX512_FROM_SI512_PD_2ARG(npyv_and_f64, _mm512_and_si512) #endif -#ifdef NPY_HAVE_AVX512BW_MASK - #define npyv_and_b8 _kand_mask64 - #define npyv_and_b16 _kand_mask32 -#elif defined(NPY_HAVE_AVX512BW) - NPY_FINLINE npyv_b8 npyv_and_b8(npyv_b8 a, npyv_b8 b) - { return a & b; } - NPY_FINLINE npyv_b16 npyv_and_b16(npyv_b16 a, npyv_b16 b) - { return a & b; } -#else - #define npyv_and_b8 _mm512_and_si512 - #define npyv_and_b16 _mm512_and_si512 -#endif -#define npyv_and_b32 _mm512_kand -#define npyv_and_b64 _mm512_kand - // OR #define npyv_or_u8 _mm512_or_si512 #define npyv_or_s8 _mm512_or_si512 @@ -121,20 +106,6 @@ NPYV_IMPL_AVX512_FROM_SI512_PS_2ARG(npyv_or_f32, _mm512_or_si512) NPYV_IMPL_AVX512_FROM_SI512_PD_2ARG(npyv_or_f64, _mm512_or_si512) #endif -#ifdef NPY_HAVE_AVX512BW_MASK - #define npyv_or_b8 _kor_mask64 - #define npyv_or_b16 _kor_mask32 -#elif defined(NPY_HAVE_AVX512BW) - NPY_FINLINE npyv_b8 npyv_or_b8(npyv_b8 a, npyv_b8 b) - { return a | b; } - NPY_FINLINE npyv_b16 npyv_or_b16(npyv_b16 a, npyv_b16 b) - { return a | b; } -#else - #define npyv_or_b8 _mm512_or_si512 - #define npyv_or_b16 _mm512_or_si512 -#endif -#define npyv_or_b32 _mm512_kor -#define npyv_or_b64 _mm512_kor // XOR #define npyv_xor_u8 _mm512_xor_si512 @@ -152,21 +123,6 @@ NPYV_IMPL_AVX512_FROM_SI512_PS_2ARG(npyv_xor_f32, _mm512_xor_si512) NPYV_IMPL_AVX512_FROM_SI512_PD_2ARG(npyv_xor_f64, _mm512_xor_si512) #endif -#ifdef NPY_HAVE_AVX512BW_MASK - #define npyv_xor_b8 _kxor_mask64 - #define npyv_xor_b16 _kxor_mask32 -#elif defined(NPY_HAVE_AVX512BW) - NPY_FINLINE npyv_b8 npyv_xor_b8(npyv_b8 a, npyv_b8 b) - { return a ^ b; } - NPY_FINLINE npyv_b16 npyv_xor_b16(npyv_b16 a, npyv_b16 b) - { return a ^ b; } -#else - #define npyv_xor_b8 _mm512_xor_si512 - #define npyv_xor_b16 _mm512_xor_si512 -#endif -#define npyv_xor_b32 _mm512_kxor -#define npyv_xor_b64 _mm512_kxor - // NOT #define npyv_not_u8(A) _mm512_xor_si512(A, _mm512_set1_epi32(-1)) #define npyv_not_s8 npyv_not_u8 @@ -183,21 +139,67 @@ #define npyv_not_f32(A) _mm512_castsi512_ps(npyv_not_u32(_mm512_castps_si512(A))) #define npyv_not_f64(A) _mm512_castsi512_pd(npyv_not_u64(_mm512_castpd_si512(A))) #endif + +/*************************** + * Logical (boolean) + ***************************/ #ifdef NPY_HAVE_AVX512BW_MASK + #define npyv_and_b8 _kand_mask64 + #define npyv_and_b16 _kand_mask32 + #define npyv_or_b8 _kor_mask64 + #define npyv_or_b16 _kor_mask32 + #define npyv_xor_b8 _kxor_mask64 + #define npyv_xor_b16 _kxor_mask32 #define npyv_not_b8 _knot_mask64 #define npyv_not_b16 _knot_mask32 #elif defined(NPY_HAVE_AVX512BW) - NPY_FINLINE npyv_b8 npyv_not_b8(npyv_b8 a) + NPY_FINLINE npyv_b8 npyv_and_b8(npyv_b8 a, npyv_b8 b) + { return a & b; } + NPY_FINLINE npyv_b16 npyv_and_b16(npyv_b16 a, npyv_b16 b) + { return a & b; } + NPY_FINLINE npyv_b8 npyv_or_b8(npyv_b8 a, npyv_b8 b) + { return a | b; } + NPY_FINLINE npyv_b16 npyv_or_b16(npyv_b16 a, npyv_b16 b) + { return a | b; } + NPY_FINLINE npyv_b8 npyv_xor_b8(npyv_b8 a, npyv_b8 b) + { return a ^ b; } + NPY_FINLINE npyv_b16 npyv_xor_b16(npyv_b16 a, npyv_b16 b) + { return a ^ b; } + NPY_FINLINE npyv_b8 npyv_not_b8(npyv_b8 a) { return ~a; } NPY_FINLINE npyv_b16 npyv_not_b16(npyv_b16 a) { return ~a; } #else + #define npyv_and_b8 _mm512_and_si512 + #define npyv_and_b16 _mm512_and_si512 + #define npyv_or_b8 _mm512_or_si512 + #define npyv_or_b16 _mm512_or_si512 + #define npyv_xor_b8 _mm512_xor_si512 + #define npyv_xor_b16 _mm512_xor_si512 #define npyv_not_b8 npyv_not_u8 #define npyv_not_b16 npyv_not_u8 #endif + +#define npyv_and_b32 _mm512_kand +#define npyv_or_b32 _mm512_kor +#define npyv_xor_b32 _mm512_kxor #define npyv_not_b32 _mm512_knot -#define npyv_not_b64 _mm512_knot +#ifdef NPY_HAVE_AVX512DQ_MASK + #define npyv_and_b64 _kand_mask8 + #define npyv_or_b64 _kor_mask8 + #define npyv_xor_b64 _kxor_mask8 + #define npyv_not_b64 _knot_mask8 +#else + NPY_FINLINE npyv_b64 npyv_and_b64(npyv_b64 a, npyv_b64 b) + { return (npyv_b64)_mm512_kand((npyv_b32)a, (npyv_b32)b); } + NPY_FINLINE npyv_b64 npyv_or_b64(npyv_b64 a, npyv_b64 b) + { return (npyv_b64)_mm512_kor((npyv_b32)a, (npyv_b32)b); } + NPY_FINLINE npyv_b64 npyv_xor_b64(npyv_b64 a, npyv_b64 b) + { return (npyv_b64)_mm512_kxor((npyv_b32)a, (npyv_b32)b); } + NPY_FINLINE npyv_b64 npyv_not_b64(npyv_b64 a) + { return (npyv_b64)_mm512_knot((npyv_b32)a); } +#endif /*************************** * Comparison |