Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG: Fix integer overflow in in1d for mixed integer dtypes #22877 #22878

Merged
merged 9 commits into from Dec 25, 2022
20 changes: 18 additions & 2 deletions numpy/lib/arraysetops.py
Expand Up @@ -649,8 +649,24 @@ def in1d(ar1, ar2, assume_unique=False, invert=False, *, kind=None):
ar2_range = int(ar2_max) - int(ar2_min)

# Constraints on whether we can actually use the table method:
range_safe_from_overflow = ar2_range < np.iinfo(ar2.dtype).max
# 1. Assert memory usage is not too large
below_memory_constraint = ar2_range <= 6 * (ar1.size + ar2.size)
# 2. Check overflows for (ar2 - ar2_min); dtype=ar2.dtype
range_safe_from_overflow = ar2_range <= np.iinfo(ar2.dtype).max
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR also corrects the bounds of the overflow check. It should have really been <= in #12065, rather than <. I noticed this after adding some new tests.

# 3. Check overflows for (ar1 - ar2_min); dtype=ar1.dtype
if ar1.size > 0:
ar1_min = np.min(ar1)
ar1_max = np.max(ar1)

# After masking, the range of ar1 is guaranteed to be
# within the range of ar2:
ar1_upper = min(int(ar1_max), int(ar2_max))
ar1_lower = max(int(ar1_min), int(ar2_min))

range_safe_from_overflow &= all((
ar1_upper - int(ar2_min) <= np.iinfo(ar1.dtype).max,
ar1_lower - int(ar2_min) >= np.iinfo(ar1.dtype).min
))

# Optimal performance is for approximately
# log10(size) > (log10(range) - 2.27) / 0.927.
Expand Down Expand Up @@ -687,7 +703,7 @@ def in1d(ar1, ar2, assume_unique=False, invert=False, *, kind=None):
elif kind == 'table': # not range_safe_from_overflow
raise RuntimeError(
"You have specified kind='table', "
"but the range of values in `ar2` exceeds the "
"but the range of values in `ar2` or `ar1` exceed the "
"maximum integer of the datatype. "
"Please set `kind` to None or 'sort'."
)
Expand Down
39 changes: 37 additions & 2 deletions numpy/lib/tests/test_arraysetops.py
Expand Up @@ -414,13 +414,48 @@ def test_in1d_table_timedelta_fails(self):
with pytest.raises(ValueError):
in1d(a, b, kind="table")

@pytest.mark.parametrize(
"dtype1,dtype2",
[
(np.int8, np.int16),
(np.int16, np.int8),
(np.uint8, np.uint16),
(np.uint16, np.uint8),
(np.uint8, np.int16),
(np.int16, np.uint8),
]
)
@pytest.mark.parametrize("kind", [None, "sort", "table"])
def test_in1d_mixed_dtype(self, dtype1, dtype2, kind):
"""Test that in1d works as expected for mixed dtype input."""
is_dtype2_signed = np.issubdtype(dtype2, np.signedinteger)
ar1 = np.array([0, 0, 1, 1], dtype=dtype1)

if is_dtype2_signed:
ar2 = np.array([-128, 0, 127], dtype=dtype2)
else:
ar2 = np.array([127, 0, 255], dtype=dtype2)

expected = np.array([True, True, False, False])

expect_failure = kind == "table" and any((
dtype1 == np.int8 and dtype2 == np.int16,
dtype1 == np.int16 and dtype2 == np.int8
))

if expect_failure:
with pytest.raises(RuntimeError, match="exceed the maximum"):
in1d(ar1, ar2, kind=kind)
else:
assert_array_equal(in1d(ar1, ar2, kind=kind), expected)

@pytest.mark.parametrize("kind", [None, "sort", "table"])
def test_in1d_mixed_boolean(self, kind):
"""Test that in1d works as expected for bool/int input."""
for dtype in np.typecodes["AllInteger"]:
a = np.array([True, False, False], dtype=bool)
b = np.array([1, 1, 1, 1], dtype=dtype)
expected = np.array([True, False, False], dtype=bool)
b = np.array([0, 0, 0, 0], dtype=dtype)
expected = np.array([False, True, True], dtype=bool)
assert_array_equal(in1d(a, b, kind=kind), expected)

a, b = b, a
Expand Down