summaryrefslogtreecommitdiff
path: root/numpy
diff options
context:
space:
mode:
authorRaghuveer Devulapalli <raghuveer.devulapalli@intel.com>2019-07-16 11:56:01 -0700
committerRaghuveer Devulapalli <raghuveer.devulapalli@intel.com>2019-07-18 21:09:34 -0700
commitf316efb36a029f20bc551799916b989445e12a4d (patch)
tree3763d1ea12b2fdb17653aa765c0fa725826e7f2f /numpy
parent09cffcefbdbdb3787af91d6b0fd04367937fe642 (diff)
downloadnumpy-f316efb36a029f20bc551799916b989445e12a4d.tar.gz
BUG: fixing bug where AVX expf does not output denormals
Diffstat (limited to 'numpy')
-rw-r--r--numpy/core/src/umath/simd.inc.src54
1 files changed, 48 insertions, 6 deletions
diff --git a/numpy/core/src/umath/simd.inc.src b/numpy/core/src/umath/simd.inc.src
index ecf2a7951..d5ba4c7d4 100644
--- a/numpy/core/src/umath/simd.inc.src
+++ b/numpy/core/src/umath/simd.inc.src
@@ -1223,6 +1223,46 @@ avx2_get_mantissa(__m256 x)
_mm256_and_si256(
_mm256_castps_si256(x), mantissa_bits), exp_126_bits));
}
+
+NPY_INLINE NPY_GCC_OPT_3 NPY_GCC_TARGET_AVX2 __m256
+avx2_scalef_ps(__m256 poly, __m256 quadrant)
+{
+ /*
+ * Handle denormals (which occur when quadrant <= -125):
+ * 1) This function computes poly*(2^quad) by adding the exponent of
+ poly to quad
+ * 2) When quad <= -125, the output is a denormal and the above logic
+ breaks down
+ * 3) To handle such cases, we split quadrant: -125 + (quadrant + 125)
+ * 4) poly*(2^-125) is computed the usual way
+ * 5) 2^(quad-125) can be computed by: 2 << abs(quad-125)
+ * 6) The final div operation generates the denormal
+ */
+ __m256 minquadrant = _mm256_set1_ps(-125.0f);
+ __m256 denormal_mask = _mm256_cmp_ps(quadrant, minquadrant, _CMP_LE_OQ);
+ if (_mm256_movemask_ps(denormal_mask) != 0x0000) {
+ __m256 quad_diff = _mm256_sub_ps(quadrant, minquadrant); // use negate
+ quad_diff = _mm256_sub_ps(_mm256_setzero_ps(), quad_diff); // make it +ve
+ quad_diff = _mm256_blendv_ps(_mm256_setzero_ps(), quad_diff, denormal_mask);
+ __m256i two_power_diff = _mm256_sllv_epi32(
+ _mm256_set1_epi32(1), _mm256_cvtps_epi32(quad_diff));
+ quadrant = _mm256_max_ps(quadrant, minquadrant); //keep quadrant >= -126
+ __m256i exponent = _mm256_slli_epi32(_mm256_cvtps_epi32(quadrant), 23);
+ poly = _mm256_castsi256_ps(
+ _mm256_add_epi32(
+ _mm256_castps_si256(poly), exponent));
+ __m256 denorm_poly = _mm256_div_ps(poly, _mm256_cvtepi32_ps(two_power_diff));
+ return _mm256_blendv_ps(poly, denorm_poly, denormal_mask);
+ }
+ else {
+ __m256i exponent = _mm256_slli_epi32(_mm256_cvtps_epi32(quadrant), 23);
+ poly = _mm256_castsi256_ps(
+ _mm256_add_epi32(
+ _mm256_castps_si256(poly), exponent));
+ return poly;
+ }
+}
+
#endif
#if defined HAVE_ATTRIBUTE_TARGET_AVX512F_WITH_INTRINSICS
@@ -1276,6 +1316,12 @@ avx512_get_mantissa(__m512 x)
{
return _mm512_getmant_ps(x, _MM_MANT_NORM_p5_1, _MM_MANT_SIGN_src);
}
+
+static NPY_INLINE NPY_GCC_OPT_3 NPY_GCC_TARGET_AVX512F __m512
+avx512_scalef_ps(__m512 poly, __m512 quadrant)
+{
+ return _mm512_scalef_ps(poly, quadrant);
+}
#endif
/**begin repeat
@@ -1345,7 +1391,7 @@ static NPY_GCC_OPT_3 NPY_GCC_TARGET_@ISA@ void
const npy_intp stride = steps/sizeof(npy_float);
const npy_int num_lanes = @BYTES@/sizeof(npy_float);
npy_float xmax = 88.72283935546875f;
- npy_float xmin = -87.3365478515625f;
+ npy_float xmin = -103.97208404541015625f;
npy_int indexarr[16];
for (npy_int ii = 0; ii < 16; ii++) {
indexarr[ii] = ii*stride;
@@ -1369,7 +1415,6 @@ static NPY_GCC_OPT_3 NPY_GCC_TARGET_@ISA@ void
@vtype@ zeros_f = _mm@vsize@_set1_ps(0.0f);
@vtype@ poly, num_poly, denom_poly, quadrant;
@vtype@i vindex = _mm@vsize@_loadu_si@vsize@((@vtype@i*)&indexarr[0]);
- @vtype@i exponent;
@mask@ xmax_mask, xmin_mask, nan_mask, inf_mask;
@mask@ overflow_mask = @isa@_get_partial_load_mask(0, num_lanes);
@@ -1426,10 +1471,7 @@ static NPY_GCC_OPT_3 NPY_GCC_TARGET_@ISA@ void
* exponent of quadrant to the exponent of poly. quadrant is an int,
* so extracting exponent is simply extracting 8 bits.
*/
- exponent = _mm@vsize@_slli_epi32(_mm@vsize@_cvtps_epi32(quadrant), 23);
- poly = _mm@vsize@_castsi@vsize@_ps(
- _mm@vsize@_add_epi32(
- _mm@vsize@_castps_si@vsize@(poly), exponent));
+ poly = @isa@_scalef_ps(poly, quadrant);
/*
* elem > xmax; return inf