diff options
author | AngelGris <lucianogarciabes@gmail.com> | 2021-02-08 21:55:01 +0100 |
---|---|---|
committer | AngelGris <lucianogarciabes@gmail.com> | 2021-02-08 21:55:01 +0100 |
commit | 8fbd472e562237dd56ce251e266e2090d6c5003b (patch) | |
tree | 3bc47cb21fffea5f36f3c2f38818a65dad98f007 /numpy/lib/arraysetops.py | |
parent | 5936386ff9b5674385ccee9154e320c686e4a28e (diff) | |
download | numpy-8fbd472e562237dd56ce251e266e2090d6c5003b.tar.gz |
Implement different approach to fix bug
Diffstat (limited to 'numpy/lib/arraysetops.py')
-rw-r--r-- | numpy/lib/arraysetops.py | 28 |
1 files changed, 8 insertions, 20 deletions
diff --git a/numpy/lib/arraysetops.py b/numpy/lib/arraysetops.py index e7f9add20..eb5c488e4 100644 --- a/numpy/lib/arraysetops.py +++ b/numpy/lib/arraysetops.py @@ -565,6 +565,10 @@ def in1d(ar1, ar2, assume_unique=False, invert=False): ar1 = np.asarray(ar1).ravel() ar2 = np.asarray(ar2).ravel() + # Ensure that iteration through object arrays yields size-1 arrays + if ar2.dtype == object: + ar2 = ar2.reshape(-1, 1) + # Check if one of the arrays may contain arbitrary objects contains_object = ar1.dtype.hasobject or ar2.dtype.hasobject @@ -575,28 +579,12 @@ def in1d(ar1, ar2, assume_unique=False, invert=False): if len(ar2) < 10 * len(ar1) ** 0.145 or contains_object: if invert: mask = np.ones(len(ar1), dtype=bool) - # If ar2.dtype is object, store is used to wrap the a value - # in an array to prevent tuples from being unpacked before the comparison - if ar2.dtype == object: - store = np.empty(shape=1, dtype=object) - for a in ar2: - store[0] = a - mask &= (ar1 != store) - else: - for a in ar2: - mask &= (ar1 != a) + for a in ar2: + mask &= (ar1 != a) else: mask = np.zeros(len(ar1), dtype=bool) - # If ar2.dtype is object, store is used to wrap the a value - # in an array to prevent tuples from being unpacked before the comparison - if ar2.dtype == object: - store = np.empty(shape=1, dtype=object) - for a in ar2: - store[0] = a - mask |= (ar1 == store) - else: - for a in ar2: - mask |= (ar1 == a) + for a in ar2: + mask |= (ar1 == a) return mask # Otherwise use sorting |