Skip to content

Commit

Permalink
FEA SLEP006: Metadata routing for learning_curve (#28975)
Browse files Browse the repository at this point in the history
  • Loading branch information
StefanieSenger committed May 17, 2024
1 parent b461547 commit 77fc72c
Show file tree
Hide file tree
Showing 6 changed files with 182 additions and 59 deletions.
4 changes: 2 additions & 2 deletions doc/metadata_routing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ Meta-estimators and functions supporting metadata routing:
- :class:`sklearn.linear_model.LogisticRegressionCV`
- :class:`sklearn.linear_model.MultiTaskElasticNetCV`
- :class:`sklearn.linear_model.MultiTaskLassoCV`
- :class:`sklearn.linear_model.OrthogonalMatchingPursuitCV`
- :class:`sklearn.linear_model.RANSACRegressor`
- :class:`sklearn.linear_model.RidgeClassifierCV`
- :class:`sklearn.linear_model.RidgeCV`
Expand All @@ -302,13 +303,13 @@ Meta-estimators and functions supporting metadata routing:
- :func:`sklearn.model_selection.cross_validate`
- :func:`sklearn.model_selection.cross_val_score`
- :func:`sklearn.model_selection.cross_val_predict`
- :class:`sklearn.model_selection.learning_curve`
- :class:`sklearn.multiclass.OneVsOneClassifier`
- :class:`sklearn.multiclass.OneVsRestClassifier`
- :class:`sklearn.multiclass.OutputCodeClassifier`
- :class:`sklearn.multioutput.ClassifierChain`
- :class:`sklearn.multioutput.MultiOutputClassifier`
- :class:`sklearn.multioutput.MultiOutputRegressor`
- :class:`sklearn.linear_model.OrthogonalMatchingPursuitCV`
- :class:`sklearn.multioutput.RegressorChain`
- :class:`sklearn.pipeline.FeatureUnion`
- :class:`sklearn.pipeline.Pipeline`
Expand All @@ -321,7 +322,6 @@ Meta-estimators and tools not supporting metadata routing yet:
- :class:`sklearn.feature_selection.RFE`
- :class:`sklearn.feature_selection.RFECV`
- :class:`sklearn.feature_selection.SequentialFeatureSelector`
- :class:`sklearn.model_selection.learning_curve`
- :class:`sklearn.model_selection.permutation_test_score`
- :class:`sklearn.model_selection.validation_curve`
- :class:`sklearn.semi_supervised.SelfTrainingClassifier`
4 changes: 4 additions & 0 deletions doc/whats_new/v1.6.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ The following models now support metadata routing in one or more of their
methods. Refer to the :ref:`Metadata Routing User Guide <metadata_routing>` for
more details.

- |Feature| :func:`model_selection.learning_curve` now supports metadata routing for the
`fit` method of its estimator and for its underlying CV splitter and scorer.
:pr:`28975` by :user:`Stefanie Senger <StefanieSenger>`.

- |Feature| :class:`ensemble.StackingClassifier` and
:class:`ensemble.StackingRegressor` now support metadata routing and pass
``**fit_params`` to the underlying estimators via their `fit` methods.
Expand Down
133 changes: 106 additions & 27 deletions sklearn/model_selection/_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,22 @@
]


