From 230f28c62a0f0a2874948c69641c41442dbd9555 Mon Sep 17 00:00:00 2001 From: ShvetsKS <33296480+ShvetsKS@users.noreply.github.com> Date: Tue, 15 Dec 2020 14:20:19 +0300 Subject: [PATCH] Fix handling of print period in EvaluationMonitor (#6499) Co-authored-by: Kirill Shvets --- python-package/xgboost/callback.py | 2 +- tests/python/test_callback.py | 25 +++++++++++++++---------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/python-package/xgboost/callback.py b/python-package/xgboost/callback.py index f3960c2c3ded..c9cdb04eaaa1 100644 --- a/python-package/xgboost/callback.py +++ b/python-package/xgboost/callback.py @@ -622,7 +622,7 @@ def after_iteration(self, model, epoch, evals_log): msg += self._fmt_metric(data, metric_name, score, stdv) msg += '\n' - if (epoch % self.period) != 0 or self.period == 1: + if (epoch % self.period) == 0 or self.period == 1: rabit.tracker_print(msg) self._latest = None else: diff --git a/tests/python/test_callback.py b/tests/python/test_callback.py index 8d4778d35f04..b8f2c0785e27 100644 --- a/tests/python/test_callback.py +++ b/tests/python/test_callback.py @@ -33,15 +33,18 @@ def run_evaluation_monitor(self, D_train, D_valid, rounds, verbose_eval): verbose_eval=verbose_eval) output: str = out.getvalue().strip() - pos = 0 - msg = 'Train-error' - for i in range(rounds // int(verbose_eval)): - pos = output.find('Train-error', pos) - assert pos != -1 - pos += len(msg) - - assert output.find('Train-error', pos) == -1 - + if int(verbose_eval) == 1: + # Should print each iteration info + assert len(output.split('\n')) == rounds + elif int(verbose_eval) > rounds: + # Should print first and latest iteration info + assert len(output.split('\n')) == 2 + else: + # Should print info by each period additionaly to first and latest iteration + num_periods = rounds // int(verbose_eval) + # Extra information is required for latest iteration + is_extra_info_required = num_periods * int(verbose_eval) < (rounds - 1) + assert len(output.split('\n')) == 1 + num_periods + int(is_extra_info_required) def test_evaluation_monitor(self): D_train = xgb.DMatrix(self.X_train, self.y_train) @@ -57,8 +60,10 @@ def test_evaluation_monitor(self): assert len(evals_result['Train']['error']) == rounds assert len(evals_result['Valid']['error']) == rounds - self.run_evaluation_monitor(D_train, D_valid, rounds, 2) self.run_evaluation_monitor(D_train, D_valid, rounds, True) + self.run_evaluation_monitor(D_train, D_valid, rounds, 2) + self.run_evaluation_monitor(D_train, D_valid, rounds, 4) + self.run_evaluation_monitor(D_train, D_valid, rounds, rounds + 1) def test_early_stopping(self): D_train = xgb.DMatrix(self.X_train, self.y_train)