Skip to content

Commit

Permalink
[ENH] Partially replace pd.Int64Index with pd.Index (#2339)
Browse files Browse the repository at this point in the history
Replaces `pd.Int64Index` with `pd.Index`, and `VALID_INDEX_TYPES` checks with dedicated functions in `sktime.utils.validation.series`.
  • Loading branch information
Stanislav Khrapov committed Apr 4, 2022
1 parent 2c012eb commit 4b9003c
Show file tree
Hide file tree
Showing 9 changed files with 76 additions and 47 deletions.
8 changes: 5 additions & 3 deletions sktime/datatypes/_series/_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
import numpy as np
import pandas as pd

from sktime.utils.validation.series import is_in_valid_index_types

VALID_INDEX_TYPES = (pd.Int64Index, pd.RangeIndex, pd.PeriodIndex, pd.DatetimeIndex)

# whether the checks insist on freq attribute is set
Expand Down Expand Up @@ -69,7 +71,7 @@ def ret(valid, msg, metadata, return_metadata):
metadata["is_univariate"] = len(obj.columns) < 2

# check whether the time index is of valid type
if not type(index) in VALID_INDEX_TYPES:
if not is_in_valid_index_types(index):
msg = (
f"{type(index)} is not supported for {var_name}, use "
f"one of {VALID_INDEX_TYPES} instead."
Expand Down Expand Up @@ -131,7 +133,7 @@ def ret(valid, msg, metadata, return_metadata):
return ret(False, msg, None, return_metadata)

# check whether the time index is of valid type
if not type(index) in VALID_INDEX_TYPES:
if not is_in_valid_index_types(index):
msg = (
f"{type(index)} is not supported for {var_name}, use "
f"one of {VALID_INDEX_TYPES} instead."
Expand Down Expand Up @@ -214,7 +216,7 @@ def _index_equally_spaced(index):
-------
equally_spaced: bool - whether index is equally spaced
"""
if not isinstance(index, VALID_INDEX_TYPES):
if not is_in_valid_index_types(index):
raise TypeError(f"index must be one of {VALID_INDEX_TYPES}")

# empty and single element indices are equally spaced
Expand Down
30 changes: 16 additions & 14 deletions sktime/forecasting/base/_fh.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,13 @@
is_int,
is_timedelta_or_date_offset,
)
from sktime.utils.validation.series import VALID_INDEX_TYPES
from sktime.utils.validation.series import (
VALID_INDEX_TYPES,
is_in_valid_absolute_index_types,
is_in_valid_index_types,
is_in_valid_relative_index_types,
)

RELATIVE_TYPES = (pd.Int64Index, pd.RangeIndex, pd.TimedeltaIndex)
ABSOLUTE_TYPES = (pd.Int64Index, pd.RangeIndex, pd.DatetimeIndex, pd.PeriodIndex)
assert set(RELATIVE_TYPES).issubset(VALID_INDEX_TYPES)
assert set(ABSOLUTE_TYPES).issubset(VALID_INDEX_TYPES)
VALID_FORECASTING_HORIZON_TYPES = (int, list, np.ndarray, pd.Index)

DELEGATED_METHODS = (
Expand Down Expand Up @@ -95,19 +96,20 @@ def _check_values(values: Union[VALID_FORECASTING_HORIZON_TYPES]) -> pd.Index:
# isinstance() does not work here, because index types inherit from each
# other,
# hence we check for type equality here
if type(values) in VALID_INDEX_TYPES:
if is_in_valid_index_types(values):
pass

# convert single integer to pandas index, no further checks needed
# convert single integer or timedelta or dateoffset
# to pandas index, no further checks needed
elif is_int(values):
return pd.Int64Index([values], dtype=int)
values = pd.Index([values], dtype=int)

elif is_timedelta_or_date_offset(values):
return pd.Index([values])
values = pd.Index([values])

# convert np.array or list to pandas index
elif is_array(values) and array_is_int(values):
values = pd.Int64Index(values, dtype=int)
values = pd.Index(values, dtype=int)

elif is_array(values) and array_is_timedelta_or_date_offset(values):
values = pd.Index(values)
Expand Down Expand Up @@ -181,17 +183,17 @@ def __init__(
# types inherit from each other, hence we check for type equality
error_msg = f"`values` type is not compatible with `is_relative={is_relative}`."
if is_relative is None:
if type(values) in RELATIVE_TYPES:
if is_in_valid_relative_index_types(values):
is_relative = True
elif type(values) in ABSOLUTE_TYPES:
elif is_in_valid_absolute_index_types(values):
is_relative = False
else:
raise TypeError(f"{type(values)} is not a supported fh index type")
if is_relative:
if not type(values) in RELATIVE_TYPES:
if not is_in_valid_relative_index_types(values):
raise TypeError(error_msg)
else:
if not type(values) in ABSOLUTE_TYPES:
if not is_in_valid_absolute_index_types(values):
raise TypeError(error_msg)

self._values = values
Expand Down
4 changes: 2 additions & 2 deletions sktime/forecasting/base/adapters/_statsmodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _fit(self, y, X=None, fh=None):
"""
# statsmodels does not support the pd.Int64Index as required,
# so we coerce them here to pd.RangeIndex
if isinstance(y, pd.Series) and type(y.index) == pd.Int64Index:
if isinstance(y, pd.Series) and y.index.is_integer():
y, X = _coerce_int_to_range_index(y, X)
self._fit_forecaster(y, X)
return self
Expand Down Expand Up @@ -113,7 +113,7 @@ def _coerce_int_to_range_index(y, X=None):
np.testing.assert_array_equal(y.index, new_index)
except AssertionError:
raise ValueError(
"Coercion of pd.Int64Index to pd.RangeIndex "
"Coercion of integer pd.Index to pd.RangeIndex "
"failed. Please provide `y_train` with a "
"pd.RangeIndex."
)
Expand Down
34 changes: 21 additions & 13 deletions sktime/forecasting/base/tests/test_fh.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,11 @@
_get_intervals_count_and_unit,
_shift,
)
from sktime.utils.validation.series import VALID_INDEX_TYPES
from sktime.utils.validation.series import (
VALID_INDEX_TYPES,
is_in_valid_index_types,
is_integer_index,
)


def _assert_index_equal(a, b):
Expand All @@ -49,7 +53,10 @@ def test_fh(index_type, fh_type, is_relative, steps):
"""Testing ForecastingHorizon conversions."""
# generate data
y = make_forecasting_problem(index_type=index_type)
assert isinstance(y.index, INDEX_TYPE_LOOKUP.get(index_type))
if index_type == "int":
assert is_integer_index(y.index)
else:
assert isinstance(y.index, INDEX_TYPE_LOOKUP.get(index_type))

# split data
y_train, y_test = temporal_train_test_split(y, test_size=10)
Expand All @@ -59,12 +66,15 @@ def test_fh(index_type, fh_type, is_relative, steps):

# generate fh
fh = _make_fh(cutoff, steps, fh_type, is_relative)
assert isinstance(fh.to_pandas(), INDEX_TYPE_LOOKUP.get(fh_type))
if fh_type == "int":
assert is_integer_index(fh.to_pandas())
else:
assert isinstance(fh.to_pandas(), INDEX_TYPE_LOOKUP.get(fh_type))

# get expected outputs
if isinstance(steps, int):
steps = np.array([steps])
fh_relative = pd.Int64Index(steps).sort_values()
fh_relative = pd.Index(steps).sort_values()
fh_absolute = y.index[np.where(y.index == cutoff)[0] + steps].sort_values()
fh_indexer = fh_relative - 1
fh_oos = fh.to_pandas()[fh_relative > 0]
Expand Down Expand Up @@ -102,7 +112,7 @@ def test_fh(index_type, fh_type, is_relative, steps):


def test_fh_method_delegation():
"""Test ForecastinHorizon delegated methods."""
"""Test ForecastingHorizon delegated methods."""
fh = ForecastingHorizon(1)
for method in DELEGATED_METHODS:
assert hasattr(fh, method)
Expand All @@ -125,10 +135,7 @@ def test_check_fh_values_bad_input_types(arg):
ForecastingHorizon(arg)


DUPLICATE_INPUT_ARGS = (
np.array([1, 2, 2]),
[3, 3, 1],
)
DUPLICATE_INPUT_ARGS = (np.array([1, 2, 2]), [3, 3, 1])


@pytest.mark.parametrize("arg", DUPLICATE_INPUT_ARGS)
Expand All @@ -139,7 +146,7 @@ def test_check_fh_values_duplicate_input_values(arg):


GOOD_ABSOLUTE_INPUT_ARGS = (
pd.Int64Index([1, 2, 3]),
pd.Index([1, 2, 3]),
pd.period_range("2000-01-01", periods=3, freq="D"),
pd.date_range("2000-01-01", periods=3, freq="M"),
np.array([1, 2, 3]),
Expand All @@ -151,8 +158,9 @@ def test_check_fh_values_duplicate_input_values(arg):
@pytest.mark.parametrize("arg", GOOD_ABSOLUTE_INPUT_ARGS)
def test_check_fh_absolute_values_input_conversion_to_pandas_index(arg):
"""Test conversion of absolute horizons to pandas index."""
output = ForecastingHorizon(arg, is_relative=False).to_pandas()
assert type(output) in VALID_INDEX_TYPES
assert is_in_valid_index_types(
ForecastingHorizon(arg, is_relative=False).to_pandas()
)


GOOD_RELATIVE_INPUT_ARGS = [
Expand Down Expand Up @@ -212,7 +220,7 @@ def test_coerce_duration_to_int(duration):
ret = _coerce_duration_to_int(duration, freq=_get_freq(duration))

# check output type is always integer
assert type(ret) in (pd.Int64Index, np.integer, int)
assert (type(ret) in (np.integer, int)) or is_integer_index(ret)

# check result
if isinstance(duration, pd.Index):
Expand Down
2 changes: 1 addition & 1 deletion sktime/forecasting/tests/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@
]

INDEX_TYPE_LOOKUP = {
"int": pd.Int64Index,
"int": pd.Index,
"range": pd.RangeIndex,
"datetime": pd.DatetimeIndex,
"period": pd.PeriodIndex,
Expand Down
2 changes: 1 addition & 1 deletion sktime/utils/_testing/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _make_index(n_timepoints, index_type=None):

elif index_type == "int":
start = 3
return pd.Int64Index(np.arange(start, start + n_timepoints))
return pd.Index(np.arange(start, start + n_timepoints), dtype=int)

else:
raise ValueError(f"index_class: {index_type} is not supported")
6 changes: 3 additions & 3 deletions sktime/utils/datetime.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import numpy as np
import pandas as pd

from sktime.utils.validation.series import check_time_index
from sktime.utils.validation.series import check_time_index, is_integer_index


def _coerce_duration_to_int(duration, freq=None):
Expand All @@ -37,7 +37,7 @@ def _coerce_duration_to_int(duration, freq=None):
duration[0], pd.tseries.offsets.BaseOffset
):
count = _get_intervals_count_and_unit(freq)[0]
return pd.Int64Index([d.n / count for d in duration])
return pd.Index([d.n / count for d in duration], dtype=int)
elif isinstance(duration, (pd.Timedelta, pd.TimedeltaIndex)):
count, unit = _get_intervals_count_and_unit(freq)
# integer conversion only works reliably with non-ambiguous units (
Expand Down Expand Up @@ -100,7 +100,7 @@ def _shift(x, by=1):
Shifted time point
"""
assert isinstance(x, (pd.Period, pd.Timestamp, int, np.integer)), type(x)
assert isinstance(by, (int, np.integer, pd.Int64Index)), type(by)
assert isinstance(by, (int, np.integer)) or is_integer_index(by), type(by)
if isinstance(x, pd.Timestamp):
if not hasattr(x, "freq") or x.freq is None:
raise ValueError("No `freq` information available")
Expand Down
35 changes: 26 additions & 9 deletions sktime/utils/validation/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,28 @@
pd.DatetimeIndex,
pd.TimedeltaIndex,
)
RELATIVE_INDEX_TYPES = (pd.RangeIndex, pd.TimedeltaIndex)
ABSOLUTE_INDEX_TYPES = (pd.RangeIndex, pd.DatetimeIndex, pd.PeriodIndex)
assert set(RELATIVE_INDEX_TYPES).issubset(VALID_INDEX_TYPES)
assert set(ABSOLUTE_INDEX_TYPES).issubset(VALID_INDEX_TYPES)


def is_integer_index(x) -> bool:
"""Check that the input is an integer pd.Index."""
return isinstance(x, pd.Index) and x.is_integer()


def is_in_valid_index_types(x) -> bool:
"""Check that the input type belongs to the valid index types."""
return isinstance(x, VALID_INDEX_TYPES) or is_integer_index(x)


def is_in_valid_relative_index_types(x) -> bool:
return isinstance(x, RELATIVE_INDEX_TYPES) or is_integer_index(x)


def is_in_valid_absolute_index_types(x) -> bool:
return isinstance(x, ABSOLUTE_INDEX_TYPES) or is_integer_index(x)


def _check_is_univariate(y, var_name="input"):
Expand Down Expand Up @@ -185,7 +207,7 @@ def check_time_index(

# We here check for type equality because isinstance does not
# work reliably because index types inherit from each other.
if not type(index) in VALID_INDEX_TYPES:
if not is_in_valid_index_types(index):
raise NotImplementedError(
f"{type(index)} is not supported for {var_name}, use "
f"one of {VALID_INDEX_TYPES} instead."
Expand All @@ -194,7 +216,7 @@ def check_time_index(
if enforce_index_type and type(index) is not enforce_index_type:
raise NotImplementedError(
f"{type(index)} is not supported for {var_name}, use "
f"type: {enforce_index_type} instead."
f"type: {enforce_index_type} or integer pd.Index instead."
)

# Check time index is ordered in time
Expand Down Expand Up @@ -274,11 +296,6 @@ def check_equal_time_index(*ys, mode="equal"):
raise ValueError(msg)


def _is_int_index(index):
"""Check if index type is one of pd.RangeIndex or pd.Int64Index."""
return type(index) in (pd.Int64Index, pd.RangeIndex)


def check_consistent_index_type(a, b):
"""Check that two indices have consistent types.
Expand All @@ -299,8 +316,8 @@ def check_consistent_index_type(a, b):
"series have the same index type."
)

if _is_int_index(a):
if not _is_int_index(b):
if is_integer_index(a):
if not is_integer_index(b):
raise TypeError(msg)

else:
Expand Down
2 changes: 1 addition & 1 deletion sktime/utils/validation/tests/test_forecasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from sktime.utils.validation.forecasting import check_fh

empty_input = (np.array([]), [], pd.Int64Index([]))
empty_input = (np.array([], dtype=int), [], pd.Index([], dtype=int))


@pytest.mark.parametrize("arg", empty_input)
Expand Down

0 comments on commit 4b9003c

Please sign in to comment.