Skip to content

Commit

Permalink
TST Relax test_gradient_boosting_early_stopping (scikit-learn#24541)
Browse files Browse the repository at this point in the history
  • Loading branch information
jjerphan authored and glemaitre committed Oct 31, 2022
1 parent 5e3e3ca commit 683544d
Showing 1 changed file with 28 additions and 17 deletions.
45 changes: 28 additions & 17 deletions sklearn/ensemble/tests/test_gradient_boosting.py
Expand Up @@ -1102,39 +1102,49 @@ def test_sparse_input(EstimatorClass, sparse_matrix):
assert_array_almost_equal(res_sparse, res)


def test_gradient_boosting_early_stopping():
@pytest.mark.parametrize(
"GradientBoostingEstimator", [GradientBoostingClassifier, GradientBoostingRegressor]
)
def test_gradient_boosting_early_stopping(GradientBoostingEstimator):
# Check if early stopping works as expected, that is empirically check that the
# number of trained estimators is increasing when the tolerance decreases.

X, y = make_classification(n_samples=1000, random_state=0)
n_estimators = 1000

gbc = GradientBoostingClassifier(
n_estimators=1000,
gb_large_tol = GradientBoostingEstimator(
n_estimators=n_estimators,
n_iter_no_change=10,
learning_rate=0.1,
max_depth=3,
random_state=42,
tol=1e-1,
)

gbr = GradientBoostingRegressor(
n_estimators=1000,
gb_small_tol = GradientBoostingEstimator(
n_estimators=n_estimators,
n_iter_no_change=10,
learning_rate=0.1,
max_depth=3,
random_state=42,
tol=1e-3,
)

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=42)
# Check if early_stopping works as expected
for est, tol, early_stop_n_estimators in (
(gbc, 1e-1, 28),
(gbr, 1e-1, 13),
(gbc, 1e-3, 70),
(gbr, 1e-3, 28),
):
est.set_params(tol=tol)
est.fit(X_train, y_train)
assert est.n_estimators_ == early_stop_n_estimators
assert est.score(X_test, y_test) > 0.7
gb_large_tol.fit(X_train, y_train)
gb_small_tol.fit(X_train, y_train)

assert gb_large_tol.n_estimators_ < gb_small_tol.n_estimators_ < n_estimators

assert gb_large_tol.score(X_test, y_test) > 0.7
assert gb_small_tol.score(X_test, y_test) > 0.7


def test_gradient_boosting_without_early_stopping():
# When early stopping is not used, the number of trained estimators
# must be the one specified.
X, y = make_classification(n_samples=1000, random_state=0)

# Without early stopping
gbc = GradientBoostingClassifier(
n_estimators=50, learning_rate=0.1, max_depth=3, random_state=42
)
Expand All @@ -1144,6 +1154,7 @@ def test_gradient_boosting_early_stopping():
)
gbr.fit(X, y)

# The number of trained estimators must be the one specified.
assert gbc.n_estimators_ == 50
assert gbr.n_estimators_ == 30

Expand Down

0 comments on commit 683544d

Please sign in to comment.