Skip to content

Commit

Permalink
Fix CV.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Oct 30, 2020
1 parent 40b5562 commit 5cb2d5e
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
12 changes: 10 additions & 2 deletions python-package/xgboost/callback.py
Expand Up @@ -365,14 +365,22 @@ def before_training(self, model):
'''Function called before training.'''
for c in self.callbacks:
model = c.before_training(model=model)
assert isinstance(model, Booster), 'before_training should return the Booster'
msg = 'before_training should return the model'
if self.is_cv:
assert isinstance(model.cvfolds, list), msg
else:
assert isinstance(model, Booster), msg
return model

def after_training(self, model):
'''Function called after training.'''
for c in self.callbacks:
model = c.after_training(model=model)
assert isinstance(model, Booster), 'after_training should return the Booster'
msg = 'after_training should return the model'
if self.is_cv:
assert isinstance(model.cvfolds, list), msg
else:
assert isinstance(model, Booster), msg
return model

def before_iteration(self, model, epoch, dtrain, evals):
Expand Down
6 changes: 4 additions & 2 deletions python-package/xgboost/training.py
Expand Up @@ -493,9 +493,8 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
verbose_eval, early_stopping_rounds, maximize, 0,
num_boost_round, feval, None, callbacks,
show_stdv=show_stdv, cvfolds=cvfolds)
callbacks.before_training(cvfolds)

booster = _PackedBooster(cvfolds)
callbacks.before_training(booster)

for i in range(num_boost_round):
if callbacks.before_iteration(booster, i, dtrain, None):
Expand All @@ -522,4 +521,7 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
results = pd.DataFrame.from_dict(results)
except ImportError:
pass

callbacks.after_training(booster)

return results

0 comments on commit 5cb2d5e

Please sign in to comment.