def _check_params_groups_deprecation(fit_params, params, groups):
def _check_params_groups_deprecation(fit_params, params, groups, version):
"""A helper function to check deprecations on `groups` and `fit_params`.
To be removed when set_config(enable_metadata_routing=False) is not possible.
# TODO(SLEP6): To be removed when set_config(enable_metadata_routing=False) is not
# possible.
"""
if params is not None and fit_params is not None:
raise ValueError(
"`params` and `fit_params` cannot both be provided. Pass parameters "
"via `params`. `fit_params` is deprecated and will be removed in "
"version 1.6."
f"version {version}."
)
elif fit_params is not None:
warnings.warn(
(
"`fit_params` is deprecated and will be removed in version 1.6. "
"`fit_params` is deprecated and will be removed in version {version}. "
"Pass parameters via `params` instead."
),
FutureWarning,
Expand Down Expand Up @@ -346,7 +347,7 @@ def cross_validate(
>>> print(scores['train_r2'])
[0.28009951 0.3908844 0.22784907]
"""
params = _check_params_groups_deprecation(fit_params, params, groups)
params = _check_params_groups_deprecation(fit_params, params, groups, "1.6")

X, y = indexable(X, y)

Expand Down Expand Up @@ -602,10 +603,8 @@ def cross_val_score(
``cross_val_score(..., params={'groups': groups})``.
scoring : str or callable, default=None
A str (see model evaluation documentation) or
a scorer callable object / function with signature
``scorer(estimator, X, y)`` which should return only
a single value.
A str (see :ref:`scoring_parameter`) or a scorer callable object / function with
signature ``scorer(estimator, X, y)`` which should return only a single value.
Similar to :func:`cross_validate`
but only a single metric is permitted.
Expand Down Expand Up @@ -1206,7 +1205,7 @@ def cross_val_predict(
>>> lasso = linear_model.Lasso()
>>> y_pred = cross_val_predict(lasso, X, y, cv=3)
"""
params = _check_params_groups_deprecation(fit_params, params, groups)
params = _check_params_groups_deprecation(fit_params, params, groups, "1.6")
X, y = indexable(X, y)

if _routing_enabled():
Expand Down Expand Up @@ -1718,6 +1717,7 @@ def _shuffle(y, groups, random_state):
"error_score": [StrOptions({"raise"}), Real],
"return_times": ["boolean"],
"fit_params": [dict, None],
"params": [dict, None],
},
prefer_skip_nested_validation=False, # estimator is not validated yet
)
Expand All @@ -1739,6 +1739,7 @@ def learning_curve(
error_score=np.nan,
return_times=False,
fit_params=None,
params=None,
):
"""Learning curve.
Expand Down Expand Up @@ -1773,14 +1774,21 @@ def learning_curve(
train/test set. Only used in conjunction with a "Group" :term:`cv`
instance (e.g., :class:`GroupKFold`).
.. versionchanged:: 1.6
``groups`` can only be passed if metadata routing is not enabled
via ``sklearn.set_config(enable_metadata_routing=True)``. When routing
is enabled, pass ``groups`` alongside other metadata via the ``params``
argument instead. E.g.:
``learning_curve(..., params={'groups': groups})``.
train_sizes : array-like of shape (n_ticks,), \
default=np.linspace(0.1, 1.0, 5)
Relative or absolute numbers of training examples that will be used to
generate the learning curve. If the dtype is float, it is regarded as a
fraction of the maximum size of the training set (that is determined
by the selected validation method), i.e. it has to be within (0, 1].
Otherwise it is interpreted as absolute sizes of the training sets.
Note that for classification the number of samples usually have to
Note that for classification the number of samples usually has to
be big enough to contain at least one sample from each class.
cv : int, cross-validation generator or an iterable, default=None
Expand All @@ -1804,9 +1812,8 @@ def learning_curve(
``cv`` default value if None changed from 3-fold to 5-fold.
scoring : str or callable, default=None
A str (see model evaluation documentation) or
a scorer callable object / function with signature
``scorer(estimator, X, y)``.
A str (see :ref:`scoring_parameter`) or a scorer callable object / function with
signature ``scorer(estimator, X, y)``.
exploit_incremental_learning : bool, default=False
If the estimator supports incremental learning, this will be
Expand Down Expand Up @@ -1849,7 +1856,22 @@ def learning_curve(
fit_params : dict, default=None
Parameters to pass to the fit method of the estimator.
.. versionadded:: 0.24
.. deprecated:: 1.6
This parameter is deprecated and will be removed in version 1.6. Use
``params`` instead.
params : dict, default=None
Parameters to pass to the `fit` method of the estimator and to the scorer.
- If `enable_metadata_routing=False` (default):
Parameters directly passed to the `fit` method of the estimator.
- If `enable_metadata_routing=True`:
Parameters safely routed to the `fit` method of the estimator.
See :ref:`Metadata Routing User Guide <metadata_routing>` for more
details.
.. versionadded:: 1.6
Returns
-------
Expand Down Expand Up @@ -1903,14 +1925,69 @@ def learning_curve(
"An estimator must support the partial_fit interface "
"to exploit incremental learning"
)

params = _check_params_groups_deprecation(fit_params, params, groups, "1.8")

X, y, groups = indexable(X, y, groups)

cv = check_cv(cv, y, classifier=is_classifier(estimator))
# Store it as list as we will be iterating over the list multiple times
cv_iter = list(cv.split(X, y, groups))

scorer = check_scoring(estimator, scoring=scoring)

if _routing_enabled():
router = (
MetadataRouter(owner="learning_curve")
.add(
estimator=estimator,
# TODO(SLEP6): also pass metadata to the predict method for
# scoring?
method_mapping=MethodMapping()
.add(caller="fit", callee="fit")
.add(caller="fit", callee="partial_fit"),
)
.add(
splitter=cv,
method_mapping=MethodMapping().add(caller="fit", callee="split"),
)
.add(
scorer=scorer,
method_mapping=MethodMapping().add(caller="fit", callee="score"),
)
)

try:
routed_params = process_routing(router, "fit", **params)
except UnsetMetadataPassedError as e:
# The default exception would mention `fit` since in the above
# `process_routing` code, we pass `fit` as the caller. However,
# the user is not calling `fit` directly, so we change the message
# to make it more suitable for this case.
unrequested_params = sorted(e.unrequested_params)
raise UnsetMetadataPassedError(
message=(
f"{unrequested_params} are passed to `learning_curve` but are not"
" explicitly set as requested or not requested for learning_curve's"
f" estimator: {estimator.__class__.__name__}. Call"
" `.set_fit_request({{metadata}}=True)` on the estimator for"
f" each metadata in {unrequested_params} that you"
" want to use and `metadata=False` for not using it. See the"
" Metadata Routing User guide"
" <https://scikit-learn.org/stable/metadata_routing.html> for more"
" information."
),
unrequested_params=e.unrequested_params,
routed_params=e.routed_params,
)

else:
routed_params = Bunch()
routed_params.estimator = Bunch(fit=params, partial_fit=params)
routed_params.splitter = Bunch(split={"groups": groups})
routed_params.scorer = Bunch(score={})

# Store cv as list as we will be iterating over the list multiple times
cv_iter = list(cv.split(X, y, **routed_params.splitter.split))

n_max_training_samples = len(cv_iter[0][0])
# Because the lengths of folds can be significantly different, it is
# not guaranteed that we use all of the available training data when we
Expand Down Expand Up @@ -1940,7 +2017,8 @@ def learning_curve(
scorer,
return_times,
error_score=error_score,
fit_params=fit_params,
fit_params=routed_params.estimator.partial_fit,
score_params=routed_params.scorer.score,
)
for train, test in cv_iter
)
Expand All @@ -1961,9 +2039,8 @@ def learning_curve(
test=test,
verbose=verbose,
parameters=None,
fit_params=fit_params,
# TODO(SLEP6): support score params here
score_params=None,
fit_params=routed_params.estimator.fit,
score_params=routed_params.scorer.score,
return_train_score=True,
error_score=error_score,
return_times=return_times,
Expand Down Expand Up @@ -2069,6 +2146,7 @@ def _incremental_fit_estimator(
return_times,
error_score,
fit_params,
score_params,
):
"""Train estimator on training subsets incrementally and compute scores."""
train_scores, test_scores, fit_times, score_times = [], [], [], []
Expand All @@ -2079,6 +2157,9 @@ def _incremental_fit_estimator(
partial_fit_func = partial(estimator.partial_fit, **fit_params)
else:
partial_fit_func = partial(estimator.partial_fit, classes=classes, **fit_params)
score_params = score_params if score_params is not None else {}
score_params_train = _check_method_params(X, params=score_params, indices=train)
score_params_test = _check_method_params(X, params=score_params, indices=test)

for n_train_samples, partial_train in partitions:
train_subset = train[:n_train_samples]
Expand All @@ -2095,14 +2176,13 @@ def _incremental_fit_estimator(

start_score = time.time()

# TODO(SLEP6): support score params in the following two calls
test_scores.append(
_score(
estimator,
X_test,
y_test,
scorer,
score_params=None,
score_params=score_params_test,
error_score=error_score,
)
)
Expand All @@ -2112,7 +2192,7 @@ def _incremental_fit_estimator(
X_train,
y_train,
scorer,
score_params=None,
score_params=score_params_train,
error_score=error_score,
)
)
Expand Down Expand Up @@ -2220,9 +2300,8 @@ def validation_curve(
``cv`` default value if None changed from 3-fold to 5-fold.
scoring : str or callable, default=None
A str (see model evaluation documentation) or
a scorer callable object / function with signature
``scorer(estimator, X, y)``.
A str (see :ref:`scoring_parameter`) or a scorer callable object / function with
signature ``scorer(estimator, X, y)``.
n_jobs : int, default=None
Number of jobs to run in parallel. Training the estimator and computing
Expand Down

0 comments on commit 77fc72c

Please sign in to comment.