diff options
author | Sayed Adel <seiko@imavr.com> | 2022-12-21 11:56:20 +0200 |
---|---|---|
committer | Sayed Adel <seiko@imavr.com> | 2022-12-21 12:05:22 +0200 |
commit | 2862a1c63244809993ddb415c4f858ff699ba9c5 (patch) | |
tree | 03c902977f5bacbbe153a560f36eb395534e758a | |
parent | 0d1bb8e42228776dc8d35bdcfacf2ff3af366ade (diff) | |
download | numpy-2862a1c63244809993ddb415c4f858ff699ba9c5.tar.gz |
BUG, SIMD: Fix memory overlap on comparison accumulate
-rw-r--r-- | numpy/core/src/umath/loops_comparison.dispatch.c.src | 30 | ||||
-rw-r--r-- | numpy/core/tests/test_umath.py | 40 |
2 files changed, 57 insertions, 13 deletions
diff --git a/numpy/core/src/umath/loops_comparison.dispatch.c.src b/numpy/core/src/umath/loops_comparison.dispatch.c.src index f61ef4113..751080871 100644 --- a/numpy/core/src/umath/loops_comparison.dispatch.c.src +++ b/numpy/core/src/umath/loops_comparison.dispatch.c.src @@ -312,19 +312,23 @@ static inline void run_binary_simd_@kind@_@sfx@(char **args, npy_intp const *dimensions, npy_intp const *steps) { #if @VECTOR@ - /* argument one scalar */ - if (IS_BLOCKABLE_BINARY_SCALAR1_BOOL(sizeof(@type@), NPY_SIMD_WIDTH)) { - simd_binary_scalar1_@kind@_@sfx@(args, dimensions[0]); - return; - } - /* argument two scalar */ - else if (IS_BLOCKABLE_BINARY_SCALAR2_BOOL(sizeof(@type@), NPY_SIMD_WIDTH)) { - simd_binary_scalar2_@kind@_@sfx@(args, dimensions[0]); - return; - } - else if (IS_BLOCKABLE_BINARY_BOOL(sizeof(@type@), NPY_SIMD_WIDTH)) { - simd_binary_@kind@_@sfx@(args, dimensions[0]); - return; + if (!is_mem_overlap(args[0], steps[0], args[2], steps[2], dimensions[0]) && + !is_mem_overlap(args[1], steps[1], args[2], steps[2], dimensions[0]) + ) { + /* argument one scalar */ + if (IS_BINARY_CONT_S1(@type@, npy_bool)) { + simd_binary_scalar1_@kind@_@sfx@(args, dimensions[0]); + return; + } + /* argument two scalar */ + else if (IS_BINARY_CONT_S2(@type@, npy_bool)) { + simd_binary_scalar2_@kind@_@sfx@(args, dimensions[0]); + return; + } + else if (IS_BINARY_CONT(@type@, npy_bool)) { + simd_binary_@kind@_@sfx@(args, dimensions[0]); + return; + } } #endif diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py index 3364eb7e6..26e4e39af 100644 --- a/numpy/core/tests/test_umath.py +++ b/numpy/core/tests/test_umath.py @@ -31,6 +31,13 @@ UFUNCS_UNARY_FP = [ uf for uf in UFUNCS_UNARY if 'f->f' in uf.types ] +UFUNCS_BINARY = [ + uf for uf in UFUNCS if uf.nin == 2 +] +UFUNCS_BINARY_ACC = [ + uf for uf in UFUNCS_BINARY if hasattr(uf, "accumulate") and uf.nout == 1 +] + def interesting_binop_operands(val1, val2, dtype): """ Helper to create "interesting" operands to cover common code paths: @@ -4329,6 +4336,7 @@ def test_rint_big_int(): # Rint should not change the value assert_equal(val, np.rint(val)) + @pytest.mark.parametrize('ftype', [np.float32, np.float64]) def test_memoverlap_accumulate(ftype): # Reproduces bug https://github.com/numpy/numpy/issues/15597 @@ -4338,6 +4346,38 @@ def test_memoverlap_accumulate(ftype): assert_equal(np.maximum.accumulate(arr), out_max) assert_equal(np.minimum.accumulate(arr), out_min) +@pytest.mark.parametrize("ufunc, dtype", [ + (ufunc, t[0]) + for ufunc in UFUNCS_BINARY_ACC + for t in ufunc.types + if t[-1] == '?' and t[0] not in 'DFGMmO' +]) +def test_memoverlap_accumulate_cmp(ufunc, dtype): + if ufunc.signature: + pytest.skip('For generic signatures only') + for size in (2, 8, 32, 64, 128, 256): + arr = np.array([0, 1, 1]*size, dtype=dtype) + acc = ufunc.accumulate(arr, dtype='?') + acc_u8 = acc.view(np.uint8) + exp = np.array(list(itertools.accumulate(arr, ufunc)), dtype=np.uint8) + assert_equal(exp, acc_u8) + +@pytest.mark.parametrize("ufunc, dtype", [ + (ufunc, t[0]) + for ufunc in UFUNCS_BINARY_ACC + for t in ufunc.types + if t[0] == t[1] and t[0] == t[-1] and t[0] not in 'DFGMmO?' +]) +def test_memoverlap_accumulate_symmetric(ufunc, dtype): + if ufunc.signature: + pytest.skip('For generic signatures only') + with np.errstate(all='ignore'): + for size in (2, 8, 32, 64, 128, 256): + arr = np.array([0, 1, 2]*size).astype(dtype) + acc = ufunc.accumulate(arr, dtype=dtype) + exp = np.array(list(itertools.accumulate(arr, ufunc)), dtype=dtype) + assert_equal(exp, acc) + def test_signaling_nan_exceptions(): with assert_no_warnings(): a = np.ndarray(shape=(), dtype='float32', buffer=b'\x00\xe0\xbf\xff') |