summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--numpy/core/src/umath/simd.inc.src106
-rw-r--r--numpy/core/tests/test_umath_complex.py12
2 files changed, 92 insertions, 26 deletions
diff --git a/numpy/core/src/umath/simd.inc.src b/numpy/core/src/umath/simd.inc.src
index e3c0ee3cc..7ec90f9c8 100644
--- a/numpy/core/src/umath/simd.inc.src
+++ b/numpy/core/src/umath/simd.inc.src
@@ -1725,13 +1725,17 @@ avx512_hsub_@vsub@(const @vtype@ x)
}
static NPY_INLINE NPY_GCC_OPT_3 NPY_GCC_TARGET_AVX512F @vtype@
-avx512_cabsolute_@vsub@(const @vtype@ x)
+avx512_cabsolute_@vsub@(const @vtype@ x1,
+ const @vtype@ x2,
+ const __m512i re_indices,
+ const __m512i im_indices)
{
@vtype@ inf = _mm512_set1_@vsub@(@INF@);
@vtype@ nan = _mm512_set1_@vsub@(@NAN@);
- @vtype@ x_abs = avx512_abs_@vsub@(x);
- @vtype@ re = _mm512_maskz_compress_@vsub@(@cmpx_re_mask@, x_abs);
- @vtype@ im = _mm512_maskz_compress_@vsub@(@cmpx_img_mask@, x_abs);
+ @vtype@ x1_abs = avx512_abs_@vsub@(x1);
+ @vtype@ x2_abs = avx512_abs_@vsub@(x2);
+ @vtype@ re = _mm512_permutex2var_@vsub@(x1_abs, re_indices, x2_abs);
+ @vtype@ im = _mm512_permutex2var_@vsub@(x1_abs, im_indices , x2_abs);
/*
* If real or imag = INF, then convert it to inf + j*inf
* Handles: inf + j*nan, nan + j*inf
@@ -2621,12 +2625,14 @@ static NPY_GCC_OPT_3 NPY_GCC_TARGET_@ISA@ void
* #type = npy_float, npy_double#
* #num_lanes = 16, 8#
* #vsuffix = ps, pd#
+ * #epi_vsub = epi32, epi64#
* #mask = __mmask16, __mmask8#
* #vtype = __m512, __m512d#
* #scale = 4, 8#
* #vindextype = __m512i, __m256i#
* #vindexload = _mm512_loadu_si512, _mm256_loadu_si256#
* #storemask = 0xFF, 0xF#
+ * #IS_FLOAT = 1, 0#
*/
/**begin repeat1
@@ -2669,9 +2675,8 @@ AVX512F_@func@_@TYPE@(char **args, const npy_intp *dimensions, const npy_intp *s
/**end repeat1**/
/**begin repeat1
- * #func = absolute, square, conjugate#
- * #vectorf = avx512_cabsolute, avx512_csquare, avx512_conjugate#
- * #is_out_real = 1, 0, 0#
+ * #func = square, conjugate#
+ * #vectorf = avx512_csquare, avx512_conjugate#
*/
#if defined HAVE_ATTRIBUTE_TARGET_AVX512F_WITH_INTRINSICS && defined NPY_HAVE_SSE2_INTRINSICS
@@ -2695,19 +2700,12 @@ AVX512F_@func@_@TYPE@(@type@ * op,
}
@vindextype@ vindex = @vindexload@((@vindextype@*)index_ip1);
@mask@ load_mask = avx512_get_full_load_mask_@vsuffix@();
-#if @is_out_real@
- @mask@ store_mask = _mm512_kand(avx512_get_full_load_mask_@vsuffix@(), @storemask@);
-#endif
@vtype@ zeros = _mm512_setzero_@vsuffix@();
while (num_remaining_elements > 0) {
if (num_remaining_elements < @num_lanes@) {
load_mask = avx512_get_partial_load_mask_@vsuffix@(
num_remaining_elements, @num_lanes@);
-#if @is_out_real@
- store_mask = avx512_get_partial_load_mask_@vsuffix@(
- num_remaining_elements/2, @num_lanes@);
-#endif
}
@vtype@ x1;
if (stride_ip1 == 1) {
@@ -2719,27 +2717,85 @@ AVX512F_@func@_@TYPE@(@type@ * op,
@vtype@ out = @vectorf@_@vsuffix@(x1);
-#if @is_out_real@
- _mm512_mask_storeu_@vsuffix@(op, store_mask, out);
- op += @num_lanes@/2;
-#else
_mm512_mask_storeu_@vsuffix@(op, load_mask, out);
op += @num_lanes@;
-#endif
-
ip += @num_lanes@*stride_ip1;
num_remaining_elements -= @num_lanes@;
}
-#if @is_out_real@
+}
+#endif
+/**end repeat1**/
+
+#if defined HAVE_ATTRIBUTE_TARGET_AVX512F_WITH_INTRINSICS && defined NPY_HAVE_SSE2_INTRINSICS
+static NPY_GCC_OPT_3 NPY_INLINE NPY_GCC_TARGET_AVX512F void
+AVX512F_absolute_@TYPE@(@type@ * op,
+ @type@ * ip,
+ const npy_intp array_size,
+ const npy_intp steps)
+{
+ npy_intp num_remaining_elements = 2*array_size;
+ const npy_intp stride_ip1 = steps/(npy_intp)sizeof(@type@)/2;
+
/*
- * Ignore invalid exception for cabsolute generated by vmaxps/vmaxpd
- * and vminps/vminpd instructions
+ * Note: while generally indices are npy_intp, we ensure that our maximum index
+ * will fit in an int32 as a precondition for this function via max_stride
*/
- npy_clear_floatstatus_barrier((char*)op);
+ npy_int32 index_ip[32];
+ for (npy_int32 ii = 0; ii < 2*@num_lanes@; ii=ii+2) {
+ index_ip[ii] = ii*stride_ip1;
+ index_ip[ii+1] = ii*stride_ip1 + 1;
+ }
+ @vindextype@ vindex1 = @vindexload@((@vindextype@*)index_ip);
+ @vindextype@ vindex2 = @vindexload@((@vindextype@*)(index_ip+@num_lanes@));
+
+ @mask@ load_mask1 = avx512_get_full_load_mask_@vsuffix@();
+ @mask@ load_mask2 = avx512_get_full_load_mask_@vsuffix@();
+ @mask@ store_mask = avx512_get_full_load_mask_@vsuffix@();
+ @vtype@ zeros = _mm512_setzero_@vsuffix@();
+
+#if @IS_FLOAT@
+ __m512i re_index = _mm512_set_epi32(30,28,26,24,22,20,18,16,14,12,10,8,6,4,2,0);
+ __m512i im_index = _mm512_set_epi32(31,29,27,25,23,21,19,17,15,13,11,9,7,5,3,1);
+#else
+ __m512i re_index = _mm512_set_epi64(14,12,10,8,6,4,2,0);
+ __m512i im_index = _mm512_set_epi64(15,13,11,9,7,5,3,1);
#endif
+
+ while (num_remaining_elements > 0) {
+ if (num_remaining_elements < @num_lanes@) {
+ load_mask1 = avx512_get_partial_load_mask_@vsuffix@(
+ num_remaining_elements, @num_lanes@);
+ load_mask2 = 0x0000;
+ store_mask = avx512_get_partial_load_mask_@vsuffix@(
+ num_remaining_elements/2, @num_lanes@);
+ } else if (num_remaining_elements < 2*@num_lanes@) {
+ load_mask1 = avx512_get_full_load_mask_@vsuffix@();
+ load_mask2 = avx512_get_partial_load_mask_@vsuffix@(
+ num_remaining_elements - @num_lanes@, @num_lanes@);
+ store_mask = avx512_get_partial_load_mask_@vsuffix@(
+ num_remaining_elements/2, @num_lanes@);
+ }
+ @vtype@ x1, x2;
+ if (stride_ip1 == 1) {
+ x1 = avx512_masked_load_@vsuffix@(load_mask1, ip);
+ x2 = avx512_masked_load_@vsuffix@(load_mask2, ip+@num_lanes@);
+ }
+ else {
+ x1 = avx512_masked_gather_@vsuffix@(zeros, ip, vindex1, load_mask1);
+ x2 = avx512_masked_gather_@vsuffix@(zeros, ip, vindex2, load_mask2);
+ }
+
+ @vtype@ out = avx512_cabsolute_@vsuffix@(x1, x2, re_index, im_index);
+
+ _mm512_mask_storeu_@vsuffix@(op, store_mask, out);
+ op += @num_lanes@;
+ ip += 2*@num_lanes@*stride_ip1;
+ num_remaining_elements -= 2*@num_lanes@;
+ }
+ npy_clear_floatstatus_barrier((char*)op);
}
+
#endif
-/**end repeat1**/
/**end repeat**/
/*
diff --git a/numpy/core/tests/test_umath_complex.py b/numpy/core/tests/test_umath_complex.py
index b4f2ebfde..a21158420 100644
--- a/numpy/core/tests/test_umath_complex.py
+++ b/numpy/core/tests/test_umath_complex.py
@@ -541,7 +541,7 @@ def check_complex_value(f, x1, y1, x2, y2, exact=True):
else:
assert_almost_equal(f(z1), z2)
-class TestComplexAVX(object):
+class TestSpecialComplexAVX(object):
@pytest.mark.parametrize("stride", [-4,-2,-1,1,2,4])
@pytest.mark.parametrize("astype", [np.complex64, np.complex128])
def test_array(self, stride, astype):
@@ -567,3 +567,13 @@ class TestComplexAVX(object):
assert_equal(np.abs(arr[::stride]), abs_true[::stride])
with np.errstate(invalid='ignore'):
assert_equal(np.square(arr[::stride]), sq_true[::stride])
+
+class TestComplexAbsoluteAVX(object):
+ @pytest.mark.parametrize("arraysize", [1,2,3,4,5,6,7,8,9,10,11,13,15,17,18,19])
+ @pytest.mark.parametrize("stride", [-4,-3,-2,-1,1,2,3,4])
+ @pytest.mark.parametrize("astype", [np.complex64, np.complex128])
+ # test to ensure masking and strides work as intended in the AVX implementation
+ def test_array(self, arraysize, stride, astype):
+ arr = np.ones(arraysize, dtype=astype)
+ abs_true = np.ones(arraysize, dtype=arr.real.dtype)
+ assert_equal(np.abs(arr[::stride]), abs_true[::stride])