From ec50033bfd87fcc32f6a4be665665514cd6d8561 Mon Sep 17 00:00:00 2001 From: fis Date: Thu, 17 Dec 2020 20:07:51 +0800 Subject: [PATCH] Fix. --- python-package/xgboost/callback.py | 4 +++- tests/python/test_callback.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python-package/xgboost/callback.py b/python-package/xgboost/callback.py index e15bf699febb..f3960c2c3ded 100644 --- a/python-package/xgboost/callback.py +++ b/python-package/xgboost/callback.py @@ -456,6 +456,7 @@ def __init__(self, learning_rates): def after_iteration(self, model, epoch, evals_log): model.set_param('learning_rate', self.learning_rates(epoch)) + return False # pylint: disable=too-many-instance-attributes @@ -565,7 +566,7 @@ def after_iteration(self, model: Booster, epoch, evals_log): def after_training(self, model: Booster): try: if self.save_best: - model = model[: int(model.attr('best_iteration'))] + model = model[: int(model.attr('best_iteration')) + 1] except XGBoostError as e: raise XGBoostError('`save_best` is not applicable to current booster') from e return model @@ -677,6 +678,7 @@ def after_iteration(self, model, epoch, evals_log): else: model.save_model(path) self._epoch += 1 + return False class LegacyCallbacks: diff --git a/tests/python/test_callback.py b/tests/python/test_callback.py index bdae94f87390..8d4778d35f04 100644 --- a/tests/python/test_callback.py +++ b/tests/python/test_callback.py @@ -148,7 +148,7 @@ def test_early_stopping_save_best_model(self): eval_metric=tm.eval_error_metric, callbacks=[early_stop]) booster = cls.get_booster() dump = booster.get_dump(dump_format='json') - assert len(dump) == booster.best_iteration + assert len(dump) == booster.best_iteration + 1 early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds, save_best=True)