diff --git a/pandas/core/indexes/interval.py b/pandas/core/indexes/interval.py index ae2d054fe94fed..2d14b98a6f7797 100644 --- a/pandas/core/indexes/interval.py +++ b/pandas/core/indexes/interval.py @@ -59,6 +59,8 @@ is_number, is_object_dtype, is_scalar, + is_signed_integer_dtype, + is_unsigned_integer_dtype, ) from pandas.core.dtypes.dtypes import IntervalDtype from pandas.core.dtypes.missing import is_valid_na_for_dtype @@ -340,8 +342,8 @@ def from_tuples( # "Union[IndexEngine, ExtensionEngine]" in supertype "Index" @cache_readonly def _engine(self) -> IntervalTree: # type: ignore[override] - left = self._maybe_convert_i8(self.left) - right = self._maybe_convert_i8(self.right) + left = self._maybe_convert_to_64bit_if_numeric(self.left) + right = self._maybe_convert_to_64bit_if_numeric(self.right) return IntervalTree(left, right, closed=self.closed) def __contains__(self, key: Any) -> bool: @@ -501,6 +503,18 @@ def _needs_i8_conversion(self, key) -> bool: i8_types = (Timestamp, Timedelta, DatetimeIndex, TimedeltaIndex) return isinstance(key, i8_types) + def _maybe_convert_to_64bit_if_numeric(self, key): + key = self._maybe_convert_i8(key) + dtype = key.dtype + if is_signed_integer_dtype(dtype) and dtype != "int64": + return key.astype(np.int64) + elif is_unsigned_integer_dtype(dtype) and dtype != "uint64": + return key.astype(np.uint64) + elif is_float_dtype(dtype) and dtype != "float64": + return key.astype(np.float64) + else: + return key + def _maybe_convert_i8(self, key): """ Maybe convert a given key to its equivalent i8 value(s). Used as a diff --git a/pandas/tests/indexes/interval/test_constructors.py b/pandas/tests/indexes/interval/test_constructors.py index 1c8697e96c2e99..e6c7f664624173 100644 --- a/pandas/tests/indexes/interval/test_constructors.py +++ b/pandas/tests/indexes/interval/test_constructors.py @@ -18,12 +18,12 @@ timedelta_range, ) import pandas._testing as tm -from pandas.api.types import is_unsigned_integer_dtype -from pandas.core.api import ( - Float64Index, - Int64Index, - UInt64Index, +from pandas.api.types import ( + is_float_dtype, + is_signed_integer_dtype, + is_unsigned_integer_dtype, ) +from pandas.core.api import NumericIndex from pandas.core.arrays import IntervalArray import pandas.core.common as com @@ -50,9 +50,17 @@ def _skip_test_constructor(self, dtype): [ [3, 14, 15, 92, 653], np.arange(10, dtype="int64"), - Int64Index(range(-10, 11)), - UInt64Index(range(10, 31)), - Float64Index(np.arange(20, 30, 0.5)), + NumericIndex(range(-10, 11), dtype=np.int64), + NumericIndex(range(-10, 11), dtype=np.int32), + NumericIndex(range(-10, 11), dtype=np.int16), + NumericIndex(range(-10, 11), dtype=np.int8), + NumericIndex(range(10, 31), dtype=np.uint64), + NumericIndex(range(10, 31), dtype=np.uint32), + NumericIndex(range(10, 31), dtype=np.uint16), + NumericIndex(range(10, 31), dtype=np.uint8), + NumericIndex(np.arange(20, 30, 0.5), dtype=np.float64), + NumericIndex(np.arange(20, 30, 0.5), dtype=np.float32), + NumericIndex(np.arange(20, 30, 0.5), dtype=np.float16), date_range("20180101", periods=10), date_range("20180101", periods=10, tz="US/Eastern"), timedelta_range("1 day", periods=10), @@ -81,10 +89,10 @@ def test_constructor(self, constructor, breaks, closed, name, use_dtype): @pytest.mark.parametrize( "breaks, subtype", [ - (Int64Index([0, 1, 2, 3, 4]), "float64"), - (Int64Index([0, 1, 2, 3, 4]), "datetime64[ns]"), - (Int64Index([0, 1, 2, 3, 4]), "timedelta64[ns]"), - (Float64Index([0, 1, 2, 3, 4]), "int64"), + (Index([0, 1, 2, 3, 4]), "float64"), + (Index([0, 1, 2, 3, 4]), "datetime64[ns]"), + (Index([0, 1, 2, 3, 4]), "timedelta64[ns]"), + (Index([0, 1, 2, 3, 4], dtype=np.float64), "int64"), (date_range("2017-01-01", periods=5), "int64"), (timedelta_range("1 day", periods=5), "int64"), ], @@ -103,9 +111,18 @@ def test_constructor_dtype(self, constructor, breaks, subtype): @pytest.mark.parametrize( "breaks", [ - Int64Index([0, 1, 2, 3, 4]), - UInt64Index([0, 1, 2, 3, 4]), - Float64Index([0, 1, 2, 3, 4]), + NumericIndex([0, 1, 2, 3, 4], dtype=np.int64), + NumericIndex([0, 1, 2, 3, 4], dtype=np.int32), + NumericIndex([0, 1, 2, 3, 4], dtype=np.int32), + NumericIndex([0, 1, 2, 3, 4], dtype=np.int16), + NumericIndex([0, 1, 2, 3, 4], dtype=np.int8), + NumericIndex([0, 1, 2, 3, 4], dtype=np.uint64), + NumericIndex([0, 1, 2, 3, 4], dtype=np.uint32), + NumericIndex([0, 1, 2, 3, 4], dtype=np.uint16), + NumericIndex([0, 1, 2, 3, 4], dtype=np.uint8), + NumericIndex([0, 1, 2, 3, 4], dtype=np.float64), + NumericIndex([0, 1, 2, 3, 4], dtype=np.float32), + NumericIndex([0, 1, 2, 3, 4], dtype=np.float16), date_range("2017-01-01", periods=5), timedelta_range("1 day", periods=5), ], @@ -262,8 +279,8 @@ def test_mixed_float_int(self, left_subtype, right_subtype): right = np.arange(1, 10, dtype=right_subtype) result = IntervalIndex.from_arrays(left, right) - expected_left = Float64Index(left) - expected_right = Float64Index(right) + expected_left = Index(left, dtype=np.float64) + expected_right = Index(right, dtype=np.float64) expected_subtype = np.float64 tm.assert_index_equal(result.left, expected_left) @@ -313,8 +330,13 @@ class TestFromTuples(ConstructorTests): """Tests specific to IntervalIndex.from_tuples""" def _skip_test_constructor(self, dtype): - if is_unsigned_integer_dtype(dtype): - return True, "tuples don't have a dtype" + msg = f"tuples don't have a dtype, so constructor won't see dtype {dtype}" + if is_signed_integer_dtype(dtype) and dtype != "int64": + return True, msg + elif is_float_dtype(dtype) and dtype != "float64": + return True, msg + elif is_unsigned_integer_dtype(dtype): + return True, msg else: return False, "" @@ -366,10 +388,13 @@ class TestClassConstructors(ConstructorTests): """Tests specific to the IntervalIndex/Index constructors""" def _skip_test_constructor(self, dtype): - # get_kwargs_from_breaks in TestFromTuples and TestClassconstructors just return - # tuples of ints, so IntervalIndex can't know the original dtype - if is_unsigned_integer_dtype(dtype): - return True, "tuples don't have a dtype" + msg = f"tuples don't have a dtype, so constructor won't see dtype {dtype}" + if is_signed_integer_dtype(dtype) and dtype != "int64": + return True, msg + elif is_float_dtype(dtype) and dtype != "float64": + return True, msg + elif is_unsigned_integer_dtype(dtype): + return True, msg else: return False, ""