Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove param 'evals_result' 'early_stopping_rounds' in lightgbm version > 3.3.1 #5206

Merged
merged 5 commits into from Dec 31, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
30 changes: 24 additions & 6 deletions mlflow/lightgbm.py
Expand Up @@ -25,6 +25,7 @@
import logging
import functools
from copy import deepcopy
from packaging.version import Version

import mlflow
from mlflow import pyfunc
Expand Down Expand Up @@ -323,7 +324,8 @@ def autolog(

- parameters specified in `lightgbm.train`_.
- metrics on each iteration (if ``valid_sets`` specified).
- metrics at the best iteration (if ``early_stopping_rounds`` specified).
- metrics at the best iteration (if ``early_stopping_rounds`` specified or ``early_stopping``
callback is set).
- feature importance (both "split" and "gain") as JSON files and plots.
- trained model, including:
- an example of valid input.
Expand Down Expand Up @@ -447,10 +449,13 @@ def log_feature_importance_plot(features, importance, importance_type):
"fobj",
"feval",
"init_model",
"evals_result",
"learning_rates",
"callbacks",
]
if Version(lightgbm.__version__) <= Version("3.3.1"):
# The parameter `evals_result` in `lightgbm.train` is removed in this PR:
# https://github.com/microsoft/LightGBM/pull/4882
unlogged_params.append("evals_result")

params_to_log_for_fn = get_mlflow_run_params_for_fn_args(
original, args, kwargs, unlogged_params
Expand Down Expand Up @@ -484,10 +489,23 @@ def log_feature_importance_plot(features, importance, importance_type):

# If early_stopping_rounds is present, logging metrics at the best iteration
# as extra metrics with the max step + 1.
early_stopping_index = all_arg_names.index("early_stopping_rounds")
early_stopping = (
num_pos_args >= early_stopping_index + 1 or "early_stopping_rounds" in kwargs
)
if Version(lightgbm.__version__) <= Version("3.3.1"):
# The parameter `early_stopping_rounds` in `lightgbm.train` is removed in this PR:
# https://github.com/microsoft/LightGBM/pull/4908
early_stopping_index = all_arg_names.index("early_stopping_rounds")
early_stopping = (
num_pos_args >= early_stopping_index + 1 or "early_stopping_rounds" in kwargs
)
else:
early_stopping = False
if "callbacks" in kwargs and kwargs["callbacks"] is not None:
for cb in kwargs["callbacks"]:
if (
hasattr(cb, "__qualname__")
and cb.__qualname__ == "early_stopping.<locals>._callback"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This checking is a little tricky , is there better way ? I think you want to check whether the callback is lgb.early_stopping(5) , can we directly compare the function object with it ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Directly comparing function object seems not working:

lightgbm.early_stopping(stopping_rounds=1) is lightgbm.early_stopping(stopping_rounds=1) # False
lightgbm.early_stopping(stopping_rounds=1) == lightgbm.early_stopping(stopping_rounds=1) # False

):
early_stopping = True
break
if early_stopping:
extra_step = len(eval_results)
autologging_client.log_metrics(
Expand Down
175 changes: 123 additions & 52 deletions tests/lightgbm/test_lightgbm_autolog.py
Expand Up @@ -102,10 +102,13 @@ def test_lgb_autolog_logs_default_params(bst_params, train_set):
"fobj",
"feval",
"init_model",
"evals_result",
"learning_rates",
"callbacks",
]
if Version(lgb.__version__) <= Version("3.3.1"):
# The parameter `evals_result` in `lightgbm.train` is removed in this PR:
# https://github.com/microsoft/LightGBM/pull/4882
unlogged_params.append("evals_result")

for param in unlogged_params:
assert param not in params
Expand All @@ -116,12 +119,14 @@ def test_lgb_autolog_logs_specified_params(bst_params, train_set):
mlflow.lightgbm.autolog()
expected_params = {
"num_boost_round": 10,
"early_stopping_rounds": 5,
}
if Version(lgb.__version__) <= Version("3.3.1"):
# The parameter `verbose_eval` in `lightgbm.train` is removed in this PR:
# https://github.com/microsoft/LightGBM/pull/4878
# The parameter `early_stopping_rounds` in `lightgbm.train` is removed in this PR:
# https://github.com/microsoft/LightGBM/pull/4908
expected_params["verbose_eval"] = False
expected_params["early_stopping_rounds"] = 5
lgb.train(bst_params, train_set, valid_sets=[train_set], **expected_params)
run = get_latest_run()
params = run.data.params
Expand All @@ -140,10 +145,13 @@ def test_lgb_autolog_logs_specified_params(bst_params, train_set):
"fobj",
"feval",
"init_model",
"evals_result",
"learning_rates",
"callbacks",
]
if Version(lgb.__version__) <= Version("3.3.1"):
# The parameter `evals_result` in `lightgbm.train` is removed in this PR:
# https://github.com/microsoft/LightGBM/pull/4882
unlogged_params.append("evals_result")

