summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/common/simd/avx512/arithmetic.h224
1 files changed, 222 insertions, 2 deletions
diff --git a/numpy/core/src/common/simd/avx512/arithmetic.h b/numpy/core/src/common/simd/avx512/arithmetic.h
index 450da7ea5..29e1af7e8 100644
--- a/numpy/core/src/common/simd/avx512/arithmetic.h
+++ b/numpy/core/src/common/simd/avx512/arithmetic.h
@@ -107,6 +107,226 @@ NPY_FINLINE __m512i npyv_mul_u8(__m512i a, __m512i b)
// TODO: after implment Packs intrins
/***************************
+ * Integer Division
+ ***************************/
+// See simd/intdiv.h for more clarification
+// divide each unsigned 8-bit element by divisor
+NPY_FINLINE npyv_u8 npyv_divc_u8(npyv_u8 a, const npyv_u8x3 divisor)
+{
+ const __m128i shf1 = _mm512_castsi512_si128(divisor.val[1]);
+ const __m128i shf2 = _mm512_castsi512_si128(divisor.val[2]);
+#ifdef NPY_HAVE_AVX512BW
+ const __m512i shf1b = _mm512_set1_epi8(0xFFU >> _mm_cvtsi128_si32(shf1));
+ const __m512i shf2b = _mm512_set1_epi8(0xFFU >> _mm_cvtsi128_si32(shf2));
+ // high part of unsigned multiplication
+ __m512i mulhi_odd = _mm512_mulhi_epu16(a, divisor.val[0]);
+ __m512i mulhi_even = _mm512_mulhi_epu16(_mm512_slli_epi16(a, 8), divisor.val[0]);
+ mulhi_even = _mm512_srli_epi16(mulhi_even, 8);
+ __m512i mulhi = _mm512_mask_mov_epi8(mulhi_even, 0xAAAAAAAAAAAAAAAA, mulhi_odd);
+ // floor(a/d) = (mulhi + ((a-mulhi) >> sh1)) >> sh2
+ __m512i q = _mm512_sub_epi8(a, mulhi);
+ q = _mm512_and_si512(_mm512_srl_epi16(q, shf1), shf1b);
+ q = _mm512_add_epi8(mulhi, q);
+ q = _mm512_and_si512(_mm512_srl_epi16(q, shf2), shf2b);
+ return q;
+#else
+ const __m256i bmask = _mm256_set1_epi32(0xFF00FF00);
+ const __m256i shf1b = _mm256_set1_epi8(0xFFU >> _mm_cvtsi128_si32(shf1));
+ const __m256i shf2b = _mm256_set1_epi8(0xFFU >> _mm_cvtsi128_si32(shf2));
+ const __m512i shf2bw= npyv512_combine_si256(shf2b, shf2b);
+ const __m256i mulc = npyv512_lower_si256(divisor.val[0]);
+ //// lower 256-bit
+ __m256i lo_a = npyv512_lower_si256(a);
+ // high part of unsigned multiplication
+ __m256i mulhi_odd = _mm256_mulhi_epu16(lo_a, mulc);
+ __m256i mulhi_even = _mm256_mulhi_epu16(_mm256_slli_epi16(lo_a, 8), mulc);
+ mulhi_even = _mm256_srli_epi16(mulhi_even, 8);
+ __m256i mulhi = _mm256_blendv_epi8(mulhi_even, mulhi_odd, bmask);
+ // floor(a/d) = (mulhi + ((a-mulhi) >> sh1)) >> sh2
+ __m256i lo_q = _mm256_sub_epi8(lo_a, mulhi);
+ lo_q = _mm256_and_si256(_mm256_srl_epi16(lo_q, shf1), shf1b);
+ lo_q = _mm256_add_epi8(mulhi, lo_q);
+ lo_q = _mm256_srl_epi16(lo_q, shf2); // no sign extend
+
+ //// higher 256-bit
+ __m256i hi_a = npyv512_higher_si256(a);
+ // high part of unsigned multiplication
+ mulhi_odd = _mm256_mulhi_epu16(hi_a, mulc);
+ mulhi_even = _mm256_mulhi_epu16(_mm256_slli_epi16(hi_a, 8), mulc);
+ mulhi_even = _mm256_srli_epi16(mulhi_even, 8);
+ mulhi = _mm256_blendv_epi8(mulhi_even, mulhi_odd, bmask);
+ // floor(a/d) = (mulhi + ((a-mulhi) >> sh1)) >> sh2
+ __m256i hi_q = _mm256_sub_epi8(hi_a, mulhi);
+ hi_q = _mm256_and_si256(_mm256_srl_epi16(hi_q, shf1), shf1b);
+ hi_q = _mm256_add_epi8(mulhi, hi_q);
+ hi_q = _mm256_srl_epi16(hi_q, shf2); // no sign extend
+ return _mm512_and_si512(npyv512_combine_si256(lo_q, hi_q), shf2bw); // extend sign
+#endif
+}
+// divide each signed 8-bit element by divisor (round towards zero)
+NPY_FINLINE npyv_s16 npyv_divc_s16(npyv_s16 a, const npyv_s16x3 divisor);
+NPY_FINLINE npyv_s8 npyv_divc_s8(npyv_s8 a, const npyv_s8x3 divisor)
+{
+ __m512i divc_even = npyv_divc_s16(npyv_shri_s16(npyv_shli_s16(a, 8), 8), divisor);
+ __m512i divc_odd = npyv_divc_s16(npyv_shri_s16(a, 8), divisor);
+ divc_odd = npyv_shli_s16(divc_odd, 8);
+#ifdef NPY_HAVE_AVX512BW
+ return _mm512_mask_mov_epi8(divc_even, 0xAAAAAAAAAAAAAAAA, divc_odd);
+#else
+ const __m512i bmask = _mm512_set1_epi32(0x00FF00FF);
+ return npyv_select_u8(bmask, divc_even, divc_odd);
+#endif
+}
+// divide each unsigned 16-bit element by divisor
+NPY_FINLINE npyv_u16 npyv_divc_u16(npyv_u16 a, const npyv_u16x3 divisor)
+{
+ const __m128i shf1 = _mm512_castsi512_si128(divisor.val[1]);
+ const __m128i shf2 = _mm512_castsi512_si128(divisor.val[2]);
+ // floor(a/d) = (mulhi + ((a-mulhi) >> sh1)) >> sh2
+ #define NPYV__DIVC_U16(RLEN, A, MULC, R) \
+ mulhi = _mm##RLEN##_mulhi_epu16(A, MULC); \
+ R = _mm##RLEN##_sub_epi16(A, mulhi); \
+ R = _mm##RLEN##_srl_epi16(R, shf1); \
+ R = _mm##RLEN##_add_epi16(mulhi, R); \
+ R = _mm##RLEN##_srl_epi16(R, shf2);
+
+#ifdef NPY_HAVE_AVX512BW
+ __m512i mulhi, q;
+ NPYV__DIVC_U16(512, a, divisor.val[0], q)
+ return q;
+#else
+ const __m256i m = npyv512_lower_si256(divisor.val[0]);
+ __m256i lo_a = npyv512_lower_si256(a);
+ __m256i hi_a = npyv512_higher_si256(a);
+
+ __m256i mulhi, lo_q, hi_q;
+ NPYV__DIVC_U16(256, lo_a, m, lo_q)
+ NPYV__DIVC_U16(256, hi_a, m, hi_q)
+ return npyv512_combine_si256(lo_q, hi_q);
+#endif
+ #undef NPYV__DIVC_U16
+}
+// divide each signed 16-bit element by divisor (round towards zero)
+NPY_FINLINE npyv_s16 npyv_divc_s16(npyv_s16 a, const npyv_s16x3 divisor)
+{
+ const __m128i shf1 = _mm512_castsi512_si128(divisor.val[1]);
+ // q = ((a + mulhi) >> sh1) - XSIGN(a)
+ // trunc(a/d) = (q ^ dsign) - dsign
+ #define NPYV__DIVC_S16(RLEN, A, MULC, DSIGN, R) \
+ mulhi = _mm##RLEN##_mulhi_epi16(A, MULC); \
+ R = _mm##RLEN##_sra_epi16(_mm##RLEN##_add_epi16(A, mulhi), shf1); \
+ R = _mm##RLEN##_sub_epi16(R, _mm##RLEN##_srai_epi16(A, 15)); \
+ R = _mm##RLEN##_sub_epi16(_mm##RLEN##_xor_si##RLEN(R, DSIGN), DSIGN);
+
+#ifdef NPY_HAVE_AVX512BW
+ __m512i mulhi, q;
+ NPYV__DIVC_S16(512, a, divisor.val[0], divisor.val[2], q)
+ return q;
+#else
+ const __m256i m = npyv512_lower_si256(divisor.val[0]);
+ const __m256i dsign = npyv512_lower_si256(divisor.val[2]);
+ __m256i lo_a = npyv512_lower_si256(a);
+ __m256i hi_a = npyv512_higher_si256(a);
+
+ __m256i mulhi, lo_q, hi_q;
+ NPYV__DIVC_S16(256, lo_a, m, dsign, lo_q)
+ NPYV__DIVC_S16(256, hi_a, m, dsign, hi_q)
+ return npyv512_combine_si256(lo_q, hi_q);
+#endif
+ #undef NPYV__DIVC_S16
+}
+// divide each unsigned 32-bit element by divisor
+NPY_FINLINE npyv_u32 npyv_divc_u32(npyv_u32 a, const npyv_u32x3 divisor)
+{
+ const __m128i shf1 = _mm512_castsi512_si128(divisor.val[1]);
+ const __m128i shf2 = _mm512_castsi512_si128(divisor.val[2]);
+ // high part of unsigned multiplication
+ __m512i mulhi_even = _mm512_srli_epi64(_mm512_mul_epu32(a, divisor.val[0]), 32);
+ __m512i mulhi_odd = _mm512_mul_epu32(_mm512_srli_epi64(a, 32), divisor.val[0]);
+ __m512i mulhi = _mm512_mask_mov_epi32(mulhi_even, 0xAAAA, mulhi_odd);
+ // floor(a/d) = (mulhi + ((a-mulhi) >> sh1)) >> sh2
+ __m512i q = _mm512_sub_epi32(a, mulhi);
+ q = _mm512_srl_epi32(q, shf1);
+ q = _mm512_add_epi32(mulhi, q);
+ q = _mm512_srl_epi32(q, shf2);
+ return q;
+}
+// divide each signed 32-bit element by divisor (round towards zero)
+NPY_FINLINE npyv_s32 npyv_divc_s32(npyv_s32 a, const npyv_s32x3 divisor)
+{
+ const __m128i shf1 = _mm512_castsi512_si128(divisor.val[1]);
+ // high part of signed multiplication
+ __m512i mulhi_even = _mm512_srli_epi64(_mm512_mul_epi32(a, divisor.val[0]), 32);
+ __m512i mulhi_odd = _mm512_mul_epi32(_mm512_srli_epi64(a, 32), divisor.val[0]);
+ __m512i mulhi = _mm512_mask_mov_epi32(mulhi_even, 0xAAAA, mulhi_odd);
+ // q = ((a + mulhi) >> sh1) - XSIGN(a)
+ // trunc(a/d) = (q ^ dsign) - dsign
+ __m512i q = _mm512_sra_epi32(_mm512_add_epi32(a, mulhi), shf1);
+ q = _mm512_sub_epi32(q, _mm512_srai_epi32(a, 31));
+ q = _mm512_sub_epi32(_mm512_xor_si512(q, divisor.val[2]), divisor.val[2]);
+ return q;
+}
+// returns the high 64 bits of unsigned 64-bit multiplication
+// xref https://stackoverflow.com/a/28827013
+NPY_FINLINE npyv_u64 npyv__mullhi_u64(npyv_u64 a, npyv_u64 b)
+{
+ __m512i lomask = npyv_setall_s64(0xffffffff);
+ __m512i a_hi = _mm512_srli_epi64(a, 32); // a0l, a0h, a1l, a1h
+ __m512i b_hi = _mm512_srli_epi64(b, 32); // b0l, b0h, b1l, b1h
+ // compute partial products
+ __m512i w0 = _mm512_mul_epu32(a, b); // a0l*b0l, a1l*b1l
+ __m512i w1 = _mm512_mul_epu32(a, b_hi); // a0l*b0h, a1l*b1h
+ __m512i w2 = _mm512_mul_epu32(a_hi, b); // a0h*b0l, a1h*b0l
+ __m512i w3 = _mm512_mul_epu32(a_hi, b_hi); // a0h*b0h, a1h*b1h
+ // sum partial products
+ __m512i w0h = _mm512_srli_epi64(w0, 32);
+ __m512i s1 = _mm512_add_epi64(w1, w0h);
+ __m512i s1l = _mm512_and_si512(s1, lomask);
+ __m512i s1h = _mm512_srli_epi64(s1, 32);
+
+ __m512i s2 = _mm512_add_epi64(w2, s1l);
+ __m512i s2h = _mm512_srli_epi64(s2, 32);
+
+ __m512i hi = _mm512_add_epi64(w3, s1h);
+ hi = _mm512_add_epi64(hi, s2h);
+ return hi;
+}
+// divide each unsigned 64-bit element by a divisor
+NPY_FINLINE npyv_u64 npyv_divc_u64(npyv_u64 a, const npyv_u64x3 divisor)
+{
+ const __m128i shf1 = _mm512_castsi512_si128(divisor.val[1]);
+ const __m128i shf2 = _mm512_castsi512_si128(divisor.val[2]);
+ // high part of unsigned multiplication
+ __m512i mulhi = npyv__mullhi_u64(a, divisor.val[0]);
+ // floor(a/d) = (mulhi + ((a-mulhi) >> sh1)) >> sh2
+ __m512i q = _mm512_sub_epi64(a, mulhi);
+ q = _mm512_srl_epi64(q, shf1);
+ q = _mm512_add_epi64(mulhi, q);
+ q = _mm512_srl_epi64(q, shf2);
+ return q;
+}
+// divide each unsigned 64-bit element by a divisor (round towards zero)
+NPY_FINLINE npyv_s64 npyv_divc_s64(npyv_s64 a, const npyv_s64x3 divisor)
+{
+ const __m128i shf1 = _mm512_castsi512_si128(divisor.val[1]);
+ // high part of unsigned multiplication
+ __m512i mulhi = npyv__mullhi_u64(a, divisor.val[0]);
+ // convert unsigned to signed high multiplication
+ // mulhi - ((a < 0) ? m : 0) - ((m < 0) ? a : 0);
+ __m512i asign = _mm512_srai_epi64(a, 63);
+ __m512i msign = _mm512_srai_epi64(divisor.val[0], 63);
+ __m512i m_asign = _mm512_and_si512(divisor.val[0], asign);
+ __m512i a_msign = _mm512_and_si512(a, msign);
+ mulhi = _mm512_sub_epi64(mulhi, m_asign);
+ mulhi = _mm512_sub_epi64(mulhi, a_msign);
+ // q = ((a + mulhi) >> sh1) - XSIGN(a)
+ // trunc(a/d) = (q ^ dsign) - dsign
+ __m512i q = _mm512_sra_epi64(_mm512_add_epi64(a, mulhi), shf1);
+ q = _mm512_sub_epi64(q, asign);
+ q = _mm512_sub_epi64(_mm512_xor_si512(q, divisor.val[2]), divisor.val[2]);
+ return q;
+}
+/***************************
* Division
***************************/
// TODO: emulate integer division
@@ -136,11 +356,11 @@ NPY_FINLINE __m512i npyv_mul_u8(__m512i a, __m512i b)
* 2- shuff(cross) /add /shuff(cross) /add /shuff /add /shuff /add /extract
* 3- _mm512_reduce_add_ps/pd
* The first one is been widely used by many projects
- *
+ *
* the second one is used by Intel Compiler, maybe because the
* latency of hadd increased by (2-3) starting from Skylake-X which makes two
* extra shuffles(non-cross) cheaper. check https://godbolt.org/z/s3G9Er for more info.
- *
+ *
* The third one is almost the same as the second one but only works for
* intel compiler/GCC 7.1/Clang 4, we still need to support older GCC.
***************************/