From 5fbb2f2ea0e2d59e9102e2a2ac7c854691561f59 Mon Sep 17 00:00:00 2001 From: fis Date: Wed, 28 Oct 2020 18:04:41 +0800 Subject: [PATCH] Fix CV. --- python-package/xgboost/callback.py | 12 ++++++++++-- python-package/xgboost/training.py | 6 ++++-- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/python-package/xgboost/callback.py b/python-package/xgboost/callback.py index 3588d62dc368..ddd9710ea9be 100644 --- a/python-package/xgboost/callback.py +++ b/python-package/xgboost/callback.py @@ -365,14 +365,22 @@ def before_training(self, model): '''Function called before training.''' for c in self.callbacks: model = c.before_training(model=model) - assert isinstance(model, Booster), 'before_training should return the Booster' + msg = 'before_training should return the model' + if self.is_cv: + assert isinstance(model.cvfolds, list), msg + else: + assert isinstance(model, Booster), msg return model def after_training(self, model): '''Function called after training.''' for c in self.callbacks: model = c.after_training(model=model) - assert isinstance(model, Booster), 'after_training should return the Booster' + msg = 'after_training should return the model' + if self.is_cv: + assert isinstance(model.cvfolds, list), msg + else: + assert isinstance(model, Booster), msg return model def before_iteration(self, model, epoch, dtrain, evals): diff --git a/python-package/xgboost/training.py b/python-package/xgboost/training.py index ef280dba4786..7ca5922905dd 100644 --- a/python-package/xgboost/training.py +++ b/python-package/xgboost/training.py @@ -493,9 +493,8 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None verbose_eval, early_stopping_rounds, maximize, 0, num_boost_round, feval, None, callbacks, show_stdv=show_stdv, cvfolds=cvfolds) - callbacks.before_training(cvfolds) - booster = _PackedBooster(cvfolds) + callbacks.before_training(booster) for i in range(num_boost_round): if callbacks.before_iteration(booster, i, dtrain, None): @@ -522,4 +521,7 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None results = pd.DataFrame.from_dict(results) except ImportError: pass + + callbacks.after_training(booster) + return results