Skip to content

Commit

Permalink
IntervalIndex tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Terji Petersen authored and Terji Petersen committed Nov 20, 2022
1 parent 1ce1e8a commit 25c71c7
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 22 deletions.
18 changes: 16 additions & 2 deletions pandas/core/indexes/interval.py
Expand Up @@ -59,6 +59,8 @@
is_number,
is_object_dtype,
is_scalar,
is_signed_integer_dtype,
is_unsigned_integer_dtype,
)
from pandas.core.dtypes.dtypes import IntervalDtype
from pandas.core.dtypes.missing import is_valid_na_for_dtype
Expand Down Expand Up @@ -340,8 +342,8 @@ def from_tuples(
# "Union[IndexEngine, ExtensionEngine]" in supertype "Index"
@cache_readonly
def _engine(self) -> IntervalTree: # type: ignore[override]
left = self._maybe_convert_i8(self.left)
right = self._maybe_convert_i8(self.right)
left = self._maybe_convert_to_64bit_if_numeric(self.left)
right = self._maybe_convert_to_64bit_if_numeric(self.right)
return IntervalTree(left, right, closed=self.closed)

def __contains__(self, key: Any) -> bool:
Expand Down Expand Up @@ -501,6 +503,18 @@ def _needs_i8_conversion(self, key) -> bool:
i8_types = (Timestamp, Timedelta, DatetimeIndex, TimedeltaIndex)
return isinstance(key, i8_types)

def _maybe_convert_to_64bit_if_numeric(self, key):
key = self._maybe_convert_i8(key)
dtype = key.dtype
if is_signed_integer_dtype(dtype) and dtype != "int64":
return key.astype(np.int64)
elif is_unsigned_integer_dtype(dtype) and dtype != "uint64":
return key.astype(np.uint64)
elif is_float_dtype(dtype) and dtype != "float64":
return key.astype(np.float64)
else:
return key

def _maybe_convert_i8(self, key):
"""
Maybe convert a given key to its equivalent i8 value(s). Used as a
Expand Down
36 changes: 16 additions & 20 deletions pandas/tests/indexes/interval/test_constructors.py
Expand Up @@ -19,11 +19,7 @@
)
import pandas._testing as tm
from pandas.api.types import is_unsigned_integer_dtype
from pandas.core.api import (
Float64Index,
Int64Index,
UInt64Index,
)
from pandas.core.api import NumericIndex
from pandas.core.arrays import IntervalArray
import pandas.core.common as com

Expand All @@ -44,9 +40,9 @@ class ConstructorTests:
params=[
([3, 14, 15, 92, 653], np.int64),
(np.arange(10, dtype="int64"), np.int64),
(Int64Index(range(-10, 11)), np.int64),
(UInt64Index(range(10, 31)), np.uint64),
(Float64Index(np.arange(20, 30, 0.5)), np.float64),
(NumericIndex(range(-10, 11), dtype=np.int64), np.int64),
(NumericIndex(range(10, 31), dtype=np.uint64), np.uint64),
(NumericIndex(np.arange(20, 30, 0.5), dtype=np.float64), np.float64),
(date_range("20180101", periods=10), "<M8[ns]"),
(
date_range("20180101", periods=10, tz="US/Eastern"),
Expand Down Expand Up @@ -74,10 +70,10 @@ def test_constructor(self, constructor, breaks_and_expected_subtype, closed, nam
@pytest.mark.parametrize(
"breaks, subtype",
[
(Int64Index([0, 1, 2, 3, 4]), "float64"),
(Int64Index([0, 1, 2, 3, 4]), "datetime64[ns]"),
(Int64Index([0, 1, 2, 3, 4]), "timedelta64[ns]"),
(Float64Index([0, 1, 2, 3, 4]), "int64"),
(NumericIndex([0, 1, 2, 3, 4], dtype=np.int64), "float64"),
(NumericIndex([0, 1, 2, 3, 4], dtype=np.int64), "datetime64[ns]"),
(NumericIndex([0, 1, 2, 3, 4], dtype=np.int64), "timedelta64[ns]"),
(NumericIndex([0, 1, 2, 3, 4], dtype=np.float64), "int64"),
(date_range("2017-01-01", periods=5), "int64"),
(timedelta_range("1 day", periods=5), "int64"),
],
Expand All @@ -96,9 +92,9 @@ def test_constructor_dtype(self, constructor, breaks, subtype):
@pytest.mark.parametrize(
"breaks",
[
Int64Index([0, 1, 2, 3, 4]),
UInt64Index([0, 1, 2, 3, 4]),
Float64Index([0, 1, 2, 3, 4]),
NumericIndex([0, 1, 2, 3, 4], dtype=np.int64),
NumericIndex([0, 1, 2, 3, 4], dtype=np.uint64),
NumericIndex([0, 1, 2, 3, 4], dtype=np.float64),
date_range("2017-01-01", periods=5),
timedelta_range("1 day", periods=5),
],
Expand Down Expand Up @@ -255,8 +251,8 @@ def test_mixed_float_int(self, left_subtype, right_subtype):
right = np.arange(1, 10, dtype=right_subtype)
result = IntervalIndex.from_arrays(left, right)

expected_left = Float64Index(left)
expected_right = Float64Index(right)
expected_left = NumericIndex(left, dtype=np.float64)
expected_right = NumericIndex(right, dtype=np.float64)
expected_subtype = np.float64

tm.assert_index_equal(result.left, expected_left)
Expand Down Expand Up @@ -307,9 +303,9 @@ class TuplesClassConstructorTests(ConstructorTests):
params=[
([3, 14, 15, 92, 653], np.int64),
(np.arange(10, dtype="int64"), np.int64),
(Int64Index(range(-10, 11)), np.int64),
(UInt64Index(range(10, 31)), np.int64),
(Float64Index(np.arange(20, 30, 0.5)), np.float64),
(NumericIndex(range(-10, 11), dtype=np.int64), np.int64),
(NumericIndex(range(10, 31), dtype=np.uint64), np.int64),
(NumericIndex(np.arange(20, 30, 0.5), dtype=np.float64), np.float64),
(date_range("20180101", periods=10), "<M8[ns]"),
(
date_range("20180101", periods=10, tz="US/Eastern"),
Expand Down

0 comments on commit 25c71c7

Please sign in to comment.