summaryrefslogtreecommitdiff
path: root/numpy/lib/arraysetops.py
diff options
context:
space:
mode:
authorAngelGris <lucianogarciabes@gmail.com>2021-02-08 18:38:06 +0100
committerAngelGris <lucianogarciabes@gmail.com>2021-02-08 18:38:06 +0100
commit5936386ff9b5674385ccee9154e320c686e4a28e (patch)
treeb816d2e7935058fe8fafb4398253cc234e8e80e7 /numpy/lib/arraysetops.py
parent54eed9828ff0d5a78c5761c7d212bbc8e4cc1a09 (diff)
downloadnumpy-5936386ff9b5674385ccee9154e320c686e4a28e.tar.gz
BUG: np.in1d bug on the object array (issue 17923)
Diffstat (limited to 'numpy/lib/arraysetops.py')
-rw-r--r--numpy/lib/arraysetops.py24
1 files changed, 20 insertions, 4 deletions
diff --git a/numpy/lib/arraysetops.py b/numpy/lib/arraysetops.py
index 6c6c1ff80..e7f9add20 100644
--- a/numpy/lib/arraysetops.py
+++ b/numpy/lib/arraysetops.py
@@ -575,12 +575,28 @@ def in1d(ar1, ar2, assume_unique=False, invert=False):
if len(ar2) < 10 * len(ar1) ** 0.145 or contains_object:
if invert:
mask = np.ones(len(ar1), dtype=bool)
- for a in ar2:
- mask &= (ar1 != a)
+ # If ar2.dtype is object, store is used to wrap the a value
+ # in an array to prevent tuples from being unpacked before the comparison
+ if ar2.dtype == object:
+ store = np.empty(shape=1, dtype=object)
+ for a in ar2:
+ store[0] = a
+ mask &= (ar1 != store)
+ else:
+ for a in ar2:
+ mask &= (ar1 != a)
else:
mask = np.zeros(len(ar1), dtype=bool)
- for a in ar2:
- mask |= (ar1 == a)
+ # If ar2.dtype is object, store is used to wrap the a value
+ # in an array to prevent tuples from being unpacked before the comparison
+ if ar2.dtype == object:
+ store = np.empty(shape=1, dtype=object)
+ for a in ar2:
+ store[0] = a
+ mask |= (ar1 == store)
+ else:
+ for a in ar2:
+ mask |= (ar1 == a)
return mask
# Otherwise use sorting