diff options
-rw-r--r-- | numpy/core/src/common/simd/avx512/utils.h | 22 |
1 files changed, 21 insertions, 1 deletions
diff --git a/numpy/core/src/common/simd/avx512/utils.h b/numpy/core/src/common/simd/avx512/utils.h index 8066283c6..c3079283f 100644 --- a/numpy/core/src/common/simd/avx512/utils.h +++ b/numpy/core/src/common/simd/avx512/utils.h @@ -26,7 +26,7 @@ #define npyv512_combine_ps256(A, B) _mm512_insertf32x8(_mm512_castps256_ps512(A), B, 1) #else #define npyv512_combine_ps256(A, B) \ - _mm512_castsi512_ps(npyv512_combine_si256(_mm512_castps_si512(A), _mm512_castps_si512(B))) + _mm512_castsi512_ps(npyv512_combine_si256(_mm256_castps_si256(A), _mm256_castps_si256(B))) #endif #define NPYV_IMPL_AVX512_FROM_AVX2_1ARG(FN_NAME, INTRIN) \ @@ -39,6 +39,26 @@ return npyv512_combine_si256(l_a, h_a); \ } +#define NPYV_IMPL_AVX512_FROM_AVX2_PS_1ARG(FN_NAME, INTRIN) \ + NPY_FINLINE __m512 FN_NAME(__m512 a) \ + { \ + __m256 l_a = npyv512_lower_ps256(a); \ + __m256 h_a = npyv512_higher_ps256(a); \ + l_a = INTRIN(l_a); \ + h_a = INTRIN(h_a); \ + return npyv512_combine_ps256(l_a, h_a); \ + } + +#define NPYV_IMPL_AVX512_FROM_AVX2_PD_1ARG(FN_NAME, INTRIN) \ + NPY_FINLINE __m512d FN_NAME(__m512d a) \ + { \ + __m256d l_a = npyv512_lower_pd256(a); \ + __m256d h_a = npyv512_higher_pd256(a); \ + l_a = INTRIN(l_a); \ + h_a = INTRIN(h_a); \ + return npyv512_combine_pd256(l_a, h_a); \ + } + #define NPYV_IMPL_AVX512_FROM_AVX2_2ARG(FN_NAME, INTRIN) \ NPY_FINLINE __m512i FN_NAME(__m512i a, __m512i b) \ { \ |