Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove param 'evals_result' 'early_stopping_rounds' in lightgbm version > 3.3.1 #5206

Merged
merged 5 commits into from Dec 31, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 9 additions & 7 deletions mlflow/lightgbm.py
Expand Up @@ -25,6 +25,7 @@
import logging
import functools
from copy import deepcopy
from packaging.version import Version

import mlflow
from mlflow import pyfunc
Expand Down Expand Up @@ -323,7 +324,8 @@ def autolog(

- parameters specified in `lightgbm.train`_.
- metrics on each iteration (if ``valid_sets`` specified).
- metrics at the best iteration (if ``early_stopping_rounds`` specified).
- metrics at the best iteration (if ``early_stopping_rounds`` specified or ``early_stopping``
callback is set).
- feature importance (both "split" and "gain") as JSON files and plots.
- trained model, including:
- an example of valid input.
Expand Down Expand Up @@ -447,10 +449,13 @@ def log_feature_importance_plot(features, importance, importance_type):
"fobj",
"feval",
"init_model",
"evals_result",
"learning_rates",
"callbacks",
]
if Version(lightgbm.__version__) <= Version("3.3.1"):
# The parameter `evals_result` in `lightgbm.train` is removed in this PR:
# https://github.com/microsoft/LightGBM/pull/4882
unlogged_params.append("evals_result")

params_to_log_for_fn = get_mlflow_run_params_for_fn_args(
original, args, kwargs, unlogged_params
Expand Down Expand Up @@ -482,12 +487,9 @@ def log_feature_importance_plot(features, importance, importance_type):
# training model
model = original(*args, **kwargs)

# If early_stopping_rounds is present, logging metrics at the best iteration
# If early stopping is activated, logging metrics at the best iteration
# as extra metrics with the max step + 1.
early_stopping_index = all_arg_names.index("early_stopping_rounds")
early_stopping = (
num_pos_args >= early_stopping_index + 1 or "early_stopping_rounds" in kwargs
)
early_stopping = model.best_iteration > 0
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model.best_iteration > 0 can be used to determine whether early stopping is activated.

model.best_iteration is initiated as 0 and is assigned with a positive int when EarlyStoppingException is raised (this means early stopping is activated, including using the early_stopping_rounds param, lightgbm.early_stopping callback function, or even a user-defined early_stopping callback that raises EarlyStoppingException).
model.best_iteration is returned as 0 when early stopping is not activated.

Reference: https://github.com/microsoft/LightGBM/blob/ce486e5b45a6f5e67743e14765ed139ff8d532e5/python-package/lightgbm/engine.py#L226-L263

if early_stopping:
extra_step = len(eval_results)
autologging_client.log_metrics(
Expand Down
175 changes: 123 additions & 52 deletions tests/lightgbm/test_lightgbm_autolog.py
Expand Up @@ -102,10 +102,13 @@ def test_lgb_autolog_logs_default_params(bst_params, train_set):
"fobj",
"feval",
"init_model",
"evals_result",
"learning_rates",
"callbacks",
]
if Version(lgb.__version__) <= Version("3.3.1"):
# The parameter `evals_result` in `lightgbm.train` is removed in this PR:
# https://github.com/microsoft/LightGBM/pull/4882
unlogged_params.append("evals_result")

for param in unlogged_params:
assert param not in params
Expand All @@ -116,12 +119,14 @@ def test_lgb_autolog_logs_specified_params(bst_params, train_set):
mlflow.lightgbm.autolog()
expected_params = {
"num_boost_round": 10,
"early_stopping_rounds": 5,
}
if Version(lgb.__version__) <= Version("3.3.1"):
# The parameter `verbose_eval` in `lightgbm.train` is removed in this PR:
# https://github.com/microsoft/LightGBM/pull/4878
# The parameter `early_stopping_rounds` in `lightgbm.train` is removed in this PR:
# https://github.com/microsoft/LightGBM/pull/4908
expected_params["verbose_eval"] = False
expected_params["early_stopping_rounds"] = 5
lgb.train(bst_params, train_set, valid_sets=[train_set], **expected_params)
run = get_latest_run()
params = run.data.params
Expand All @@ -140,10 +145,13 @@ def test_lgb_autolog_logs_specified_params(bst_params, train_set):
"fobj",
"feval",
"init_model",
"evals_result",
"learning_rates",
"callbacks",
]
if Version(lgb.__version__) <= Version("3.3.1"):
# The parameter `evals_result` in `lightgbm.train` is removed in this PR:
# https://github.com/microsoft/LightGBM/pull/4882
unlogged_params.append("evals_result")

