Skip to content

Commit

Permalink
Fix handling of print period in EvaluationMonitor (dmlc#6499)
Browse files Browse the repository at this point in the history
Co-authored-by: Kirill Shvets <kirill.shvets@intel.com>
  • Loading branch information
2 people authored and trivialfis committed Dec 20, 2020
1 parent bce7ca3 commit 230f28c
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
2 changes: 1 addition & 1 deletion python-package/xgboost/callback.py
Expand Up @@ -622,7 +622,7 @@ def after_iteration(self, model, epoch, evals_log):
msg += self._fmt_metric(data, metric_name, score, stdv)
msg += '\n'

if (epoch % self.period) != 0 or self.period == 1:
if (epoch % self.period) == 0 or self.period == 1:
rabit.tracker_print(msg)
self._latest = None
else:
Expand Down
25 changes: 15 additions & 10 deletions tests/python/test_callback.py
Expand Up @@ -33,15 +33,18 @@ def run_evaluation_monitor(self, D_train, D_valid, rounds, verbose_eval):
verbose_eval=verbose_eval)
output: str = out.getvalue().strip()

pos = 0
msg = 'Train-error'
for i in range(rounds // int(verbose_eval)):
pos = output.find('Train-error', pos)
assert pos != -1
pos += len(msg)

assert output.find('Train-error', pos) == -1

if int(verbose_eval) == 1:
# Should print each iteration info
assert len(output.split('\n')) == rounds
elif int(verbose_eval) > rounds:
# Should print first and latest iteration info
assert len(output.split('\n')) == 2
else:
# Should print info by each period additionaly to first and latest iteration
num_periods = rounds // int(verbose_eval)
# Extra information is required for latest iteration
is_extra_info_required = num_periods * int(verbose_eval) < (rounds - 1)
assert len(output.split('\n')) == 1 + num_periods + int(is_extra_info_required)

def test_evaluation_monitor(self):
D_train = xgb.DMatrix(self.X_train, self.y_train)
Expand All @@ -57,8 +60,10 @@ def test_evaluation_monitor(self):
assert len(evals_result['Train']['error']) == rounds
assert len(evals_result['Valid']['error']) == rounds

self.run_evaluation_monitor(D_train, D_valid, rounds, 2)
self.run_evaluation_monitor(D_train, D_valid, rounds, True)
self.run_evaluation_monitor(D_train, D_valid, rounds, 2)
self.run_evaluation_monitor(D_train, D_valid, rounds, 4)
self.run_evaluation_monitor(D_train, D_valid, rounds, rounds + 1)

def test_early_stopping(self):
D_train = xgb.DMatrix(self.X_train, self.y_train)
Expand Down

0 comments on commit 230f28c

Please sign in to comment.