From 235dbe1f9ea0955c0119f79a5c6614cd0268ef05 Mon Sep 17 00:00:00 2001 From: Miles Cranmer Date: Sun, 25 Dec 2022 13:43:43 -0500 Subject: BUG: Fix integer overflow in in1d for mixed integer dtypes #22877 (#22878) * TST: Mixed integer types for in1d * BUG: Fix mixed dtype overflows for in1d (#22877) * BUG: Type conversion for integer overflow check * MAINT: Fix linting issues in in1d * MAINT: ar1 overflow check only for non-empty array * MAINT: Expand bounds of overflow check * TST: Fix integer overflow in mixed boolean test * TST: Include test for overflow on mixed dtypes * MAINT: Less conservative overflow checks --- numpy/lib/arraysetops.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) (limited to 'numpy/lib/arraysetops.py') diff --git a/numpy/lib/arraysetops.py b/numpy/lib/arraysetops.py index cf5f47a82..300bbda26 100644 --- a/numpy/lib/arraysetops.py +++ b/numpy/lib/arraysetops.py @@ -649,8 +649,24 @@ def in1d(ar1, ar2, assume_unique=False, invert=False, *, kind=None): ar2_range = int(ar2_max) - int(ar2_min) # Constraints on whether we can actually use the table method: - range_safe_from_overflow = ar2_range < np.iinfo(ar2.dtype).max + # 1. Assert memory usage is not too large below_memory_constraint = ar2_range <= 6 * (ar1.size + ar2.size) + # 2. Check overflows for (ar2 - ar2_min); dtype=ar2.dtype + range_safe_from_overflow = ar2_range <= np.iinfo(ar2.dtype).max + # 3. Check overflows for (ar1 - ar2_min); dtype=ar1.dtype + if ar1.size > 0: + ar1_min = np.min(ar1) + ar1_max = np.max(ar1) + + # After masking, the range of ar1 is guaranteed to be + # within the range of ar2: + ar1_upper = min(int(ar1_max), int(ar2_max)) + ar1_lower = max(int(ar1_min), int(ar2_min)) + + range_safe_from_overflow &= all(( + ar1_upper - int(ar2_min) <= np.iinfo(ar1.dtype).max, + ar1_lower - int(ar2_min) >= np.iinfo(ar1.dtype).min + )) # Optimal performance is for approximately # log10(size) > (log10(range) - 2.27) / 0.927. @@ -687,7 +703,7 @@ def in1d(ar1, ar2, assume_unique=False, invert=False, *, kind=None): elif kind == 'table': # not range_safe_from_overflow raise RuntimeError( "You have specified kind='table', " - "but the range of values in `ar2` exceeds the " + "but the range of values in `ar2` or `ar1` exceed the " "maximum integer of the datatype. " "Please set `kind` to None or 'sort'." ) -- cgit v1.2.1