summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSayed Adel <seiko@imavr.com>2022-12-21 11:56:20 +0200
committerSayed Adel <seiko@imavr.com>2022-12-21 12:05:22 +0200
commit2862a1c63244809993ddb415c4f858ff699ba9c5 (patch)
tree03c902977f5bacbbe153a560f36eb395534e758a
parent0d1bb8e42228776dc8d35bdcfacf2ff3af366ade (diff)
downloadnumpy-2862a1c63244809993ddb415c4f858ff699ba9c5.tar.gz
BUG, SIMD: Fix memory overlap on comparison accumulate
-rw-r--r--numpy/core/src/umath/loops_comparison.dispatch.c.src30
-rw-r--r--numpy/core/tests/test_umath.py40
2 files changed, 57 insertions, 13 deletions
diff --git a/numpy/core/src/umath/loops_comparison.dispatch.c.src b/numpy/core/src/umath/loops_comparison.dispatch.c.src
index f61ef4113..751080871 100644
--- a/numpy/core/src/umath/loops_comparison.dispatch.c.src
+++ b/numpy/core/src/umath/loops_comparison.dispatch.c.src
@@ -312,19 +312,23 @@ static inline void
run_binary_simd_@kind@_@sfx@(char **args, npy_intp const *dimensions, npy_intp const *steps)
{
#if @VECTOR@
- /* argument one scalar */
- if (IS_BLOCKABLE_BINARY_SCALAR1_BOOL(sizeof(@type@), NPY_SIMD_WIDTH)) {
- simd_binary_scalar1_@kind@_@sfx@(args, dimensions[0]);
- return;
- }
- /* argument two scalar */
- else if (IS_BLOCKABLE_BINARY_SCALAR2_BOOL(sizeof(@type@), NPY_SIMD_WIDTH)) {
- simd_binary_scalar2_@kind@_@sfx@(args, dimensions[0]);
- return;
- }
- else if (IS_BLOCKABLE_BINARY_BOOL(sizeof(@type@), NPY_SIMD_WIDTH)) {
- simd_binary_@kind@_@sfx@(args, dimensions[0]);
- return;
+ if (!is_mem_overlap(args[0], steps[0], args[2], steps[2], dimensions[0]) &&
+ !is_mem_overlap(args[1], steps[1], args[2], steps[2], dimensions[0])
+ ) {
+ /* argument one scalar */
+ if (IS_BINARY_CONT_S1(@type@, npy_bool)) {
+ simd_binary_scalar1_@kind@_@sfx@(args, dimensions[0]);
+ return;
+ }
+ /* argument two scalar */
+ else if (IS_BINARY_CONT_S2(@type@, npy_bool)) {
+ simd_binary_scalar2_@kind@_@sfx@(args, dimensions[0]);
+ return;
+ }
+ else if (IS_BINARY_CONT(@type@, npy_bool)) {
+ simd_binary_@kind@_@sfx@(args, dimensions[0]);
+ return;
+ }
}
#endif
diff --git a/numpy/core/tests/test_umath.py b/numpy/core/tests/test_umath.py
index 3364eb7e6..26e4e39af 100644
--- a/numpy/core/tests/test_umath.py
+++ b/numpy/core/tests/test_umath.py
@@ -31,6 +31,13 @@ UFUNCS_UNARY_FP = [
uf for uf in UFUNCS_UNARY if 'f->f' in uf.types
]
+UFUNCS_BINARY = [
+ uf for uf in UFUNCS if uf.nin == 2
+]
+UFUNCS_BINARY_ACC = [
+ uf for uf in UFUNCS_BINARY if hasattr(uf, "accumulate") and uf.nout == 1
+]
+
def interesting_binop_operands(val1, val2, dtype):
"""
Helper to create "interesting" operands to cover common code paths:
@@ -4329,6 +4336,7 @@ def test_rint_big_int():
# Rint should not change the value
assert_equal(val, np.rint(val))
+
@pytest.mark.parametrize('ftype', [np.float32, np.float64])
def test_memoverlap_accumulate(ftype):
# Reproduces bug https://github.com/numpy/numpy/issues/15597
@@ -4338,6 +4346,38 @@ def test_memoverlap_accumulate(ftype):
assert_equal(np.maximum.accumulate(arr), out_max)
assert_equal(np.minimum.accumulate(arr), out_min)
+@pytest.mark.parametrize("ufunc, dtype", [
+ (ufunc, t[0])
+ for ufunc in UFUNCS_BINARY_ACC
+ for t in ufunc.types
+ if t[-1] == '?' and t[0] not in 'DFGMmO'
+])
+def test_memoverlap_accumulate_cmp(ufunc, dtype):
+ if ufunc.signature:
+ pytest.skip('For generic signatures only')
+ for size in (2, 8, 32, 64, 128, 256):
+ arr = np.array([0, 1, 1]*size, dtype=dtype)
+ acc = ufunc.accumulate(arr, dtype='?')
+ acc_u8 = acc.view(np.uint8)
+ exp = np.array(list(itertools.accumulate(arr, ufunc)), dtype=np.uint8)
+ assert_equal(exp, acc_u8)
+
+@pytest.mark.parametrize("ufunc, dtype", [
+ (ufunc, t[0])
+ for ufunc in UFUNCS_BINARY_ACC
+ for t in ufunc.types
+ if t[0] == t[1] and t[0] == t[-1] and t[0] not in 'DFGMmO?'
+])
+def test_memoverlap_accumulate_symmetric(ufunc, dtype):
+ if ufunc.signature:
+ pytest.skip('For generic signatures only')
+ with np.errstate(all='ignore'):
+ for size in (2, 8, 32, 64, 128, 256):
+ arr = np.array([0, 1, 2]*size).astype(dtype)
+ acc = ufunc.accumulate(arr, dtype=dtype)
+ exp = np.array(list(itertools.accumulate(arr, ufunc)), dtype=dtype)
+ assert_equal(exp, acc)
+
def test_signaling_nan_exceptions():
with assert_no_warnings():
a = np.ndarray(shape=(), dtype='float32', buffer=b'\x00\xe0\xbf\xff')