for param in unlogged_params:
assert param not in params
Expand All @@ -153,14 +161,24 @@ def test_lgb_autolog_logs_specified_params(bst_params, train_set):
def test_lgb_autolog_logs_metrics_with_validation_data(bst_params, train_set):
mlflow.lightgbm.autolog()
evals_result = {}
lgb.train(
bst_params,
train_set,
num_boost_round=10,
valid_sets=[train_set],
valid_names=["train"],
evals_result=evals_result,
)
if Version(lgb.__version__) <= Version("3.3.1"):
lgb.train(
bst_params,
train_set,
num_boost_round=10,
valid_sets=[train_set],
valid_names=["train"],
evals_result=evals_result,
)
else:
lgb.train(
bst_params,
train_set,
num_boost_round=10,
valid_sets=[train_set],
valid_names=["train"],
callbacks=[lgb.record_evaluation(evals_result)],
)
run = get_latest_run()
data = run.data
client = mlflow.tracking.MlflowClient()
Expand All @@ -179,14 +197,24 @@ def test_lgb_autolog_logs_metrics_with_multi_validation_data(bst_params, train_s
# To avoid that, create a new Dataset object.
valid_sets = [train_set, lgb.Dataset(train_set.data)]
valid_names = ["train", "valid"]
lgb.train(
bst_params,
train_set,
num_boost_round=10,
valid_sets=valid_sets,
valid_names=valid_names,
evals_result=evals_result,
)
if Version(lgb.__version__) <= Version("3.3.1"):
lgb.train(
bst_params,
train_set,
num_boost_round=10,
valid_sets=valid_sets,
valid_names=valid_names,
evals_result=evals_result,
)
else:
lgb.train(
bst_params,
train_set,
num_boost_round=10,
valid_sets=valid_sets,
valid_names=valid_names,
callbacks=[lgb.record_evaluation(evals_result)],
)
run = get_latest_run()
data = run.data
client = mlflow.tracking.MlflowClient()
Expand All @@ -206,14 +234,24 @@ def test_lgb_autolog_logs_metrics_with_multi_metrics(bst_params, train_set):
params.update(bst_params)
valid_sets = [train_set]
valid_names = ["train"]
lgb.train(
params,
train_set,
num_boost_round=10,
valid_sets=valid_sets,
valid_names=valid_names,
evals_result=evals_result,
)
if Version(lgb.__version__) <= Version("3.3.1"):
lgb.train(
params,
train_set,
num_boost_round=10,
valid_sets=valid_sets,
valid_names=valid_names,
evals_result=evals_result,
)
else:
lgb.train(
params,
train_set,
num_boost_round=10,
valid_sets=valid_sets,
valid_names=valid_names,
callbacks=[lgb.record_evaluation(evals_result)],
)
run = get_latest_run()
data = run.data
client = mlflow.tracking.MlflowClient()
Expand All @@ -233,14 +271,24 @@ def test_lgb_autolog_logs_metrics_with_multi_validation_data_and_metrics(bst_par
params.update(bst_params)
valid_sets = [train_set, lgb.Dataset(train_set.data)]
valid_names = ["train", "valid"]
lgb.train(
params,
train_set,
num_boost_round=10,
valid_sets=valid_sets,
valid_names=valid_names,
evals_result=evals_result,
)
if Version(lgb.__version__) <= Version("3.3.1"):
lgb.train(
params,
train_set,
num_boost_round=10,
valid_sets=valid_sets,
valid_names=valid_names,
evals_result=evals_result,
)
else:
lgb.train(
params,
train_set,
num_boost_round=10,
valid_sets=valid_sets,
valid_names=valid_names,
callbacks=[lgb.record_evaluation(evals_result)],
)
run = get_latest_run()
data = run.data
client = mlflow.tracking.MlflowClient()
Expand Down Expand Up @@ -279,14 +327,24 @@ def record_metrics_side_effect(self, metrics, step=None):
params.update(bst_params)
valid_sets = [train_set, lgb.Dataset(train_set.data)]
valid_names = ["train", "valid"]
lgb.train(
params,
train_set,
num_boost_round=10,
valid_sets=valid_sets,
valid_names=valid_names,
evals_result=evals_result,
)
if Version(lgb.__version__) <= Version("3.3.1"):
lgb.train(
params,
train_set,
num_boost_round=10,
valid_sets=valid_sets,
valid_names=valid_names,
evals_result=evals_result,
)
else:
lgb.train(
params,
train_set,
num_boost_round=10,
valid_sets=valid_sets,
valid_names=valid_names,
callbacks=[lgb.record_evaluation(evals_result)],
)

run = get_latest_run()
original_metrics = run.data.metrics
Expand All @@ -307,15 +365,28 @@ def test_lgb_autolog_logs_metrics_with_early_stopping(bst_params, train_set):
params.update(bst_params)
valid_sets = [train_set, lgb.Dataset(train_set.data)]
valid_names = ["train", "valid"]
model = lgb.train(
params,
train_set,
num_boost_round=10,
early_stopping_rounds=5,
valid_sets=valid_sets,
valid_names=valid_names,
evals_result=evals_result,
)
if Version(lgb.__version__) <= Version("3.3.1"):
model = lgb.train(
params,
train_set,
num_boost_round=10,
early_stopping_rounds=5,
valid_sets=valid_sets,
valid_names=valid_names,
evals_result=evals_result,
)
else:
model = lgb.train(
params,
train_set,
num_boost_round=10,
valid_sets=valid_sets,
valid_names=valid_names,
callbacks=[
lgb.record_evaluation(evals_result),
lgb.early_stopping(5),
],
)
run = get_latest_run()
data = run.data
client = mlflow.tracking.MlflowClient()
Expand Down