for param in unlogged_params:
assert param not in params
Expand All @@ -153,14 +161,24 @@ def test_lgb_autolog_logs_specified_params(bst_params, train_set):
def test_lgb_autolog_logs_metrics_with_validation_data(bst_params, train_set):
mlflow.lightgbm.autolog()
evals_result = {}
lgb.train(
bst_params,
train_set,
num_boost_round=10,
valid_sets=[train_set],
valid_names=["train"],
evals_result=evals_result,
)
if Version(lgb.__version__) <= Version("3.3.1"):
lgb.train(
bst_params,
train_set,
num_boost_round=10,
valid_sets=[train_set],
valid_names=["train"],
evals_result=evals_result,
)
else:
lgb.train(
bst_params,
train_set,
num_boost_round=10,
valid_sets=[train_set],
valid_names=["train"],
callbacks=[lgb.record_evaluation(evals_result)],
)
run = get_latest_run()
data = run.data
client = mlflow.tracking.MlflowClient()
Expand All @@ -179,14 +197,24 @@ def test_lgb_autolog_logs_metrics_with_multi_validation_data(bst_params, train_s
# To avoid that, create a new Dataset object.
valid_sets = [train_set, lgb.Dataset(train_set.data)]
valid_names = ["train", "valid"]
lgb.train(
bst_params,
train_set,
num_boost_round=10,
valid_sets=valid_sets,
valid_names=valid_names,
evals_result=evals_result,
)
if Version(lgb.__version__) <= Version("3.3.1"):
lgb.train(
bst_params,
train_set,
num_boost_round=10,
valid_sets=valid_sets,
valid_names=valid_names,
evals_result=evals_result,
)
else:
lgb.train(
bst_params,
train_set,
num_boost_round=10,
valid_sets=valid_sets,
valid_names=valid_names,
callbacks=[lgb.record_evaluation(evals_result)],
)
run = get_latest_run()
data = run.data
client = mlflow.tracking.MlflowClient()
Expand All @@ -206,14 +234,24 @@ def test_lgb_autolog_logs_metrics_with_multi_metrics(bst_params, train_set):
params.update(bst_params)
valid_sets = [train_set]
valid_names = ["train"]
lgb.train(
params,
train_set,
num_boost_round=10,
valid_sets=valid_sets,
valid_names=valid_names,
evals_result=evals_result,
)
if Version(lgb.__version__) <= Version("3.3.1"):
lgb.train(
params,
train_set,
num_boost_round=10,
valid_sets=valid_sets,
valid_names=valid_names,
evals_result=evals_result,
)
else:
lgb.train(
params,
train_set,
num_boost_round=10,
valid_sets=valid_sets,
valid_names=valid_names,
callbacks=[lgb.record_evaluation(evals_result)],
)
run = get_latest_run()
data = run.data
client = mlflow.tracking.MlflowClient()
Expand All @@ -233,14 +271,24 @@ def test_lgb_autolog_logs_metrics_with_multi_validation_data_and_metrics(bst_par
params.update(bst_params)
valid_sets = [train_set, lgb.Dataset(train_set.data)]
valid_names = ["train", "valid"]
lgb.train(
params,
train_set,
num_boost_round=10,
valid_sets=valid_sets,
valid_names=valid_names,
evals_result=evals_result,
)
if Version(lgb.__version__) <= Version("3.3.1"):
lgb.train(
params,
train_set,
num_boost_round=10,
valid_sets=valid_sets,
valid_names=valid_names,
evals_result=evals_result,
)
else:
lgb.train(
params,
train_set,
num_boost_round=10,
valid_sets=valid_sets,
valid_names=valid_names,
callbacks=[lgb.record_evaluation(evals_result)],
)
run = get_latest_run()
data = run.data
client = mlflow.tracking.MlflowClient()
Expand Down Expand Up @@ -279,14 +327,24 @@ def record_metrics_side_effect(self, metrics, step=None):
params.update(bst_params)
valid_sets = [train_set, lgb.Dataset(train_set.data)]
valid_names = ["train", "valid"]
lgb.train(
params,
train_set,
num_boost_round=10,
valid_sets=valid_sets,
valid_names=valid_names,
evals_result=evals_result,
)
if Version(lgb.__version__) <= Version("3.3.1"):
lgb.train(
params,
train_set,
num_boost_round=10,
valid_sets=valid_sets,
valid_names=valid_names,
evals_result=evals_result,
)
else:
lgb.train(
params,
train_set,
num_boost_round=10,
valid_sets=valid_sets,
valid_names=valid_names,
callbacks=[lgb.record_evaluation(evals_result)],
)

run = get_latest_run()
original_metrics = run.data.metrics
Expand All @@ -307,15 +365,28 @@ def test_lgb_autolog_logs_metrics_with_early_stopping(bst_params, train_set):
params.update(bst_params)
valid_sets = [train_set, lgb.Dataset(train_set.data)]
valid_names = ["train", "valid"]
model = lgb.train(
params,
train_set,
num_boost_round=10,
early_stopping_rounds=5,
valid_sets=valid_sets,
valid_names=valid_names,
evals_result=evals_result,
)
if Version(lgb.__version__) <= Version("3.3.1"):
model = lgb.train(
params,
train_set,
num_boost_round=10,
early_stopping_rounds=5,
valid_sets=valid_sets,
valid_names=valid_names,
evals_result=evals_result,
)
else:
model = lgb.train(
params,
train_set,
num_boost_round=10,
valid_sets=valid_sets,
valid_names=valid_names,
callbacks=[
lgb.record_evaluation(evals_result),
lgb.early_stopping(5),
],
)
run = get_latest_run()
data = run.data
client = mlflow.tracking.MlflowClient()
Expand Down