summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDeveloper-Ecosystem-Engineering <65677710+Developer-Ecosystem-Engineering@users.noreply.github.com>2022-12-21 11:30:48 -0800
committerDeveloper-Ecosystem-Engineering <65677710+Developer-Ecosystem-Engineering@users.noreply.github.com>2022-12-21 11:30:48 -0800
commitf2fa2e533b0fee50aee2db8cfe2906fd072774c7 (patch)
tree5dec4905416c417b644370ebafdbb2525c07a8ec
parent2f64274d360b25d509636897d5d39157080a42cc (diff)
downloadnumpy-f2fa2e533b0fee50aee2db8cfe2906fd072774c7.tar.gz
BUG, SIMD: Restore behavior converting non bool input to 0x00/0xff
This resolves https://github.com/numpy/numpy/issues/22845 by restoring prior behavior to convert non bool input
-rw-r--r--numpy/core/src/umath/loops_logical.dispatch.c.src12
1 files changed, 12 insertions, 0 deletions
diff --git a/numpy/core/src/umath/loops_logical.dispatch.c.src b/numpy/core/src/umath/loops_logical.dispatch.c.src
index 793a2af19..4a021d50b 100644
--- a/numpy/core/src/umath/loops_logical.dispatch.c.src
+++ b/numpy/core/src/umath/loops_logical.dispatch.c.src
@@ -63,6 +63,10 @@ simd_binary_@kind@_BOOL(npy_bool * op, npy_bool * ip1, npy_bool * ip2, npy_intp
const int vstep = npyv_nlanes_u8;
const int wstep = vstep * UNROLL;
+#if @and@
+ const npyv_u8 zero = npyv_zero_u8();
+#endif
+
// Unrolled vectors loop
for (; len >= wstep; len -= wstep, ip1 += wstep, ip2 += wstep, op += wstep) {
/**begin repeat1
@@ -71,6 +75,10 @@ simd_binary_@kind@_BOOL(npy_bool * op, npy_bool * ip1, npy_bool * ip2, npy_intp
#if UNROLL > @unroll@
npyv_u8 a@unroll@ = npyv_load_u8(ip1 + vstep * @unroll@);
npyv_u8 b@unroll@ = npyv_load_u8(ip2 + vstep * @unroll@);
+#if @and@
+ // a = 0x00/0xff if any bit is set. ensures non-bool inputs are handled properly.
+ a@unroll@ = npyv_cvt_u8_b8(npyv_cmpgt_u8(a@unroll@, zero));
+#endif
npyv_u8 r@unroll@ = npyv_@intrin@_u8(a@unroll@, b@unroll@);
npyv_store_u8(op + vstep * @unroll@, byte_to_true(r@unroll@));
#endif
@@ -82,6 +90,10 @@ simd_binary_@kind@_BOOL(npy_bool * op, npy_bool * ip1, npy_bool * ip2, npy_intp
for (; len >= vstep; len -= vstep, ip1 += vstep, ip2 += vstep, op += vstep) {
npyv_u8 a = npyv_load_u8(ip1);
npyv_u8 b = npyv_load_u8(ip2);
+#if @and@
+ // a = 0x00/0xff if any bit is set. ensures non-bool inputs are handled properly.
+ a = npyv_cvt_u8_b8(npyv_cmpgt_u8(a, zero));
+#endif
npyv_u8 r = npyv_@intrin@_u8(a, b);
npyv_store_u8(op, byte_to_true(r));
}