Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main' into plateau-max-length
Browse files Browse the repository at this point in the history
* upstream/main:
  [DOC] Added docstring examples to load data functions (sktime#2393)
  [ENH] Capability inference for transformer and classifier pipelines (sktime#2367)
  [ENH] Proba metric grid search integration (sktime#2234)
  [ENH] Faster classifier example parameters (sktime#2378)
  [ENH] Get rid of `pd.Int64Index` (sktime#2390)
  [ENH] Allow `pd.Timedelta` values in `ForecastingHorizon` (sktime#2333)
  [ENH] Partially replace `pd.Int64Index` with `pd.Index` (sktime#2339)
  relax name rules for multiindex (sktime#2384)
  • Loading branch information
srggrs committed Apr 6, 2022
2 parents 993ccf1 + d950ff6 commit 3695a21
Show file tree
Hide file tree
Showing 59 changed files with 1,423 additions and 253 deletions.
29 changes: 19 additions & 10 deletions extension_templates/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,25 +90,34 @@ class MyTransformer(BaseTransformer):
# y_inner_mtype must be changed to one or a list of compatible sktime mtypes
# the other tags are "safe defaults" which can usually be left as-is
_tags = {
# todo: what is the scitype of X: Series, or Panel
"scitype:transform-input": "Series",
# todo: what scitype is returned: Primitives, Series, Panel
# what is the scitype of X: Series, or Panel
"scitype:transform-output": "Series",
# todo: what is the scitype of y: None (not needed), Primitives, Series, Panel
# what scitype is returned: Primitives, Series, Panel
"scitype:transform-labels": "None",
# what is the scitype of y: None (not needed), Primitives, Series, Panel
"scitype:instancewise": True, # is this an instance-wise transform?
"capability:inverse_transform": False, # can the transformer inverse transform?
"univariate-only": False, # can the transformer handle multivariate X?
"X_inner_mtype": "pd.DataFrame", # which mtypes do _fit/_predict support for X?
# X_inner_mtype can be Panel mtype even if transform-input is Series, vectorized
# this can be a Panel mtype even if transform-input is Series, vectorized
"y_inner_mtype": "None", # which mtypes do _fit/_predict support for y?
"capability:inverse_transform": True, # does transformer have inverse transform
"skip-inverse-transform": False, # is inverse-transform skipped when called?
"univariate-only": False, # can the transformer handle multivariate X?
"handles-missing-data": False, # can estimator handle missing data?
"X-y-must-have-same-index": False, # can estimator handle different X/y index?
"requires_y": False, # does y need to be passed in fit?
"enforce_index_type": None, # index type that needs to be enforced in X/y
"fit_is_empty": False, # is fit empty and can be skipped? Yes = True
"fit_is_empty": True, # is fit empty and can be skipped? Yes = True
"X-y-must-have-same-index": False, # can estimator handle different X/y index?
"transform-returns-same-time-index": False,
# does transform return have the same time index as input X
"skip-inverse-transform": False, # is inverse-transform skipped when called?
"capability:unequal_length": True,
# can the transformer handle unequal length time series (if passed Panel)?
"capability:unequal_length:removes": False,
# is transform result always guaranteed to be equal length (and series)?
# not relevant for transformers that return Primitives in transform-output
"handles-missing-data": False, # can estimator handle missing data?
# todo: rename to capability:missing_values
"capability:missing_values:removes": False,
# is transform result always guaranteed to contain no missing values?
}
# in case of inheritance, concrete class should typically set tags
# alternatively, descendants can set tags in __init__
Expand Down
29 changes: 19 additions & 10 deletions extension_templates/transformer_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,25 +85,34 @@ class MyTransformer(BaseTransformer):
# y_inner_mtype must be changed to one or a list of compatible sktime mtypes
# the other tags are "safe defaults" which can usually be left as-is
_tags = {
# todo: what is the scitype of X: Series, or Panel
"scitype:transform-input": "Series",
# todo: what scitype is returned: Primitives, Series, Panel
# what is the scitype of X: Series, or Panel
"scitype:transform-output": "Series",
# todo: what is the scitype of y: None (not needed), Primitives, Series, Panel
# what scitype is returned: Primitives, Series, Panel
"scitype:transform-labels": "None",
# what is the scitype of y: None (not needed), Primitives, Series, Panel
"scitype:instancewise": True, # is this an instance-wise transform?
"capability:inverse_transform": False, # can the transformer inverse transform?
"univariate-only": False, # can the transformer handle multivariate X?
"X_inner_mtype": "pd.DataFrame", # which mtypes do _fit/_predict support for X?
# X_inner_mtype can be Panel mtype even if transform-input is Series, vectorized
# this can be a Panel mtype even if transform-input is Series, vectorized
"y_inner_mtype": "None", # which mtypes do _fit/_predict support for y?
"capability:inverse_transform": True, # does transformer have inverse transform
"skip-inverse-transform": False, # is inverse-transform skipped when called?
"univariate-only": False, # can the transformer handle multivariate X?
"handles-missing-data": False, # can estimator handle missing data?
"X-y-must-have-same-index": False, # can estimator handle different X/y index?
"requires_y": False, # does y need to be passed in fit?
"enforce_index_type": None, # index type that needs to be enforced in X/y
"fit-in-transform": False, # is fit empty and can be skipped? Yes = True
"fit_is_empty": True, # is fit empty and can be skipped? Yes = True
"X-y-must-have-same-index": False, # can estimator handle different X/y index?
"transform-returns-same-time-index": False,
# does transform return have the same time index as input X
"skip-inverse-transform": False, # is inverse-transform skipped when called?
"capability:unequal_length": True,
# can the transformer handle unequal length time series (if passed Panel)?
"capability:unequal_length:removes": False,
# is transform result always guaranteed to be equal length (and series)?
# not relevant for transformers that return Primitives in transform-output
"handles-missing-data": False, # can estimator handle missing data?
# todo: rename to capability:missing_values
"capability:missing_values:removes": False,
# is transform result always guaranteed to contain no missing values?
}

# todo: add any hyper-parameters and components to constructor
Expand Down
82 changes: 82 additions & 0 deletions sktime/base/_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,88 @@ def _anytag_notnone_set(self, tag_name, estimators):
if tag_val != "None":
self.set_tags(**{tag_name: tag_val})

def _tagchain_is_linked(
self,
left_tag_name,
mid_tag_name,
estimators,
left_tag_val=True,
mid_tag_val=True,
):
"""Check whether all tags left of the first mid_tag/val are left_tag/val.
Useful to check, for instance, whether all instances of estimators
left of the first missing value imputer can deal with missing values.
Parameters
----------
left_tag_name : str, name of the left tag
mid_tag_name : str, name of the middle tag
estimators : list of (str, estimator) pairs to query for the tag/value
left_tag_val : value of the left tag, optional, default=True
mid_tag_val : value of the middle tag, optional, default=True
Returns
-------
chain_is_linked : bool,
True iff all "left" tag instances `left_tag_name` have value `left_tag_val`
a "left" tag instance is an instance in estimators which is earlier
than the first occurrence of `mid_tag_name` with value `mid_tag_val`
chain_is_complete : bool,
True iff chain_is_linked is True, and
there is an occurrence of `mid_tag_name` with value `mid_tag_val`
"""
for _, est in estimators:
if est.get_tag(mid_tag_name) == mid_tag_val:
return True, True
if not est.get_tag(left_tag_name) == left_tag_val:
return False, False
return True, False

def _tagchain_is_linked_set(
self,
left_tag_name,
mid_tag_name,
estimators,
left_tag_val=True,
mid_tag_val=True,
left_tag_val_not=False,
mid_tag_val_not=False,
):
"""Check if _tagchain_is_linked, then set self left_tag_name and mid_tag_name.
Writes to self:
tag with name left_tag_name, sets to left_tag_val if _tag_chain_is_linked[0]
otherwise sets to left_tag_val_not
tag with name mid_tag_name, sets to mid_tag_val if _tag_chain_is_linked[1]
otherwise sets to mid_tag_val_not
Parameters
----------
left_tag_name : str, name of the left tag
mid_tag_name : str, name of the middle tag
estimators : list of (str, estimator) pairs to query for the tag/value
left_tag_val : value of the left tag, optional, default=True
mid_tag_val : value of the middle tag, optional, default=True
left_tag_val_not : value to set if not linked, optional, default=False
mid_tag_val_not : value to set if not linked, optional, default=False
"""
linked, complete = self._tagchain_is_linked(
left_tag_name=left_tag_name,
mid_tag_name=mid_tag_name,
estimators=estimators,
left_tag_val=left_tag_val,
mid_tag_val=mid_tag_val,
)
if linked:
self.set_tags(**{left_tag_name: left_tag_val})
else:
self.set_tags(**{left_tag_name: left_tag_val_not})
if complete:
self.set_tags(**{mid_tag_name: mid_tag_val})
else:
self.set_tags(**{mid_tag_name: mid_tag_val_not})


def flatten(obj):
"""Flatten nested list/tuple structure.
Expand Down
12 changes: 9 additions & 3 deletions sktime/classification/compose/_column_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,12 +222,18 @@ class ColumnEnsembleClassifier(BaseColumnEnsembleClassifier):
Examples
--------
>>> from sktime.classification.interval_based import DrCIF
>>> from sktime.classification.dictionary_based import ContractableBOSS
>>> from sktime.classification.interval_based import CanonicalIntervalForest
>>> from sktime.datasets import load_basic_motions
>>> X_train, y_train = load_basic_motions(split="train")
>>> X_test, y_test = load_basic_motions(split="test")
>>> clf = DrCIF(n_estimators=3)
>>> estimators = [("DrCIF", clf, [0, 1])]
>>> cboss = ContractableBOSS(
... n_parameter_samples=4, max_ensemble_size=2, random_state=0
... )
>>> cif = CanonicalIntervalForest(
... n_estimators=2, n_intervals=4, att_subsample_size=4, random_state=0
... )
>>> estimators = [("cBOSS", cboss, 5), ("CIF", cif, [3, 4])]
>>> col_ens = ColumnEnsembleClassifier(estimators=estimators)
>>> col_ens.fit(X_train, y_train)
ColumnEnsembleClassifier(...)
Expand Down
37 changes: 32 additions & 5 deletions sktime/classification/compose/_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,23 @@ class ClassifierPipeline(BaseClassifier, _HeterogenousMetaEstimator):
is always in (str, transformer) format, even if transformers is just a list
strings not passed in transformers are unique generated strings
i-th transformer in `transformers_` is clone of i-th in `transformers`
Examples
--------
>>> from sktime.transformations.panel.pca import PCATransformer
>>> from sktime.classification.interval_based import TimeSeriesForestClassifier
>>> from sktime.datasets import load_unit_test
>>> X_train, y_train = load_unit_test(split="train")
>>> X_test, y_test = load_unit_test(split="test")
>>> pipeline = ClassifierPipeline(
... TimeSeriesForestClassifier(n_estimators=5), [PCATransformer()]
... )
>>> pipeline.fit(X_train, y_train)
ClassifierPipeline(...)
>>> y_pred = pipeline.predict(X_test)
Alternative construction via dunder method:
>>> pipeline = PCATransformer() * TimeSeriesForestClassifier(n_estimators=5)
"""

_tags = {
Expand All @@ -95,17 +112,27 @@ def __init__(self, classifier, transformers):

super(ClassifierPipeline, self).__init__()

# can handle multivariate of both classifier and all transformers can
# can handle multivariate iff: both classifier and all transformers can
multivariate = classifier.get_tag("capability:multivariate", False)
multivariate = multivariate and not self.transformers_.get_tag(
"univariate-only", True
)
# can handle missing values if both classifier and all transformers can
# can handle missing values iff: both classifier and all transformers can,
# *or* transformer chain removes missing data
missing = classifier.get_tag("capability:missing_values", False)
missing = missing and self.transformer_.get_tag("handles-missing-data", False)
# can handle unequal length if classifier can
# transformers should always be able to, due to vectorization
missing = missing and self.transformers_.get_tag("handles-missing-data", False)
missing = missing or self.transformers_.get_tag(
"capability:missing_values:removes", False
)
# can handle unequal length iff: classifier can and transformers can,
# *or* transformer chain renders the series equal length
unequal = classifier.get_tag("capability:unequal_length")
unequal = unequal and self.transformers_.get_tag(
"capability:unequal_length", False
)
unequal = unequal or self.transformers_.get_tag(
"capability:unequal_length:removes", False
)
# last three tags are always False, since not supported by transformers
tags_to_set = {
"capability:multivariate": multivariate,
Expand Down
16 changes: 16 additions & 0 deletions sktime/classification/compose/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

from sktime.classification.compose import ClassifierPipeline
from sktime.classification.distance_based import KNeighborsTimeSeriesClassifier
from sktime.transformations.panel.padder import PaddingTransformer
from sktime.transformations.series.exponent import ExponentTransformer
from sktime.transformations.series.impute import Imputer
from sktime.utils._testing.estimator_checks import _assert_array_almost_equal
from sktime.utils._testing.panel import _make_classification_y, _make_panel_X

Expand Down Expand Up @@ -63,3 +65,17 @@ def test_mul_sklearn_autoadapt():

_assert_array_almost_equal(y_pred, t12c_2.fit(X, y).predict(X_test))
_assert_array_almost_equal(y_pred, t12c_3.fit(X, y).predict(X_test))


def test_missing_unequal_tag_inference():
"""Test that ClassifierPipeline infers missing/unequal tags correctly."""
c = KNeighborsTimeSeriesClassifier()
c1 = ExponentTransformer() * PaddingTransformer() * ExponentTransformer() * c
c2 = ExponentTransformer() * ExponentTransformer() * c
c3 = Imputer() * ExponentTransformer() * c
c4 = ExponentTransformer() * Imputer() * c

assert c1.get_tag("capability:unequal_length")
assert not c2.get_tag("capability:unequal_length")
assert c3.get_tag("capability:missing_values")
assert not c4.get_tag("capability:missing_values")
2 changes: 1 addition & 1 deletion sktime/classification/dictionary_based/_boss.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class BOSSEnsemble(BaseClassifier):
>>> from sktime.datasets import load_unit_test
>>> X_train, y_train = load_unit_test(split="train", return_X_y=True)
>>> X_test, y_test = load_unit_test(split="test", return_X_y=True)
>>> clf = BOSSEnsemble(max_ensemble_size=5)
>>> clf = BOSSEnsemble(max_ensemble_size=3)
>>> clf.fit(X_train, y_train)
BOSSEnsemble(...)
>>> y_pred = clf.predict(X_test)
Expand Down
2 changes: 1 addition & 1 deletion sktime/classification/dictionary_based/_cboss.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class ContractableBOSS(BaseClassifier):
>>> from sktime.datasets import load_unit_test
>>> X_train, y_train = load_unit_test(split="train", return_X_y=True)
>>> X_test, y_test = load_unit_test(split="test", return_X_y=True)
>>> clf = ContractableBOSS(n_parameter_samples=25, max_ensemble_size=5)
>>> clf = ContractableBOSS(n_parameter_samples=10, max_ensemble_size=3)
>>> clf.fit(X_train, y_train)
ContractableBOSS(...)
>>> y_pred = clf.predict(X_test)
Expand Down
2 changes: 1 addition & 1 deletion sktime/classification/dictionary_based/_tde.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ class TemporalDictionaryEnsemble(BaseClassifier):
>>> X_test, y_test = load_unit_test(split="test", return_X_y=True)
>>> clf = TemporalDictionaryEnsemble(
... n_parameter_samples=10,
... max_ensemble_size=5,
... max_ensemble_size=3,
... randomly_selected_params=5,
... )
>>> clf.fit(X_train, y_train)
Expand Down
8 changes: 4 additions & 4 deletions sktime/classification/distance_based/_proximity_forest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,7 +1038,7 @@ class value list
>>> from sktime.datasets import load_unit_test
>>> X_train, y_train = load_unit_test(split="train", return_X_y=True)
>>> X_test, y_test = load_unit_test(split="test", return_X_y=True)
>>> clf = ProximityTree()
>>> clf = ProximityTree(max_depth=2, n_stump_evaluations=1)
>>> clf.fit(X_train, y_train)
ProximityTree(...)
>>> y_pred = clf.predict(X_test)
Expand Down Expand Up @@ -1315,7 +1315,7 @@ class value list
>>> from sktime.datasets import load_unit_test
>>> X_train, y_train = load_unit_test(split="train", return_X_y=True)
>>> X_test, y_test = load_unit_test(split="test", return_X_y=True)
>>> clf = ProximityForest(n_estimators=5)
>>> clf = ProximityForest(n_estimators=2, max_depth=2, n_stump_evaluations=1)
>>> clf.fit(X_train, y_train)
ProximityForest(...)
>>> y_pred = clf.predict(X_test)
Expand Down Expand Up @@ -1717,10 +1717,10 @@ def _stdp(X):
for value in instance:
num_values += 1
sum += value
sum_sq += value ** 2 # todo missing values NaN messes
sum_sq += value**2 # todo missing values NaN messes
# this up!
mean = sum / num_values
stdp = np.math.sqrt(sum_sq / num_values - mean ** 2)
stdp = np.math.sqrt(sum_sq / num_values - mean**2)
return stdp


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class KNeighborsTimeSeriesClassifier(_KNeighborsClassifier, BaseClassifier):
>>> from sktime.datasets import load_unit_test
>>> X_train, y_train = load_unit_test(return_X_y=True, split="train")
>>> X_test, y_test = load_unit_test(return_X_y=True, split="test")
>>> classifier = KNeighborsTimeSeriesClassifier()
>>> classifier = KNeighborsTimeSeriesClassifier(distance="euclidean")
>>> classifier.fit(X_train, y_train)
KNeighborsTimeSeriesClassifier(...)
>>> y_pred = classifier.predict(X_test)
Expand Down
2 changes: 1 addition & 1 deletion sktime/classification/feature_based/_catch22_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ class Catch22Classifier(BaseClassifier):
>>> X_train, y_train = load_unit_test(split="train", return_X_y=True)
>>> X_test, y_test = load_unit_test(split="test", return_X_y=True)
>>> clf = Catch22Classifier(
... estimator=RandomForestClassifier(n_estimators=10),
... estimator=RandomForestClassifier(n_estimators=5),
... outlier_norm=True,
... )
>>> clf.fit(X_train, y_train)
Expand Down
2 changes: 1 addition & 1 deletion sktime/classification/feature_based/_fresh_prince.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ class FreshPRINCE(BaseClassifier):
>>> X_test, y_test = load_unit_test(split="test", return_X_y=True)
>>> clf = FreshPRINCE(
... default_fc_parameters="minimal",
... n_estimators=10,
... n_estimators=5,
... )
>>> clf.fit(X_train, y_train)
FreshPRINCE(...)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ class RandomIntervalClassifier(BaseClassifier):
>>> X_train, y_train = load_unit_test(split="train", return_X_y=True)
>>> X_test, y_test = load_unit_test(split="test", return_X_y=True)
>>> clf = RandomIntervalClassifier(
... n_intervals=5,
... estimator=RotationForest(n_estimators=10),
... n_intervals=3,
... estimator=RotationForest(n_estimators=5),
... )
>>> clf.fit(X_train, y_train)
RandomIntervalClassifier(...)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class SignatureClassifier(BaseClassifier):
>>> from sktime.datasets import load_unit_test
>>> X_train, y_train = load_unit_test(split="train", return_X_y=True)
>>> X_test, y_test = load_unit_test(split="test", return_X_y=True)
>>> clf = SignatureClassifier(estimator=RandomForestClassifier(n_estimators=10))
>>> clf = SignatureClassifier(estimator=RandomForestClassifier(n_estimators=5))
>>> clf.fit(X_train, y_train)
SignatureClassifier(...)
>>> y_pred = clf.predict(X_test)
Expand Down

0 comments on commit 3695a21

Please sign in to comment.