diff options
-rw-r--r-- | numpy/core/src/umath/simd.inc.src | 106 | ||||
-rw-r--r-- | numpy/core/tests/test_umath_complex.py | 12 |
2 files changed, 92 insertions, 26 deletions
diff --git a/numpy/core/src/umath/simd.inc.src b/numpy/core/src/umath/simd.inc.src index e3c0ee3cc..7ec90f9c8 100644 --- a/numpy/core/src/umath/simd.inc.src +++ b/numpy/core/src/umath/simd.inc.src @@ -1725,13 +1725,17 @@ avx512_hsub_@vsub@(const @vtype@ x) } static NPY_INLINE NPY_GCC_OPT_3 NPY_GCC_TARGET_AVX512F @vtype@ -avx512_cabsolute_@vsub@(const @vtype@ x) +avx512_cabsolute_@vsub@(const @vtype@ x1, + const @vtype@ x2, + const __m512i re_indices, + const __m512i im_indices) { @vtype@ inf = _mm512_set1_@vsub@(@INF@); @vtype@ nan = _mm512_set1_@vsub@(@NAN@); - @vtype@ x_abs = avx512_abs_@vsub@(x); - @vtype@ re = _mm512_maskz_compress_@vsub@(@cmpx_re_mask@, x_abs); - @vtype@ im = _mm512_maskz_compress_@vsub@(@cmpx_img_mask@, x_abs); + @vtype@ x1_abs = avx512_abs_@vsub@(x1); + @vtype@ x2_abs = avx512_abs_@vsub@(x2); + @vtype@ re = _mm512_permutex2var_@vsub@(x1_abs, re_indices, x2_abs); + @vtype@ im = _mm512_permutex2var_@vsub@(x1_abs, im_indices , x2_abs); /* * If real or imag = INF, then convert it to inf + j*inf * Handles: inf + j*nan, nan + j*inf @@ -2621,12 +2625,14 @@ static NPY_GCC_OPT_3 NPY_GCC_TARGET_@ISA@ void * #type = npy_float, npy_double# * #num_lanes = 16, 8# * #vsuffix = ps, pd# + * #epi_vsub = epi32, epi64# * #mask = __mmask16, __mmask8# * #vtype = __m512, __m512d# * #scale = 4, 8# * #vindextype = __m512i, __m256i# * #vindexload = _mm512_loadu_si512, _mm256_loadu_si256# * #storemask = 0xFF, 0xF# + * #IS_FLOAT = 1, 0# */ /**begin repeat1 @@ -2669,9 +2675,8 @@ AVX512F_@func@_@TYPE@(char **args, const npy_intp *dimensions, const npy_intp *s /**end repeat1**/ /**begin repeat1 - * #func = absolute, square, conjugate# - * #vectorf = avx512_cabsolute, avx512_csquare, avx512_conjugate# - * #is_out_real = 1, 0, 0# + * #func = square, conjugate# + * #vectorf = avx512_csquare, avx512_conjugate# */ #if defined HAVE_ATTRIBUTE_TARGET_AVX512F_WITH_INTRINSICS && defined NPY_HAVE_SSE2_INTRINSICS @@ -2695,19 +2700,12 @@ AVX512F_@func@_@TYPE@(@type@ * op, } @vindextype@ vindex = @vindexload@((@vindextype@*)index_ip1); @mask@ load_mask = avx512_get_full_load_mask_@vsuffix@(); -#if @is_out_real@ - @mask@ store_mask = _mm512_kand(avx512_get_full_load_mask_@vsuffix@(), @storemask@); -#endif @vtype@ zeros = _mm512_setzero_@vsuffix@(); while (num_remaining_elements > 0) { if (num_remaining_elements < @num_lanes@) { load_mask = avx512_get_partial_load_mask_@vsuffix@( num_remaining_elements, @num_lanes@); -#if @is_out_real@ - store_mask = avx512_get_partial_load_mask_@vsuffix@( - num_remaining_elements/2, @num_lanes@); -#endif } @vtype@ x1; if (stride_ip1 == 1) { @@ -2719,27 +2717,85 @@ AVX512F_@func@_@TYPE@(@type@ * op, @vtype@ out = @vectorf@_@vsuffix@(x1); -#if @is_out_real@ - _mm512_mask_storeu_@vsuffix@(op, store_mask, out); - op += @num_lanes@/2; -#else _mm512_mask_storeu_@vsuffix@(op, load_mask, out); op += @num_lanes@; -#endif - ip += @num_lanes@*stride_ip1; num_remaining_elements -= @num_lanes@; } -#if @is_out_real@ +} +#endif +/**end repeat1**/ + +#if defined HAVE_ATTRIBUTE_TARGET_AVX512F_WITH_INTRINSICS && defined NPY_HAVE_SSE2_INTRINSICS +static NPY_GCC_OPT_3 NPY_INLINE NPY_GCC_TARGET_AVX512F void +AVX512F_absolute_@TYPE@(@type@ * op, + @type@ * ip, + const npy_intp array_size, + const npy_intp steps) +{ + npy_intp num_remaining_elements = 2*array_size; + const npy_intp stride_ip1 = steps/(npy_intp)sizeof(@type@)/2; + /* - * Ignore invalid exception for cabsolute generated by vmaxps/vmaxpd - * and vminps/vminpd instructions + * Note: while generally indices are npy_intp, we ensure that our maximum index + * will fit in an int32 as a precondition for this function via max_stride */ - npy_clear_floatstatus_barrier((char*)op); + npy_int32 index_ip[32]; + for (npy_int32 ii = 0; ii < 2*@num_lanes@; ii=ii+2) { + index_ip[ii] = ii*stride_ip1; + index_ip[ii+1] = ii*stride_ip1 + 1; + } + @vindextype@ vindex1 = @vindexload@((@vindextype@*)index_ip); + @vindextype@ vindex2 = @vindexload@((@vindextype@*)(index_ip+@num_lanes@)); + + @mask@ load_mask1 = avx512_get_full_load_mask_@vsuffix@(); + @mask@ load_mask2 = avx512_get_full_load_mask_@vsuffix@(); + @mask@ store_mask = avx512_get_full_load_mask_@vsuffix@(); + @vtype@ zeros = _mm512_setzero_@vsuffix@(); + +#if @IS_FLOAT@ + __m512i re_index = _mm512_set_epi32(30,28,26,24,22,20,18,16,14,12,10,8,6,4,2,0); + __m512i im_index = _mm512_set_epi32(31,29,27,25,23,21,19,17,15,13,11,9,7,5,3,1); +#else + __m512i re_index = _mm512_set_epi64(14,12,10,8,6,4,2,0); + __m512i im_index = _mm512_set_epi64(15,13,11,9,7,5,3,1); #endif + + while (num_remaining_elements > 0) { + if (num_remaining_elements < @num_lanes@) { + load_mask1 = avx512_get_partial_load_mask_@vsuffix@( + num_remaining_elements, @num_lanes@); + load_mask2 = 0x0000; + store_mask = avx512_get_partial_load_mask_@vsuffix@( + num_remaining_elements/2, @num_lanes@); + } else if (num_remaining_elements < 2*@num_lanes@) { + load_mask1 = avx512_get_full_load_mask_@vsuffix@(); + load_mask2 = avx512_get_partial_load_mask_@vsuffix@( + num_remaining_elements - @num_lanes@, @num_lanes@); + store_mask = avx512_get_partial_load_mask_@vsuffix@( + num_remaining_elements/2, @num_lanes@); + } + @vtype@ x1, x2; + if (stride_ip1 == 1) { + x1 = avx512_masked_load_@vsuffix@(load_mask1, ip); + x2 = avx512_masked_load_@vsuffix@(load_mask2, ip+@num_lanes@); + } + else { + x1 = avx512_masked_gather_@vsuffix@(zeros, ip, vindex1, load_mask1); + x2 = avx512_masked_gather_@vsuffix@(zeros, ip, vindex2, load_mask2); + } + + @vtype@ out = avx512_cabsolute_@vsuffix@(x1, x2, re_index, im_index); + + _mm512_mask_storeu_@vsuffix@(op, store_mask, out); + op += @num_lanes@; + ip += 2*@num_lanes@*stride_ip1; + num_remaining_elements -= 2*@num_lanes@; + } + npy_clear_floatstatus_barrier((char*)op); } + #endif -/**end repeat1**/ /**end repeat**/ /* diff --git a/numpy/core/tests/test_umath_complex.py b/numpy/core/tests/test_umath_complex.py index b4f2ebfde..a21158420 100644 --- a/numpy/core/tests/test_umath_complex.py +++ b/numpy/core/tests/test_umath_complex.py @@ -541,7 +541,7 @@ def check_complex_value(f, x1, y1, x2, y2, exact=True): else: assert_almost_equal(f(z1), z2) -class TestComplexAVX(object): +class TestSpecialComplexAVX(object): @pytest.mark.parametrize("stride", [-4,-2,-1,1,2,4]) @pytest.mark.parametrize("astype", [np.complex64, np.complex128]) def test_array(self, stride, astype): @@ -567,3 +567,13 @@ class TestComplexAVX(object): assert_equal(np.abs(arr[::stride]), abs_true[::stride]) with np.errstate(invalid='ignore'): assert_equal(np.square(arr[::stride]), sq_true[::stride]) + +class TestComplexAbsoluteAVX(object): + @pytest.mark.parametrize("arraysize", [1,2,3,4,5,6,7,8,9,10,11,13,15,17,18,19]) + @pytest.mark.parametrize("stride", [-4,-3,-2,-1,1,2,3,4]) + @pytest.mark.parametrize("astype", [np.complex64, np.complex128]) + # test to ensure masking and strides work as intended in the AVX implementation + def test_array(self, arraysize, stride, astype): + arr = np.ones(arraysize, dtype=astype) + abs_true = np.ones(arraysize, dtype=arr.real.dtype) + assert_equal(np.abs(arr[::stride]), abs_true[::stride]) |