Skip to content

Commit

Permalink
IntervalIndex
Browse files Browse the repository at this point in the history
  • Loading branch information
Terji Petersen authored and Terji Petersen committed Nov 20, 2022
1 parent 255b1a7 commit f81f0df
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 25 deletions.
18 changes: 16 additions & 2 deletions pandas/core/indexes/interval.py
Expand Up @@ -59,6 +59,8 @@
is_number,
is_object_dtype,
is_scalar,
is_signed_integer_dtype,
is_unsigned_integer_dtype,
)
from pandas.core.dtypes.dtypes import IntervalDtype
from pandas.core.dtypes.missing import is_valid_na_for_dtype
Expand Down Expand Up @@ -340,8 +342,8 @@ def from_tuples(
# "Union[IndexEngine, ExtensionEngine]" in supertype "Index"
@cache_readonly
def _engine(self) -> IntervalTree: # type: ignore[override]
left = self._maybe_convert_i8(self.left)
right = self._maybe_convert_i8(self.right)
left = self._maybe_convert_to_64bit_if_numeric(self.left)
right = self._maybe_convert_to_64bit_if_numeric(self.right)
return IntervalTree(left, right, closed=self.closed)

def __contains__(self, key: Any) -> bool:
Expand Down Expand Up @@ -501,6 +503,18 @@ def _needs_i8_conversion(self, key) -> bool:
i8_types = (Timestamp, Timedelta, DatetimeIndex, TimedeltaIndex)
return isinstance(key, i8_types)

def _maybe_convert_to_64bit_if_numeric(self, key):
key = self._maybe_convert_i8(key)
dtype = key.dtype
if is_signed_integer_dtype(dtype) and dtype != "int64":
return key.astype(np.int64)
elif is_unsigned_integer_dtype(dtype) and dtype != "uint64":
return key.astype(np.uint64)
elif is_float_dtype(dtype) and dtype != "float64":
return key.astype(np.float64)
else:
return key

def _maybe_convert_i8(self, key):
"""
Maybe convert a given key to its equivalent i8 value(s). Used as a
Expand Down
71 changes: 48 additions & 23 deletions pandas/tests/indexes/interval/test_constructors.py
Expand Up @@ -18,12 +18,12 @@
timedelta_range,
)
import pandas._testing as tm
from pandas.api.types import is_unsigned_integer_dtype
from pandas.core.api import (
Float64Index,
Int64Index,
UInt64Index,
from pandas.api.types import (
is_float_dtype,
is_signed_integer_dtype,
is_unsigned_integer_dtype,
)
from pandas.core.api import NumericIndex
from pandas.core.arrays import IntervalArray
import pandas.core.common as com

Expand All @@ -50,9 +50,17 @@ def _skip_test_constructor(self, dtype):
[
[3, 14, 15, 92, 653],
np.arange(10, dtype="int64"),
Int64Index(range(-10, 11)),
UInt64Index(range(10, 31)),
Float64Index(np.arange(20, 30, 0.5)),
NumericIndex(range(-10, 11), dtype=np.int64),
NumericIndex(range(-10, 11), dtype=np.int32),
NumericIndex(range(-10, 11), dtype=np.int16),
NumericIndex(range(-10, 11), dtype=np.int8),
NumericIndex(range(10, 31), dtype=np.uint64),
NumericIndex(range(10, 31), dtype=np.uint32),
NumericIndex(range(10, 31), dtype=np.uint16),
NumericIndex(range(10, 31), dtype=np.uint8),
NumericIndex(np.arange(20, 30, 0.5), dtype=np.float64),
NumericIndex(np.arange(20, 30, 0.5), dtype=np.float32),
NumericIndex(np.arange(20, 30, 0.5), dtype=np.float16),
date_range("20180101", periods=10),
date_range("20180101", periods=10, tz="US/Eastern"),
timedelta_range("1 day", periods=10),
Expand Down Expand Up @@ -81,10 +89,10 @@ def test_constructor(self, constructor, breaks, closed, name, use_dtype):
@pytest.mark.parametrize(
"breaks, subtype",
[
(Int64Index([0, 1, 2, 3, 4]), "float64"),
(Int64Index([0, 1, 2, 3, 4]), "datetime64[ns]"),
(Int64Index([0, 1, 2, 3, 4]), "timedelta64[ns]"),
(Float64Index([0, 1, 2, 3, 4]), "int64"),
(Index([0, 1, 2, 3, 4]), "float64"),
(Index([0, 1, 2, 3, 4]), "datetime64[ns]"),
(Index([0, 1, 2, 3, 4]), "timedelta64[ns]"),
(Index([0, 1, 2, 3, 4], dtype=np.float64), "int64"),
(date_range("2017-01-01", periods=5), "int64"),
(timedelta_range("1 day", periods=5), "int64"),
],
Expand All @@ -103,9 +111,18 @@ def test_constructor_dtype(self, constructor, breaks, subtype):
@pytest.mark.parametrize(
"breaks",
[
Int64Index([0, 1, 2, 3, 4]),
UInt64Index([0, 1, 2, 3, 4]),
Float64Index([0, 1, 2, 3, 4]),
NumericIndex([0, 1, 2, 3, 4], dtype=np.int64),
NumericIndex([0, 1, 2, 3, 4], dtype=np.int32),
NumericIndex([0, 1, 2, 3, 4], dtype=np.int32),
NumericIndex([0, 1, 2, 3, 4], dtype=np.int16),
NumericIndex([0, 1, 2, 3, 4], dtype=np.int8),
NumericIndex([0, 1, 2, 3, 4], dtype=np.uint64),
NumericIndex([0, 1, 2, 3, 4], dtype=np.uint32),
NumericIndex([0, 1, 2, 3, 4], dtype=np.uint16),
NumericIndex([0, 1, 2, 3, 4], dtype=np.uint8),
NumericIndex([0, 1, 2, 3, 4], dtype=np.float64),
NumericIndex([0, 1, 2, 3, 4], dtype=np.float32),
NumericIndex([0, 1, 2, 3, 4], dtype=np.float16),
date_range("2017-01-01", periods=5),
timedelta_range("1 day", periods=5),
],
Expand Down Expand Up @@ -262,8 +279,8 @@ def test_mixed_float_int(self, left_subtype, right_subtype):
right = np.arange(1, 10, dtype=right_subtype)
result = IntervalIndex.from_arrays(left, right)

expected_left = Float64Index(left)
expected_right = Float64Index(right)
expected_left = Index(left, dtype=np.float64)
expected_right = Index(right, dtype=np.float64)
expected_subtype = np.float64

tm.assert_index_equal(result.left, expected_left)
Expand Down Expand Up @@ -313,8 +330,13 @@ class TestFromTuples(ConstructorTests):
"""Tests specific to IntervalIndex.from_tuples"""

def _skip_test_constructor(self, dtype):
if is_unsigned_integer_dtype(dtype):
return True, "tuples don't have a dtype"
msg = f"tuples don't have a dtype, so constructor won't see dtype {dtype}"
if is_signed_integer_dtype(dtype) and dtype != "int64":
return True, msg
elif is_float_dtype(dtype) and dtype != "float64":
return True, msg
elif is_unsigned_integer_dtype(dtype):
return True, msg
else:
return False, ""

Expand Down Expand Up @@ -366,10 +388,13 @@ class TestClassConstructors(ConstructorTests):
"""Tests specific to the IntervalIndex/Index constructors"""

def _skip_test_constructor(self, dtype):
# get_kwargs_from_breaks in TestFromTuples and TestClassconstructors just return
# tuples of ints, so IntervalIndex can't know the original dtype
if is_unsigned_integer_dtype(dtype):
return True, "tuples don't have a dtype"
msg = f"tuples don't have a dtype, so constructor won't see dtype {dtype}"
if is_signed_integer_dtype(dtype) and dtype != "int64":
return True, msg
elif is_float_dtype(dtype) and dtype != "float64":
return True, msg
elif is_unsigned_integer_dtype(dtype):
return True, msg
else:
return False, ""

Expand Down

0 comments on commit f81f0df

Please sign in to comment.