diff options
author | Matti Picus <matti.picus@gmail.com> | 2021-11-16 19:12:14 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-11-16 19:12:14 +0200 |
commit | 38558de9ecb8b0ec7982a956106713f921e2d3d5 (patch) | |
tree | c6878d006ee85b2a474269a27437ba0acb26ebd0 | |
parent | f146ec16eed3f464f152ac9be4d51e33602f4e80 (diff) | |
parent | 9b1bd0d60b976e3e130cbb6d1eac84c9c6835adb (diff) | |
download | numpy-38558de9ecb8b0ec7982a956106713f921e2d3d5.tar.gz |
Merge pull request #20367 from HowJMay/simd-trunc
ENH, SIMD: add new universal intrinsics for trunc
-rw-r--r-- | numpy/core/src/_simd/_simd.dispatch.c.src | 4 | ||||
-rw-r--r-- | numpy/core/src/common/simd/avx2/math.h | 4 | ||||
-rw-r--r-- | numpy/core/src/common/simd/avx512/math.h | 4 | ||||
-rw-r--r-- | numpy/core/src/common/simd/neon/math.h | 33 | ||||
-rw-r--r-- | numpy/core/src/common/simd/sse/math.h | 28 | ||||
-rw-r--r-- | numpy/core/src/common/simd/vsx/math.h | 4 | ||||
-rw-r--r-- | numpy/core/tests/test_simd.py | 16 |
7 files changed, 85 insertions, 8 deletions
diff --git a/numpy/core/src/_simd/_simd.dispatch.c.src b/numpy/core/src/_simd/_simd.dispatch.c.src index 5c494ae7a..84de9a059 100644 --- a/numpy/core/src/_simd/_simd.dispatch.c.src +++ b/numpy/core/src/_simd/_simd.dispatch.c.src @@ -381,7 +381,7 @@ SIMD_IMPL_INTRIN_1(sumup_@sfx@, @esfx@, v@sfx@) ***************************/ #if @fp_only@ /**begin repeat1 - * #intrin = sqrt, recip, abs, square, ceil# + * #intrin = sqrt, recip, abs, square, ceil, trunc# */ SIMD_IMPL_INTRIN_1(@intrin@_@sfx@, v@sfx@, v@sfx@) /**end repeat1**/ @@ -615,7 +615,7 @@ SIMD_INTRIN_DEF(sumup_@sfx@) ***************************/ #if @fp_only@ /**begin repeat1 - * #intrin = sqrt, recip, abs, square, ceil# + * #intrin = sqrt, recip, abs, square, ceil, trunc# */ SIMD_INTRIN_DEF(@intrin@_@sfx@) /**end repeat1**/ diff --git a/numpy/core/src/common/simd/avx2/math.h b/numpy/core/src/common/simd/avx2/math.h index b1f3915a6..ec15e50e1 100644 --- a/numpy/core/src/common/simd/avx2/math.h +++ b/numpy/core/src/common/simd/avx2/math.h @@ -109,4 +109,8 @@ NPY_FINLINE npyv_s64 npyv_min_s64(npyv_s64 a, npyv_s64 b) #define npyv_ceil_f32 _mm256_ceil_ps #define npyv_ceil_f64 _mm256_ceil_pd +// trunc +#define npyv_trunc_f32(A) _mm256_round_ps(A, _MM_FROUND_TO_ZERO) +#define npyv_trunc_f64(A) _mm256_round_pd(A, _MM_FROUND_TO_ZERO) + #endif // _NPY_SIMD_AVX2_MATH_H diff --git a/numpy/core/src/common/simd/avx512/math.h b/numpy/core/src/common/simd/avx512/math.h index c4f8d3410..f30e50ad0 100644 --- a/numpy/core/src/common/simd/avx512/math.h +++ b/numpy/core/src/common/simd/avx512/math.h @@ -116,4 +116,8 @@ NPY_FINLINE npyv_f64 npyv_minp_f64(npyv_f64 a, npyv_f64 b) #define npyv_ceil_f32(A) _mm512_roundscale_ps(A, _MM_FROUND_TO_POS_INF) #define npyv_ceil_f64(A) _mm512_roundscale_pd(A, _MM_FROUND_TO_POS_INF) +// trunc +#define npyv_trunc_f32(A) _mm512_roundscale_ps(A, _MM_FROUND_TO_ZERO) +#define npyv_trunc_f64(A) _mm512_roundscale_pd(A, _MM_FROUND_TO_ZERO) + #endif // _NPY_SIMD_AVX512_MATH_H diff --git a/numpy/core/src/common/simd/neon/math.h b/numpy/core/src/common/simd/neon/math.h index 38c3899e4..19e5cd846 100644 --- a/numpy/core/src/common/simd/neon/math.h +++ b/numpy/core/src/common/simd/neon/math.h @@ -190,4 +190,37 @@ NPY_FINLINE npyv_s64 npyv_min_s64(npyv_s64 a, npyv_s64 b) #define npyv_ceil_f64 vrndpq_f64 #endif // NPY_SIMD_F64 +// trunc +#ifdef NPY_HAVE_ASIMD + #define npyv_trunc_f32 vrndq_f32 +#else + NPY_FINLINE npyv_f32 npyv_trunc_f32(npyv_f32 a) + { + const npyv_s32 szero = vreinterpretq_s32_f32(vdupq_n_f32(-0.0f)); + const npyv_s32 max_int = vdupq_n_s32(0x7fffffff); + /** + * On armv7, vcvtq.f32 handles special cases as follows: + * NaN return 0 + * +inf or +outrange return 0x80000000(-0.0f) + * -inf or -outrange return 0x7fffffff(nan) + */ + npyv_s32 roundi = vcvtq_s32_f32(a); + npyv_f32 round = vcvtq_f32_s32(roundi); + // respect signed zero, e.g. -0.5 -> -0.0 + npyv_f32 rzero = vreinterpretq_f32_s32(vorrq_s32( + vreinterpretq_s32_f32(round), + vandq_s32(vreinterpretq_s32_f32(a), szero) + )); + // if nan or overflow return a + npyv_u32 nnan = npyv_notnan_f32(a); + npyv_u32 overflow = vorrq_u32( + vceqq_s32(roundi, szero), vceqq_s32(roundi, max_int) + ); + return vbslq_f32(vbicq_u32(nnan, overflow), rzero, a); + } +#endif +#if NPY_SIMD_F64 + #define npyv_trunc_f64 vrndq_f64 +#endif // NPY_SIMD_F64 + #endif // _NPY_SIMD_NEON_MATH_H diff --git a/numpy/core/src/common/simd/sse/math.h b/numpy/core/src/common/simd/sse/math.h index 02eb06a29..5daf7711e 100644 --- a/numpy/core/src/common/simd/sse/math.h +++ b/numpy/core/src/common/simd/sse/math.h @@ -174,4 +174,32 @@ NPY_FINLINE npyv_s64 npyv_min_s64(npyv_s64 a, npyv_s64 b) } #endif +// trunc +#ifdef NPY_HAVE_SSE41 + #define npyv_trunc_f32(A) _mm_round_ps(A, _MM_FROUND_TO_ZERO) + #define npyv_trunc_f64(A) _mm_round_pd(A, _MM_FROUND_TO_ZERO) +#else + NPY_FINLINE npyv_f32 npyv_trunc_f32(npyv_f32 a) + { + const npyv_f32 szero = _mm_set1_ps(-0.0f); + npyv_s32 roundi = _mm_cvttps_epi32(a); + npyv_f32 trunc = _mm_cvtepi32_ps(roundi); + // respect signed zero, e.g. -0.5 -> -0.0 + npyv_f32 rzero = _mm_or_ps(trunc, _mm_and_ps(a, szero)); + // if overflow return a + return npyv_select_f32(_mm_cmpeq_epi32(roundi, _mm_castps_si128(szero)), a, rzero); + } + NPY_FINLINE npyv_f64 npyv_trunc_f64(npyv_f64 a) + { + const npyv_f64 szero = _mm_set1_pd(-0.0); + const npyv_f64 one = _mm_set1_pd(1.0); + const npyv_f64 two_power_52 = _mm_set1_pd(0x10000000000000); + npyv_f64 abs_a = npyv_abs_f64(a); + // round by add magic number 2^52 + npyv_f64 abs_round = _mm_sub_pd(_mm_add_pd(abs_a, two_power_52), two_power_52); + npyv_f64 subtrahend = _mm_and_pd(_mm_cmpgt_pd(abs_round, abs_a), one); + return _mm_or_pd(_mm_sub_pd(abs_round, subtrahend), _mm_and_pd(a, szero)); + } +#endif + #endif // _NPY_SIMD_SSE_MATH_H diff --git a/numpy/core/src/common/simd/vsx/math.h b/numpy/core/src/common/simd/vsx/math.h index f387dac4d..d138cae8a 100644 --- a/numpy/core/src/common/simd/vsx/math.h +++ b/numpy/core/src/common/simd/vsx/math.h @@ -73,4 +73,8 @@ NPY_FINLINE npyv_f64 npyv_square_f64(npyv_f64 a) #define npyv_ceil_f32 vec_ceil #define npyv_ceil_f64 vec_ceil +// trunc +#define npyv_trunc_f32 vec_trunc +#define npyv_trunc_f64 vec_trunc + #endif // _NPY_SIMD_VSX_MATH_H diff --git a/numpy/core/tests/test_simd.py b/numpy/core/tests/test_simd.py index 379fef8af..12a67c44d 100644 --- a/numpy/core/tests/test_simd.py +++ b/numpy/core/tests/test_simd.py @@ -330,12 +330,15 @@ class _SIMD_FP(_Test_Utility): square = self.square(vdata) assert square == data_square - @pytest.mark.parametrize("intrin, func", [("self.ceil", math.ceil)]) + @pytest.mark.parametrize("intrin, func", [("self.ceil", math.ceil), + ("self.trunc", math.trunc)]) def test_rounding(self, intrin, func): """ Test intrinsics: npyv_ceil_##SFX + npyv_trunc_##SFX """ + intrin_name = intrin intrin = eval(intrin) pinf, ninf, nan = self._pinfinity(), self._ninfinity(), self._nan() # special cases @@ -352,11 +355,12 @@ class _SIMD_FP(_Test_Utility): _round = intrin(vdata) assert _round == data_round # signed zero - for w in (-0.25, -0.30, -0.45): - _round = self._to_unsigned(intrin(self.setall(w))) - data_round = self._to_unsigned(self.setall(-0.0)) - assert _round == data_round - + if "ceil" in intrin_name or "trunc" in intrin_name: + for w in (-0.25, -0.30, -0.45): + _round = self._to_unsigned(intrin(self.setall(w))) + data_round = self._to_unsigned(self.setall(-0.0)) + assert _round == data_round + def test_max(self): """ Test intrinsics: |