Skip to content

Commit

Permalink
Small fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
betatim committed Nov 25, 2022
1 parent b990599 commit 152e622
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
10 changes: 9 additions & 1 deletion sklearn/ensemble/_hist_gradient_boosting/gradient_boosting.py
Expand Up @@ -93,7 +93,12 @@ class BaseHistGradientBoosting(BaseEstimator, ABC):
"min_samples_leaf": [Interval(Integral, 1, None, closed="left")],
"l2_regularization": [Interval(Real, 0, None, closed="left")],
"monotonic_cst": ["array-like", dict, None],
"interaction_cst": [list, tuple, StrOptions({"pairwise", "no_interactions"}), None],
"interaction_cst": [
list,
tuple,
StrOptions({"pairwise", "no_interactions"}),
None,
],
"n_iter_no_change": [Interval(Integral, 1, None, closed="left")],
"validation_fraction": [
Interval(Real, 0, 1, closed="neither"),
Expand Down Expand Up @@ -295,6 +300,9 @@ def _check_interaction_cst(self, n_features):
elif self.interaction_cst == "pairwise":
interaction_cst = itertools.combinations(range(n_features), 2)

else:
interaction_cst = self.interaction_cst

try:
constraints = [set(group) for group in interaction_cst]
except TypeError:
Expand Down
Expand Up @@ -1209,9 +1209,13 @@ def test_check_interaction_cst(interaction_cst, n_features, result):
for combination in itertools.product(
(HistGradientBoostingRegressor, HistGradientBoostingClassifier),
(
("no interactions", None),
("no_interactions", None),
("pairwise", None),
("pairwiseS", "not a valid interaction constraint"),
(
"pairwiseS",
"a str among {'pairwise', 'no_interactions'} or None. Got"
" 'pairwiseS' instead.",
),
),
)
],
Expand Down

0 comments on commit 152e622

Please sign in to comment.