From 152e622af3989a7a17d4d2bab7464aeb984952aa Mon Sep 17 00:00:00 2001 From: Tim Head Date: Fri, 25 Nov 2022 13:49:11 +0100 Subject: [PATCH] Small fixes --- .../_hist_gradient_boosting/gradient_boosting.py | 10 +++++++++- .../tests/test_gradient_boosting.py | 8 ++++++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py index 66bd2b1f11d65..ba44532077bc1 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py @@ -93,7 +93,12 @@ class BaseHistGradientBoosting(BaseEstimator, ABC): "min_samples_leaf": [Interval(Integral, 1, None, closed="left")], "l2_regularization": [Interval(Real, 0, None, closed="left")], "monotonic_cst": ["array-like", dict, None], - "interaction_cst": [list, tuple, StrOptions({"pairwise", "no_interactions"}), None], + "interaction_cst": [ + list, + tuple, + StrOptions({"pairwise", "no_interactions"}), + None, + ], "n_iter_no_change": [Interval(Integral, 1, None, closed="left")], "validation_fraction": [ Interval(Real, 0, 1, closed="neither"), @@ -295,6 +300,9 @@ def _check_interaction_cst(self, n_features): elif self.interaction_cst == "pairwise": interaction_cst = itertools.combinations(range(n_features), 2) + else: + interaction_cst = self.interaction_cst + try: constraints = [set(group) for group in interaction_cst] except TypeError: diff --git a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py index 6d8e01500a550..33c83de0959f7 100644 --- a/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/_hist_gradient_boosting/tests/test_gradient_boosting.py @@ -1209,9 +1209,13 @@ def test_check_interaction_cst(interaction_cst, n_features, result): for combination in itertools.product( (HistGradientBoostingRegressor, HistGradientBoostingClassifier), ( - ("no interactions", None), + ("no_interactions", None), ("pairwise", None), - ("pairwiseS", "not a valid interaction constraint"), + ( + "pairwiseS", + "a str among {'pairwise', 'no_interactions'} or None. Got" + " 'pairwiseS' instead.", + ), ), ) ],