summaryrefslogtreecommitdiff
path: root/numpy/core/src
diff options
context:
space:
mode:
authorRaghuveer Devulapalli <raghuveer.devulapalli@intel.com>2019-03-05 09:13:55 -0800
committerRaghuveer Devulapalli <raghuveer.devulapalli@intel.com>2019-04-19 10:47:15 -0700
commit9754a207828f377654c79873e38d475bb87d98de (patch)
tree6512d0febf26593ac946722d9c38ca57a4bbbdfb /numpy/core/src
parent31e71d7ce8d447cb74b9fb83875361cf7dba4579 (diff)
downloadnumpy-9754a207828f377654c79873e38d475bb87d98de.tar.gz
ENH: vectorizing float32 implementation of np.exp & np.log
This commit implements vectorized single precision exponential and natural log using AVX2 and AVX512. Accuracy: | Function | Max ULP Error | Max Relative Error | |----------|---------------|--------------------| | np.exp | 2.52 | 2.1E-07 | | np.log | 3.83 | 2.4E-07 | Performance: (1) Micro-benchmarks: measured execution time of np.exp and np.log using timeit package in python. Each function is executed 1000 times and this is repeated 100 times. The standard deviation for all the runs was less than 2% of their mean value and hence not included in the data. The vectorized implementation was upto 7.6x faster than the scalar version. | Function | NumPy1.16 | AVX2 | AVX512 | AVX2 speedup | AVX512 speedup | | -------- | --------- | ------ | ------ | ------------ | -------------- | | np.exp | 0.395s | 0.112s | 0.055s | 3.56x | 7.25x | | np.log | 0.456s | 0.147s | 0.059s | 3.10x | 7.64x | (2) Logistic regression: exp and log are heavily used in training neural networks (as part of sigmoid activation function and loss function respectively). This patch significantly speeds up training a logistic regression model. As an example, we measured how much time it takes to train a model with 15 features using 1000 training data points. We observed a 2x speed up to train the model to achieve a loss function error < 10E-04. | Function | NumPy1.16 | AVX2 | AVX512 | AVX2 speedup | AVX512 speedup | | -------------- | ---------- | ------ | ------ | ------------ | -------------- | | logistic.train | 121.0s | 75.02s | 60.60s | 1.61x | 2.02x |
Diffstat (limited to 'numpy/core/src')
-rw-r--r--numpy/core/src/umath/cpuid.c19
-rw-r--r--numpy/core/src/umath/loops.c.src37
-rw-r--r--numpy/core/src/umath/loops.h.src16
-rw-r--r--numpy/core/src/umath/simd.inc.src401
4 files changed, 472 insertions, 1 deletions
diff --git a/numpy/core/src/umath/cpuid.c b/numpy/core/src/umath/cpuid.c
index 6744ceb05..ab97e7afc 100644
--- a/numpy/core/src/umath/cpuid.c
+++ b/numpy/core/src/umath/cpuid.c
@@ -11,6 +11,7 @@
#define XCR_XFEATURE_ENABLED_MASK 0x0
#define XSTATE_SSE 0x2
#define XSTATE_YMM 0x4
+#define XSTATE_ZMM 0x70
/*
* verify the OS supports avx instructions
@@ -33,6 +34,19 @@ int os_avx_support(void)
#endif
}
+static NPY_INLINE
+int os_avx512_support(void)
+{
+#if HAVE_XGETBV
+ unsigned int eax, edx;
+ unsigned int ecx = XCR_XFEATURE_ENABLED_MASK;
+ unsigned int xcr0 = XSTATE_ZMM | XSTATE_YMM | XSTATE_SSE;
+ __asm__("xgetbv" : "=a" (eax), "=d" (edx) : "c" (ecx));
+ return (eax & xcr0) == xcr0;
+#else
+ return 0;
+#endif
+}
/*
* Primitive cpu feature detect function
@@ -42,7 +56,10 @@ NPY_NO_EXPORT int
npy_cpu_supports(const char * feature)
{
#ifdef HAVE___BUILTIN_CPU_SUPPORTS
- if (strcmp(feature, "avx2") == 0) {
+ if (strcmp(feature, "avx512f") == 0) {
+ return __builtin_cpu_supports("avx512f") && os_avx512_support();
+ }
+ else if (strcmp(feature, "avx2") == 0) {
return __builtin_cpu_supports("avx2") && os_avx_support();
}
else if (strcmp(feature, "avx") == 0) {
diff --git a/numpy/core/src/umath/loops.c.src b/numpy/core/src/umath/loops.c.src
index 290a87a33..024d495cd 100644
--- a/numpy/core/src/umath/loops.c.src
+++ b/numpy/core/src/umath/loops.c.src
@@ -1569,6 +1569,43 @@ NPY_NO_EXPORT void
/**end repeat**/
+/**begin repeat
+ * #func = exp, log#
+ * #scalarf = npy_expf, npy_logf#
+ */
+
+NPY_NO_EXPORT NPY_GCC_OPT_3 void
+FLOAT_@func@(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSED(data))
+{
+ UNARY_LOOP {
+ const npy_float in1 = *(npy_float *)ip1;
+ *(npy_float *)op1 = @scalarf@(in1);
+ }
+}
+
+/**end repeat**/
+
+/**begin repeat
+ * #isa = avx512f, avx2#
+ * #ISA = AVX512F, AVX2#
+ * #CHK = HAVE_ATTRIBUTE_TARGET_AVX512F, HAVE_ATTRIBUTE_TARGET_AVX2#
+ * #ATTR = NPY_GCC_TARGET_AVX512F, NPY_GCC_TARGET_AVX2#
+ */
+
+/**begin repeat1
+ * #func = exp, log#
+ */
+
+#if @CHK@
+NPY_NO_EXPORT NPY_GCC_OPT_3 @ATTR@ void
+FLOAT_@func@_@isa@(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSED(data))
+{
+ @ISA@_@func@_FLOAT((npy_float*)args[1], (npy_float*)args[0], dimensions[0]);
+}
+#endif
+
+/**end repeat1**/
+/**end repeat**/
/**begin repeat
* Float types
diff --git a/numpy/core/src/umath/loops.h.src b/numpy/core/src/umath/loops.h.src
index 9dc1b7016..8dd3170e3 100644
--- a/numpy/core/src/umath/loops.h.src
+++ b/numpy/core/src/umath/loops.h.src
@@ -178,6 +178,22 @@ NPY_NO_EXPORT void
/**end repeat**/
/**begin repeat
+ * #func = exp, log#
+ */
+NPY_NO_EXPORT void
+FLOAT_@func@(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSED(func));
+
+/**begin repeat1
+ * #isa = avx512f, avx2#
+ */
+
+NPY_NO_EXPORT void
+FLOAT_@func@_@isa@(char **args, npy_intp *dimensions, npy_intp *steps, void *NPY_UNUSED(func));
+
+/**end repeat1**/
+/**end repeat**/
+
+/**begin repeat
* Float types
* #TYPE = HALF, FLOAT, DOUBLE, LONGDOUBLE#
* #c = f, f, , l#
diff --git a/numpy/core/src/umath/simd.inc.src b/numpy/core/src/umath/simd.inc.src
index 4bb8569be..f5684c30b 100644
--- a/numpy/core/src/umath/simd.inc.src
+++ b/numpy/core/src/umath/simd.inc.src
@@ -122,6 +122,27 @@ abs_ptrdiff(char *a, char *b)
*/
/**begin repeat
+ * #ISA = AVX2, AVX512F#
+ */
+
+/* prototypes */
+#if defined NPY_HAVE_@ISA@_INTRINSICS
+
+/**begin repeat1
+ * #func = exp, log#
+ */
+
+static void
+@ISA@_@func@_FLOAT(npy_float *, npy_float *, const npy_int n);
+
+/**end repeat1**/
+#endif
+
+/**end repeat**/
+
+
+
+/**begin repeat
* Float types
* #type = npy_float, npy_double, npy_longdouble#
* #TYPE = FLOAT, DOUBLE, LONGDOUBLE#
@@ -1075,6 +1096,386 @@ sse2_@kind@_@TYPE@(@type@ * ip, @type@ * op, const npy_intp n)
/**end repeat**/
+/* bunch of helper functions used in ISA_exp/log_FLOAT*/
+
+#if HAVE_ATTRIBUTE_TARGET_AVX2
+NPY_INLINE NPY_GCC_OPT_3 NPY_GCC_TARGET_AVX2 __m256
+avx2_fmadd(__m256 a, __m256 b, __m256 c)
+{
+ return _mm256_add_ps(_mm256_mul_ps(a, b), c);
+}
+
+NPY_INLINE NPY_GCC_OPT_3 NPY_GCC_TARGET_AVX2 __m256
+avx2_get_full_load_mask(void)
+{
+ return _mm256_set1_ps(-1.0);
+}
+
+NPY_INLINE NPY_GCC_OPT_3 NPY_GCC_TARGET_AVX2 __m256
+avx2_get_partial_load_mask(const npy_int num_elem, const npy_int total_elem)
+{
+ float maskint[16] = {-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,
+ 1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0};
+ float* addr = maskint + total_elem - num_elem;
+ return _mm256_loadu_ps(addr);
+}
+
+NPY_INLINE NPY_GCC_OPT_3 NPY_GCC_TARGET_AVX2 __m256
+avx2_masked_load(__m256 mask, npy_float* addr)
+{
+ return _mm256_maskload_ps(addr, _mm256_cvtps_epi32(mask));
+}
+
+NPY_INLINE NPY_GCC_OPT_3 NPY_GCC_TARGET_AVX2 __m256
+avx2_set_masked_lanes(__m256 x, __m256 val, __m256 mask)
+{
+ return _mm256_blendv_ps(x, val, mask);
+}
+
+NPY_INLINE NPY_GCC_OPT_3 NPY_GCC_TARGET_AVX2 __m256
+avx2_blend(__m256 x, __m256 y, __m256 ymask)
+{
+ return _mm256_blendv_ps(x, y, ymask);
+}
+
+NPY_INLINE NPY_GCC_OPT_3 NPY_GCC_TARGET_AVX2 __m256
+avx2_get_exponent(__m256 x)
+{
+ /*
+ * Special handling of denormals:
+ * 1) Multiply denormal elements with 2**100 (0x71800000)
+ * 2) Get the 8 bits of unbiased exponent
+ * 3) Subtract 100 from exponent of denormals
+ */
+
+ __m256 two_power_100 = _mm256_castsi256_ps(_mm256_set1_epi32(0x71800000));
+ __m256 denormal_mask = _mm256_cmp_ps(x, _mm256_set1_ps(FLT_MIN), _CMP_LT_OQ);
+ __m256 temp = _mm256_mul_ps(x, two_power_100);
+ x = _mm256_blendv_ps(x, temp, denormal_mask);
+
+ __m256 exp = _mm256_cvtepi32_ps(
+ _mm256_sub_epi32(
+ _mm256_srli_epi32(
+ _mm256_castps_si256(x), 23),_mm256_set1_epi32(0x7E)));
+
+ __m256 denorm_exp = _mm256_sub_ps(exp, _mm256_set1_ps(100.0f));
+ return _mm256_blendv_ps(exp, denorm_exp, denormal_mask);
+}
+
+NPY_INLINE NPY_GCC_OPT_3 NPY_GCC_TARGET_AVX2 __m256
+avx2_get_mantissa(__m256 x)
+{
+ /*
+ * Special handling of denormals:
+ * 1) Multiply denormal elements with 2**100 (0x71800000)
+ * 2) Get the 23 bits of mantissa
+ * 3) Mantissa for denormals is not affected by the multiplication
+ */
+
+ __m256 two_power_100 = _mm256_castsi256_ps(_mm256_set1_epi32(0x71800000));
+ __m256 denormal_mask = _mm256_cmp_ps(x, _mm256_set1_ps(FLT_MIN), _CMP_LT_OQ);
+ __m256 temp = _mm256_mul_ps(x, two_power_100);
+ x = _mm256_blendv_ps(x, temp, denormal_mask);
+
+ __m256i mantissa_bits = _mm256_set1_epi32(0x7fffff);
+ __m256i exp_126_bits = _mm256_set1_epi32(126 << 23);
+ return _mm256_castsi256_ps(
+ _mm256_or_si256(
+ _mm256_and_si256(
+ _mm256_castps_si256(x), mantissa_bits), exp_126_bits));
+}
+#endif
+
+#if HAVE_ATTRIBUTE_TARGET_AVX512F
+NPY_INLINE NPY_GCC_OPT_3 NPY_GCC_TARGET_AVX512F __mmask16
+avx512_get_full_load_mask(void)
+{
+ return 0xFFFF;
+}
+
+NPY_INLINE NPY_GCC_OPT_3 NPY_GCC_TARGET_AVX512F __mmask16
+avx512_get_partial_load_mask(const npy_int num_elem, const npy_int total_elem)
+{
+ return (0x0001 << num_elem) - 0x0001;
+}
+
+NPY_INLINE NPY_GCC_OPT_3 NPY_GCC_TARGET_AVX512F __m512
+avx512_masked_load(__mmask16 mask, npy_float* addr)
+{
+ return _mm512_maskz_loadu_ps(mask, (__m512 *)addr);
+}
+
+NPY_INLINE NPY_GCC_OPT_3 NPY_GCC_TARGET_AVX512F __m512
+avx512_set_masked_lanes(__m512 x, __m512 val, __mmask16 mask)
+{
+ return _mm512_mask_blend_ps(mask, x, val);
+}
+
+NPY_INLINE NPY_GCC_OPT_3 NPY_GCC_TARGET_AVX512F __m512
+avx512_blend(__m512 x, __m512 y, __mmask16 ymask)
+{
+ return _mm512_mask_mov_ps(x, ymask, y);
+}
+
+NPY_INLINE NPY_GCC_OPT_3 NPY_GCC_TARGET_AVX512F __m512
+avx512_get_exponent(__m512 x)
+{
+ return _mm512_add_ps(_mm512_getexp_ps(x), _mm512_set1_ps(1.0f));
+}
+
+NPY_INLINE NPY_GCC_OPT_3 NPY_GCC_TARGET_AVX512F __m512
+avx512_get_mantissa(__m512 x)
+{
+ return _mm512_getmant_ps(x, _MM_MANT_NORM_p5_1, _MM_MANT_SIGN_src);
+}
+#endif
+
+/**begin repeat
+ * #ISA = AVX2, AVX512F#
+ * #isa = avx2, avx512#
+ * #vtype = __m256, __m512#
+ * #vsize = 256, 512#
+ * #or = or_ps, kor#
+ * #vsub = , _mask#
+ * #mask = __m256, __mmask16#
+ * #fmadd = avx2_fmadd,_mm512_fmadd_ps#
+ **/
+
+#if HAVE_ATTRIBUTE_TARGET_@ISA@
+NPY_INLINE NPY_GCC_OPT_3 NPY_GCC_TARGET_@ISA@ @mask@
+@isa@_cmp_mask(@vtype@ x, npy_float fnum, int sign)
+{
+ return _mm@vsize@_cmp_ps@vsub@(x, _mm@vsize@_set1_ps(fnum), sign);
+}
+
+NPY_INLINE NPY_GCC_OPT_3 NPY_GCC_TARGET_@ISA@ @vtype@
+@isa@_range_reduction(@vtype@ x, @vtype@ y, @vtype@ c1, @vtype@ c2, @vtype@ c3)
+{
+ @vtype@ reduced_x = @fmadd@(y, c1, x);
+ reduced_x = @fmadd@(y, c2, reduced_x);
+ reduced_x = @fmadd@(y, c3, reduced_x);
+ return reduced_x;
+}
+#endif
+/**end repeat**/
+
+/**begin repeat
+ * #ISA = AVX2, AVX512F#
+ * #isa = avx2, avx512#
+ * #vtype = __m256, __m512#
+ * #vsize = 256, 512#
+ * #BYTES = 32, 64#
+ * #mask = __m256, __mmask16#
+ * #and_masks =_mm256_and_ps, _mm512_kand#
+ * #fmadd = avx2_fmadd,_mm512_fmadd_ps#
+ * #mask_to_int = _mm256_movemask_ps, #
+ * #full_mask= 0xFF, 0xFFFF#
+ * #masked_store = _mm256_maskstore_ps, _mm512_mask_storeu_ps#
+ * #cvtps_epi32 = _mm256_cvtps_epi32, #
+ */
+
+#if HAVE_ATTRIBUTE_TARGET_@ISA@
+
+/*
+ * Vectorized implementation of exp using AVX2 and AVX512:
+ * 1) if x >= xmax; return INF (overflow)
+ * 2) if x <= xmin; return 0.0f (underflow)
+ * 3) Range reduction (using Coyd-Waite):
+ * a) y = x - k*ln(2); k = rint(x/ln(2)); y \in [0, ln(2)]
+ * 4) Compute exp(y) = P/Q, ratio of 2 polynomials P and Q
+ * b) P = 5th order and Q = 2nd order polynomials obtained from Remez's
+ * algorithm (mini-max polynomial approximation)
+ * 5) Compute exp(x) = exp(y) * 2^k
+ * 6) Max ULP error measured across all 32-bit FP's = 2.52 (x = 0xc2781e37)
+ * 7) Max relative error measured across all 32-bit FP's= 2.1264E-07 (for the
+ * same x = 0xc2781e37)
+ */
+
+NPY_GCC_OPT_3 NPY_GCC_TARGET_@ISA@ void
+@ISA@_exp_FLOAT(npy_float * op, npy_float * ip, const npy_int array_size)
+{
+ const npy_int num_lanes = @BYTES@/sizeof(npy_float);
+ npy_float xmax = 88.72283935546875f;
+ npy_float xmin = -87.3365478515625f;
+
+ /* Load up frequently used constants */
+ @vtype@ codyw_c1 = _mm@vsize@_set1_ps(NPY_CODY_WAITE_LOGE_2_HIGHf);
+ @vtype@ codyw_c2 = _mm@vsize@_set1_ps(NPY_CODY_WAITE_LOGE_2_LOWf);
+ @vtype@ exp_p0 = _mm@vsize@_set1_ps(NPY_COEFF_P0_EXPf);
+ @vtype@ exp_p1 = _mm@vsize@_set1_ps(NPY_COEFF_P1_EXPf);
+ @vtype@ exp_p2 = _mm@vsize@_set1_ps(NPY_COEFF_P2_EXPf);
+ @vtype@ exp_p3 = _mm@vsize@_set1_ps(NPY_COEFF_P3_EXPf);
+ @vtype@ exp_p4 = _mm@vsize@_set1_ps(NPY_COEFF_P4_EXPf);
+ @vtype@ exp_p5 = _mm@vsize@_set1_ps(NPY_COEFF_P5_EXPf);
+ @vtype@ exp_q0 = _mm@vsize@_set1_ps(NPY_COEFF_Q0_EXPf);
+ @vtype@ exp_q1 = _mm@vsize@_set1_ps(NPY_COEFF_Q1_EXPf);
+ @vtype@ exp_q2 = _mm@vsize@_set1_ps(NPY_COEFF_Q2_EXPf);
+ @vtype@ cvt_magic = _mm@vsize@_set1_ps(NPY_RINT_CVT_MAGICf);
+ @vtype@ log2e = _mm@vsize@_set1_ps(NPY_LOG2Ef);
+ @vtype@ inf = _mm@vsize@_set1_ps(NPY_INFINITYF);
+ @vtype@ zeros_f = _mm@vsize@_set1_ps(0.0f);
+ @vtype@ poly, num_poly, denom_poly, quadrant;
+ @vtype@i exponent;
+
+ @mask@ xmax_mask, xmin_mask;
+ @mask@ load_mask = @isa@_get_full_load_mask();
+ npy_int num_remaining_elements = array_size;
+
+ while (num_remaining_elements > 0) {
+
+ if (num_remaining_elements < num_lanes)
+ load_mask = @isa@_get_partial_load_mask(num_remaining_elements,
+ num_lanes);
+ @vtype@ x = @isa@_masked_load(load_mask, ip);
+ xmax_mask = @isa@_cmp_mask(x, xmax, _CMP_GE_OQ);
+ xmin_mask = @isa@_cmp_mask(x, xmin, _CMP_LE_OQ);
+
+ x = @isa@_set_masked_lanes(x, zeros_f,
+ @and_masks@(xmax_mask,xmin_mask));
+
+ quadrant = _mm@vsize@_mul_ps(x, log2e);
+
+ /* round to nearest */
+ quadrant = _mm@vsize@_add_ps(quadrant, cvt_magic);
+ quadrant = _mm@vsize@_sub_ps(quadrant, cvt_magic);
+
+ /* Cody-Waite's range reduction algorithm */
+ x = @isa@_range_reduction(x, quadrant,
+ codyw_c1, codyw_c2, zeros_f);
+
+ num_poly = @fmadd@(exp_p5, x, exp_p4);
+ num_poly = @fmadd@(num_poly, x, exp_p3);
+ num_poly = @fmadd@(num_poly, x, exp_p2);
+ num_poly = @fmadd@(num_poly, x, exp_p1);
+ num_poly = @fmadd@(num_poly, x, exp_p0);
+ denom_poly = @fmadd@(exp_q2, x, exp_q1);
+ denom_poly = @fmadd@(denom_poly, x, exp_q0);
+ poly = _mm@vsize@_div_ps(num_poly, denom_poly);
+
+ /*
+ * compute val = poly * 2^quadrant; which is same as adding the
+ * exponent of quadrant to the exponent of poly. quadrant is an int,
+ * so extracting exponent is simply extracting 8 bits.
+ */
+ exponent = _mm@vsize@_slli_epi32(_mm@vsize@_cvtps_epi32(quadrant), 23);
+ poly = _mm@vsize@_castsi@vsize@_ps(
+ _mm@vsize@_add_epi32(
+ _mm@vsize@_castps_si@vsize@(poly), exponent));
+
+ /* elem > xmax; return inf, elem < xmin; return 0.0f */
+ poly = @isa@_set_masked_lanes(poly, inf, xmax_mask);
+ poly = @isa@_set_masked_lanes(poly, zeros_f, xmin_mask);
+
+ @masked_store@(op, @cvtps_epi32@(load_mask), poly);
+
+ ip += num_lanes;
+ op += num_lanes;
+ num_remaining_elements -= num_lanes;
+ }
+}
+
+/*
+ * Vectorized implementation of log using AVX2 and AVX512
+ * 1) if x < 0.0f; return -NAN (invalid input)
+ * 2) Range reduction: y = x/2^k;
+ * a) y = normalized mantissa, k is the exponent (0.5 <= y < 1)
+ * 3) Compute log(y) = P/Q, ratio of 2 polynomials P and Q
+ * b) P = 5th order and Q = 5th order polynomials obtained from Remez's
+ * algorithm (mini-max polynomial approximation)
+ * 5) Compute log(x) = log(y) + k*ln(2)
+ * 6) Max ULP error measured across all 32-bit FP's = 3.83 (x = 0x3f486945)
+ * 7) Max relative error measured across all 32-bit FP's = 2.359E-07 (for same
+ * x = 0x3f486945)
+ */
+
+NPY_GCC_OPT_3 NPY_GCC_TARGET_@ISA@ void
+@ISA@_log_FLOAT(npy_float * op, npy_float * ip, const npy_int array_size)
+{
+ const npy_int num_lanes = @BYTES@/sizeof(npy_float);
+
+ /* Load up frequently used constants */
+ @vtype@ log_p0 = _mm@vsize@_set1_ps(NPY_COEFF_P0_LOGf);
+ @vtype@ log_p1 = _mm@vsize@_set1_ps(NPY_COEFF_P1_LOGf);
+ @vtype@ log_p2 = _mm@vsize@_set1_ps(NPY_COEFF_P2_LOGf);
+ @vtype@ log_p3 = _mm@vsize@_set1_ps(NPY_COEFF_P3_LOGf);
+ @vtype@ log_p4 = _mm@vsize@_set1_ps(NPY_COEFF_P4_LOGf);
+ @vtype@ log_p5 = _mm@vsize@_set1_ps(NPY_COEFF_P5_LOGf);
+ @vtype@ log_q0 = _mm@vsize@_set1_ps(NPY_COEFF_Q0_LOGf);
+ @vtype@ log_q1 = _mm@vsize@_set1_ps(NPY_COEFF_Q1_LOGf);
+ @vtype@ log_q2 = _mm@vsize@_set1_ps(NPY_COEFF_Q2_LOGf);
+ @vtype@ log_q3 = _mm@vsize@_set1_ps(NPY_COEFF_Q3_LOGf);
+ @vtype@ log_q4 = _mm@vsize@_set1_ps(NPY_COEFF_Q4_LOGf);
+ @vtype@ log_q5 = _mm@vsize@_set1_ps(NPY_COEFF_Q5_LOGf);
+ @vtype@ loge2 = _mm@vsize@_set1_ps(NPY_LOGE2f);
+ @vtype@ neg_nan = _mm@vsize@_set1_ps(-NPY_NANF);
+ @vtype@ neg_inf = _mm@vsize@_set1_ps(-NPY_INFINITYF);
+ @vtype@ zeros_f = _mm@vsize@_set1_ps(0.0f);
+ @vtype@ ones_f = _mm@vsize@_set1_ps(1.0f);
+ @vtype@ poly, num_poly, denom_poly, exponent;
+
+ @mask@ inf_nan_mask, sqrt2_mask, zero_mask, negx_mask;
+ @mask@ load_mask = @isa@_get_full_load_mask();
+ npy_int num_remaining_elements = array_size;
+
+ while (num_remaining_elements > 0) {
+
+ if (num_remaining_elements < num_lanes)
+ load_mask = @isa@_get_partial_load_mask(num_remaining_elements,
+ num_lanes);
+ @vtype@ x_in = @isa@_masked_load(load_mask, ip);
+
+ negx_mask = @isa@_cmp_mask(x_in, 0.0f, _CMP_LT_OQ);
+ zero_mask = @isa@_cmp_mask(x_in, 0.0f, _CMP_EQ_OQ);
+ inf_nan_mask = @isa@_cmp_mask(x_in, FLT_MAX, _CMP_GT_OQ);
+
+ @vtype@ x = @isa@_set_masked_lanes(x_in, zeros_f, negx_mask);
+
+ /* set x = normalized mantissa */
+ exponent = @isa@_get_exponent(x);
+ x = @isa@_get_mantissa(x);
+
+ /* if x < sqrt(2) {exp = exp-1; x = 2*x} */
+ sqrt2_mask = @isa@_cmp_mask(x, NPY_SQRT1_2f, _CMP_LE_OQ);
+ x = @isa@_blend(x, _mm@vsize@_add_ps(x,x), sqrt2_mask);
+ exponent = @isa@_blend(exponent,
+ _mm@vsize@_sub_ps(exponent,ones_f), sqrt2_mask);
+
+ /* x = x - 1 */
+ x = _mm@vsize@_sub_ps(x, ones_f);
+
+ /* Polynomial approximation for log(1+x) */
+ num_poly = @fmadd@(log_p5, x, log_p4);
+ num_poly = @fmadd@(num_poly, x, log_p3);
+ num_poly = @fmadd@(num_poly, x, log_p2);
+ num_poly = @fmadd@(num_poly, x, log_p1);
+ num_poly = @fmadd@(num_poly, x, log_p0);
+ denom_poly = @fmadd@(log_q5, x, log_q4);
+ denom_poly = @fmadd@(denom_poly, x, log_q3);
+ denom_poly = @fmadd@(denom_poly, x, log_q2);
+ denom_poly = @fmadd@(denom_poly, x, log_q1);
+ denom_poly = @fmadd@(denom_poly, x, log_q0);
+ poly = _mm@vsize@_div_ps(num_poly, denom_poly);
+ poly = @fmadd@(exponent, loge2, poly);
+
+ /*
+ * x < 0.0f; return -NAN
+ * x = 0.0f; return -INF
+ * x > FLT_MAX; return x
+ */
+ poly = @isa@_set_masked_lanes(poly, neg_nan, negx_mask);
+ poly = @isa@_set_masked_lanes(poly, neg_inf, zero_mask);
+ poly = @isa@_set_masked_lanes(poly, x_in, inf_nan_mask);
+
+ @masked_store@(op, @cvtps_epi32@(load_mask), poly);
+
+ ip += num_lanes;
+ op += num_lanes;
+ num_remaining_elements -= num_lanes;
+ }
+}
+#endif
+/**end repeat**/
+
/*
*****************************************************************************
** BOOL LOOPS