Skip to content

Commit

Permalink
Fix Python callback. (#6320)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Oct 29, 2020
1 parent b181a88 commit 6ff331b
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
4 changes: 3 additions & 1 deletion python-package/xgboost/training.py
Expand Up @@ -3,6 +3,8 @@
# pylint: disable=too-many-branches, too-many-statements
"""Training Library containing training routines."""
import warnings
import copy

import numpy as np
from .core import Booster, XGBoostError
from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold)
Expand Down Expand Up @@ -57,7 +59,7 @@ def _train_internal(params, dtrain,
evals_result=None, maximize=None,
verbose_eval=None, early_stopping_rounds=None):
"""internal training function"""
callbacks = [] if callbacks is None else callbacks
callbacks = [] if callbacks is None else copy.copy(callbacks)
evals = list(evals)
params = _configure_metrics(params.copy())

Expand Down
13 changes: 13 additions & 0 deletions tests/python/test_callback.py
Expand Up @@ -232,3 +232,16 @@ def test_check_point(self):
for i in range(1, 10):
assert os.path.exists(
os.path.join(tmpdir, 'model_' + str(i) + '.pkl'))

def test_callback_list(self):
X, y = tm.get_boston()
m = xgb.DMatrix(X, y)
callbacks = [xgb.callback.EarlyStopping(rounds=10)]
for i in range(4):
xgb.train({'objective': 'reg:squarederror',
'eval_metric': 'rmse'}, m,
evals=[(m, 'Train')],
num_boost_round=1,
verbose_eval=True,
callbacks=callbacks)
assert len(callbacks) == 1

0 comments on commit 6ff331b

Please sign in to comment.