Skip to content

Commit

Permalink
Fix period in evaluation monitor. (#6441)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored and hcho3 committed Dec 4, 2020
1 parent 8a0db29 commit a2c778e
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 16 deletions.
2 changes: 1 addition & 1 deletion python-package/xgboost/callback.py
Expand Up @@ -621,7 +621,7 @@ def after_iteration(self, model, epoch, evals_log):
msg += self._fmt_metric(data, metric_name, score, stdv)
msg += '\n'

if (epoch % self.period) != 0:
if (epoch % self.period) != 0 or self.period == 1:
rabit.tracker_print(msg)
self._latest = None
else:
Expand Down
36 changes: 21 additions & 15 deletions tests/python/test_callback.py
Expand Up @@ -22,38 +22,44 @@ def setup_class(cls):
cls.X_valid = X[split:, ...]
cls.y_valid = y[split:, ...]

def test_evaluation_monitor(self):
D_train = xgb.DMatrix(self.X_train, self.y_train)
D_valid = xgb.DMatrix(self.X_valid, self.y_valid)
def run_evaluation_monitor(self, D_train, D_valid, rounds, verbose_eval):
evals_result = {}
rounds = 10
xgb.train({'objective': 'binary:logistic',
'eval_metric': 'error'}, D_train,
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
num_boost_round=rounds,
evals_result=evals_result,
verbose_eval=True)
assert len(evals_result['Train']['error']) == rounds
assert len(evals_result['Valid']['error']) == rounds

with tm.captured_output() as (out, err):
xgb.train({'objective': 'binary:logistic',
'eval_metric': 'error'}, D_train,
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
num_boost_round=rounds,
evals_result=evals_result,
verbose_eval=2)
verbose_eval=verbose_eval)
output: str = out.getvalue().strip()

pos = 0
msg = 'Train-error'
for i in range(rounds // 2):
for i in range(rounds // int(verbose_eval)):
pos = output.find('Train-error', pos)
assert pos != -1
pos += len(msg)

assert output.find('Train-error', pos) == -1


def test_evaluation_monitor(self):
D_train = xgb.DMatrix(self.X_train, self.y_train)
D_valid = xgb.DMatrix(self.X_valid, self.y_valid)
evals_result = {}
rounds = 10
xgb.train({'objective': 'binary:logistic',
'eval_metric': 'error'}, D_train,
evals=[(D_train, 'Train'), (D_valid, 'Valid')],
num_boost_round=rounds,
evals_result=evals_result,
verbose_eval=True)
assert len(evals_result['Train']['error']) == rounds
assert len(evals_result['Valid']['error']) == rounds

self.run_evaluation_monitor(D_train, D_valid, rounds, 2)
self.run_evaluation_monitor(D_train, D_valid, rounds, True)

def test_early_stopping(self):
D_train = xgb.DMatrix(self.X_train, self.y_train)
D_valid = xgb.DMatrix(self.X_valid, self.y_valid)
Expand Down

0 comments on commit a2c778e

Please sign in to comment.