summaryrefslogtreecommitdiff
path: root/numpy/lib/arraysetops.py
diff options
context:
space:
mode:
Diffstat (limited to 'numpy/lib/arraysetops.py')
-rw-r--r--numpy/lib/arraysetops.py28
1 files changed, 8 insertions, 20 deletions
diff --git a/numpy/lib/arraysetops.py b/numpy/lib/arraysetops.py
index e7f9add20..eb5c488e4 100644
--- a/numpy/lib/arraysetops.py
+++ b/numpy/lib/arraysetops.py
@@ -565,6 +565,10 @@ def in1d(ar1, ar2, assume_unique=False, invert=False):
ar1 = np.asarray(ar1).ravel()
ar2 = np.asarray(ar2).ravel()
+ # Ensure that iteration through object arrays yields size-1 arrays
+ if ar2.dtype == object:
+ ar2 = ar2.reshape(-1, 1)
+
# Check if one of the arrays may contain arbitrary objects
contains_object = ar1.dtype.hasobject or ar2.dtype.hasobject
@@ -575,28 +579,12 @@ def in1d(ar1, ar2, assume_unique=False, invert=False):
if len(ar2) < 10 * len(ar1) ** 0.145 or contains_object:
if invert:
mask = np.ones(len(ar1), dtype=bool)
- # If ar2.dtype is object, store is used to wrap the a value
- # in an array to prevent tuples from being unpacked before the comparison
- if ar2.dtype == object:
- store = np.empty(shape=1, dtype=object)
- for a in ar2:
- store[0] = a
- mask &= (ar1 != store)
- else:
- for a in ar2:
- mask &= (ar1 != a)
+ for a in ar2:
+ mask &= (ar1 != a)
else:
mask = np.zeros(len(ar1), dtype=bool)
- # If ar2.dtype is object, store is used to wrap the a value
- # in an array to prevent tuples from being unpacked before the comparison
- if ar2.dtype == object:
- store = np.empty(shape=1, dtype=object)
- for a in ar2:
- store[0] = a
- mask |= (ar1 == store)
- else:
- for a in ar2:
- mask |= (ar1 == a)
+ for a in ar2:
+ mask |= (ar1 == a)
return mask
# Otherwise use sorting