diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index dcc02fae2748..d07ae685f77c 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -614,6 +614,15 @@ XGB_DLL int XGBoosterSlice(BoosterHandle handle, int begin_layer, int end_layer, int step, BoosterHandle *out); +/*! + * \brief Get number of boosted rounds from gradient booster. When process_type is + * update, this number might drop due to removed tree. + * \param handle Handle to booster. + * \param out Pointer to output integer. + * \return 0 when success, -1 when failure happens + */ +XGB_DLL int XGBoosterBoostedRounds(BoosterHandle handle, int* out); + /*! * \brief set parameters * \param handle handle diff --git a/include/xgboost/gbm.h b/include/xgboost/gbm.h index 389593d7f098..9de6e04f20fb 100644 --- a/include/xgboost/gbm.h +++ b/include/xgboost/gbm.h @@ -79,6 +79,9 @@ class GradientBooster : public Model, public Configurable { virtual bool AllowLazyCheckPoint() const { return false; } + /*! \brief Return number of boosted rounds. + */ + virtual int32_t BoostedRounds() const = 0; /*! * \brief perform update to the model(boosting) * \param p_fmat feature matrix that provide access to features diff --git a/include/xgboost/learner.h b/include/xgboost/learner.h index 0098059d318c..1b399bbbdf3f 100644 --- a/include/xgboost/learner.h +++ b/include/xgboost/learner.h @@ -134,6 +134,11 @@ class Learner : public Model, public Configurable, public dmlc::Serializable { HostDeviceVector **out_preds, uint32_t layer_begin, uint32_t layer_end) = 0; + /* + * \brief Get number of boosted rounds from gradient booster. + */ + virtual int32_t BoostedRounds() const = 0; + void LoadModel(Json const& in) override = 0; void SaveModel(Json* out) const override = 0; diff --git a/python-package/xgboost/callback.py b/python-package/xgboost/callback.py index dec1e5f35817..3c66ccfb8d58 100644 --- a/python-package/xgboost/callback.py +++ b/python-package/xgboost/callback.py @@ -6,7 +6,7 @@ import collections import os import pickle -from typing import Callable, List +from typing import Callable, List, Optional, Union, Dict, Tuple import numpy from . import rabit @@ -285,11 +285,13 @@ def after_training(self, model): '''Run after training is finished.''' return model - def before_iteration(self, model, epoch, evals_log): + def before_iteration(self, model, epoch: int, + evals_log: 'CallbackContainer.EvalsLog') -> bool: '''Run before each iteration. Return True when training should stop.''' return False - def after_iteration(self, model, epoch, evals_log): + def after_iteration(self, model, epoch: int, + evals_log: 'CallbackContainer.EvalsLog') -> bool: '''Run after each iteration. Return True when training should stop.''' return False @@ -346,8 +348,13 @@ class CallbackContainer: .. versionadded:: 1.3.0 ''' - def __init__(self, callbacks: List[TrainingCallback], - metric: Callable = None, is_cv: bool = False): + + EvalsLog = Dict[str, Dict[str, Union[List[float], List[Tuple[float, float]]]]] + + def __init__(self, + callbacks: List[TrainingCallback], + metric: Callable = None, + is_cv: bool = False): self.callbacks = set(callbacks) if metric is not None: msg = 'metric must be callable object for monitoring. For ' + \ @@ -355,7 +362,7 @@ def __init__(self, callbacks: List[TrainingCallback], ' will invoke monitor automatically.' assert callable(metric), msg self.metric = metric - self.history = collections.OrderedDict() + self.history: CallbackContainer.EvalsLog = collections.OrderedDict() self.is_cv = is_cv if self.is_cv: @@ -383,7 +390,7 @@ def after_training(self, model): assert isinstance(model, Booster), msg return model - def before_iteration(self, model, epoch, dtrain, evals): + def before_iteration(self, model, epoch, dtrain, evals) -> bool: '''Function called before training iteration.''' return any(c.before_iteration(model, epoch, self.history) for c in self.callbacks) @@ -409,7 +416,7 @@ def _update_history(self, score, epoch): self.history[data_name][metric_name] = [s] return False - def after_iteration(self, model, epoch, dtrain, evals): + def after_iteration(self, model, epoch, dtrain, evals) -> bool: '''Function called after training iteration.''' if self.is_cv: scores = model.eval(epoch, self.metric) @@ -445,7 +452,7 @@ class LearningRateScheduler(TrainingCallback): rounds. ''' - def __init__(self, learning_rates): + def __init__(self, learning_rates) -> None: assert callable(learning_rates) or \ isinstance(learning_rates, collections.abc.Sequence) if callable(learning_rates): @@ -454,41 +461,42 @@ def __init__(self, learning_rates): self.learning_rates = lambda epoch: learning_rates[epoch] super().__init__() - def after_iteration(self, model, epoch, evals_log): + def after_iteration(self, model, epoch, evals_log) -> bool: model.set_param('learning_rate', self.learning_rates(epoch)) + return False # pylint: disable=too-many-instance-attributes class EarlyStopping(TrainingCallback): - ''' Callback function for early stopping + """Callback function for early stopping .. versionadded:: 1.3.0 Parameters ---------- - rounds : int + rounds Early stopping rounds. - metric_name : str + metric_name Name of metric that is used for early stopping. - data_name: str + data_name Name of dataset that is used for early stopping. - maximize : bool + maximize Whether to maximize evaluation metric. None means auto (discouraged). - save_best : bool + save_best Whether training should return the best model or the last model. - ''' + """ def __init__(self, - rounds, - metric_name=None, - data_name=None, - maximize=None, - save_best=False): + rounds: int, + metric_name: Optional[str] = None, + data_name: Optional[str] = None, + maximize: Optional[bool] = None, + save_best: Optional[bool] = False) -> None: self.data = data_name self.metric_name = metric_name self.rounds = rounds self.save_best = save_best self.maximize = maximize - self.stopping_history = {} + self.stopping_history: CallbackContainer.EvalsLog = {} if self.maximize is not None: if self.maximize: @@ -496,11 +504,16 @@ def __init__(self, else: self.improve_op = lambda x, y: x < y - self.current_rounds = 0 - self.best_scores = {} + self.current_rounds: int = 0 + self.best_scores: dict = {} + self.starting_round: int = 0 super().__init__() - def _update_rounds(self, score, name, metric, model, epoch): + def before_training(self, model): + self.starting_round = model.num_boosted_rounds() + return model + + def _update_rounds(self, score, name, metric, model, epoch) -> bool: # Just to be compatibility with old behavior before 1.3. We should let # user to decide. if self.maximize is None: @@ -536,7 +549,9 @@ def _update_rounds(self, score, name, metric, model, epoch): return True return False - def after_iteration(self, model: Booster, epoch, evals_log): + def after_iteration(self, model, epoch: int, + evals_log: CallbackContainer.EvalsLog) -> bool: + epoch += self.starting_round # training continuation msg = 'Must have at least 1 validation dataset for early stopping.' assert len(evals_log.keys()) >= 1, msg data_name = '' @@ -562,12 +577,14 @@ def after_iteration(self, model: Booster, epoch, evals_log): score = data_log[metric_name][-1] return self._update_rounds(score, data_name, metric_name, model, epoch) - def after_training(self, model: Booster): + def after_training(self, model): try: if self.save_best: - model = model[: int(model.attr('best_iteration'))] + model = model[: int(model.attr("best_iteration")) + 1] except XGBoostError as e: - raise XGBoostError('`save_best` is not applicable to current booster') from e + raise XGBoostError( + "`save_best` is not applicable to current booster" + ) from e return model @@ -588,36 +605,37 @@ class EvaluationMonitor(TrainingCallback): show_stdv : bool Used in cv to show standard deviation. Users should not specify it. ''' - def __init__(self, rank=0, period=1, show_stdv=False): + def __init__(self, rank=0, period=1, show_stdv=False) -> None: self.printer_rank = rank self.show_stdv = show_stdv self.period = period assert period > 0 # last error message, useful when early stopping and period are used together. - self._latest = None + self._latest: Optional[str] = None super().__init__() - def _fmt_metric(self, data, metric, score, std): + def _fmt_metric(self, data, metric, score, std) -> str: if std is not None and self.show_stdv: msg = '\t{0}:{1:.5f}+{2:.5f}'.format(data + '-' + metric, score, std) else: msg = '\t{0}:{1:.5f}'.format(data + '-' + metric, score) return msg - def after_iteration(self, model, epoch, evals_log): + def after_iteration(self, model, epoch: int, + evals_log: CallbackContainer.EvalsLog) -> bool: if not evals_log: return False - msg = f'[{epoch}]' + msg: str = f'[{epoch}]' if rabit.get_rank() == self.printer_rank: for data, metric in evals_log.items(): for metric_name, log in metric.items(): + stdv: Optional[float] = None if isinstance(log[-1], tuple): score = log[-1][0] stdv = log[-1][1] else: score = log[-1] - stdv = None msg += self._fmt_metric(data, metric_name, score, stdv) msg += '\n' @@ -665,7 +683,8 @@ def __init__(self, directory: os.PathLike, name: str = 'model', self._epoch = 0 super().__init__() - def after_iteration(self, model, epoch, evals_log): + def after_iteration(self, model, epoch: int, + evals_log: CallbackContainer.EvalsLog) -> bool: if self._epoch == self._iterations: path = os.path.join(self._path, self._name + '_' + str(epoch) + ('.pkl' if self._as_pickle else '.json')) @@ -677,6 +696,7 @@ def after_iteration(self, model, epoch, evals_log): else: model.save_model(path) self._epoch += 1 + return False class LegacyCallbacks: diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index cbf61d9ff6f4..6426be00d190 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -1177,23 +1177,6 @@ def copy(self): """ return self.__copy__() - def load_rabit_checkpoint(self): - """Initialize the model by load from rabit checkpoint. - - Returns - ------- - version: integer - The version number of the model. - """ - version = ctypes.c_int() - _check_call(_LIB.XGBoosterLoadRabitCheckpoint( - self.handle, ctypes.byref(version))) - return version.value - - def save_rabit_checkpoint(self): - """Save the current booster to rabit checkpoint.""" - _check_call(_LIB.XGBoosterSaveRabitCheckpoint(self.handle)) - def attr(self, key): """Get attribute string from the Booster. @@ -1745,6 +1728,17 @@ def load_model(self, fname): else: raise TypeError('Unknown file type: ', fname) + def num_boosted_rounds(self) -> int: + '''Get number of boosted rounds. For gblinear this is reset to 0 after + serializing the model. + + ''' + rounds = ctypes.c_int() + assert self.handle is not None + _check_call(_LIB.XGBoosterBoostedRounds( + self.handle, ctypes.byref(rounds))) + return rounds.value + def dump_model(self, fout, fmap='', with_stats=False, dump_format="text"): """Dump model into a text or JSON file. Unlike `save_model`, the output format is primarily used for visualization or interpretation, diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index d3b2a1bf85b2..e1cde586543a 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -4,6 +4,7 @@ import copy import warnings import json +from typing import Union, Optional, List, Dict, Callable, Tuple, Any import numpy as np from .core import Booster, DMatrix, XGBoostError, _deprecate_positional_args from .training import train @@ -494,6 +495,22 @@ def load_model(self, fname): # Delete the attribute after load self.get_booster().set_attr(scikit_learn=None) + def _configure_fit( + self, + booster: Optional[Booster], + eval_metric: Optional[Union[Callable, str, List[str]]], + params: Dict[str, Any], + ) -> Tuple[Booster, Optional[Union[Callable, str, List[str]]], Dict[str, Any]]: + model = self._Booster if hasattr(self, "_Booster") else None + model = booster if booster is not None else model + feval = eval_metric if callable(eval_metric) else None + if eval_metric is not None: + if callable(eval_metric): + eval_metric = None + else: + params.update({"eval_metric": eval_metric}) + return model, feval, params + @_deprecate_positional_args def fit(self, X, y, *, sample_weight=None, base_margin=None, eval_set=None, eval_metric=None, early_stopping_rounds=None, @@ -586,19 +603,13 @@ def fit(self, X, y, *, sample_weight=None, base_margin=None, else: obj = None - feval = eval_metric if callable(eval_metric) else None - if eval_metric is not None: - if callable(eval_metric): - eval_metric = None - else: - params.update({'eval_metric': eval_metric}) - + model, feval, params = self._configure_fit(xgb_model, eval_metric, params) self._Booster = train(params, train_dmatrix, self.get_num_boosting_rounds(), evals=evals, early_stopping_rounds=early_stopping_rounds, evals_result=evals_result, obj=obj, feval=feval, - verbose_eval=verbose, xgb_model=xgb_model, + verbose_eval=verbose, xgb_model=model, callbacks=callbacks) if evals_result: @@ -857,27 +868,20 @@ def fit(self, X, y, *, sample_weight=None, base_margin=None, not np.array_equal(self.classes_, np.arange(self.n_classes_))): raise ValueError(label_encoding_check_error) - xgb_options = self.get_xgb_params() + params = self.get_xgb_params() if callable(self.objective): obj = _objective_decorator(self.objective) # Use default value. Is it really not used ? - xgb_options["objective"] = "binary:logistic" + params["objective"] = "binary:logistic" else: obj = None if self.n_classes_ > 2: # Switch to using a multiclass objective in the underlying # XGB instance - xgb_options['objective'] = 'multi:softprob' - xgb_options['num_class'] = self.n_classes_ - - feval = eval_metric if callable(eval_metric) else None - if eval_metric is not None: - if callable(eval_metric): - eval_metric = None - else: - xgb_options.update({"eval_metric": eval_metric}) + params['objective'] = 'multi:softprob' + params['num_class'] = self.n_classes_ if self.use_label_encoder: if not can_use_label_encoder: @@ -891,6 +895,7 @@ def fit(self, X, y, *, sample_weight=None, base_margin=None, else: label_transform = (lambda x: x) + model, feval, params = self._configure_fit(xgb_model, eval_metric, params) if len(X.shape) != 2: # Simply raise an error here since there might be many # different ways of reshaping @@ -906,15 +911,15 @@ def fit(self, X, y, *, sample_weight=None, base_margin=None, eval_set=eval_set, sample_weight_eval_set=sample_weight_eval_set, eval_group=None, label_transform=label_transform) - self._Booster = train(xgb_options, train_dmatrix, + self._Booster = train(params, train_dmatrix, self.get_num_boosting_rounds(), evals=evals, early_stopping_rounds=early_stopping_rounds, evals_result=evals_result, obj=obj, feval=feval, - verbose_eval=verbose, xgb_model=xgb_model, + verbose_eval=verbose, xgb_model=model, callbacks=callbacks) - self.objective = xgb_options["objective"] + self.objective = params["objective"] if evals_result: for val in evals_result.items(): evals_result_key = list(val[1].keys())[0] diff --git a/python-package/xgboost/training.py b/python-package/xgboost/training.py index 8db3a9798a9e..1467fe72641e 100644 --- a/python-package/xgboost/training.py +++ b/python-package/xgboost/training.py @@ -4,11 +4,10 @@ """Training Library containing training routines.""" import warnings import copy - +import json import numpy as np from .core import Booster, XGBoostError from .compat import (SKLEARN_INSTALLED, XGBStratifiedKFold) -from . import rabit from . import callback @@ -51,28 +50,12 @@ def _train_internal(params, dtrain, evals = list(evals) bst = Booster(params, [dtrain] + [d[0] for d in evals]) - nboost = 0 - num_parallel_tree = 1 if xgb_model is not None: bst = Booster(params, [dtrain] + [d[0] for d in evals], model_file=xgb_model) - nboost = len(bst.get_dump()) - - _params = dict(params) if isinstance(params, list) else params - if 'num_parallel_tree' in _params and _params[ - 'num_parallel_tree'] is not None: - num_parallel_tree = _params['num_parallel_tree'] - nboost //= num_parallel_tree - if 'num_class' in _params and _params['num_class'] is not None: - nboost //= _params['num_class'] - - # Distributed code: Load the checkpoint from rabit. - version = bst.load_rabit_checkpoint() - assert rabit.get_world_size() != 1 or version == 0 - start_iteration = int(version / 2) - nboost += start_iteration + start_iteration = 0 is_new_callback = _is_new_callback(callbacks) if is_new_callback: @@ -92,26 +75,13 @@ def _train_internal(params, dtrain, show_stdv=False, cvfolds=None) bst = callbacks.before_training(bst) + for i in range(start_iteration, num_boost_round): if callbacks.before_iteration(bst, i, dtrain, evals): break - # Distributed code: need to resume to this point. - # Skip the first update if it is a recovery step. - if version % 2 == 0: - bst.update(dtrain, i, obj) - bst.save_rabit_checkpoint() - version += 1 - - assert rabit.get_world_size() == 1 or version == rabit.version_number() - - nboost += 1 - # check evaluation result. + bst.update(dtrain, i, obj) if callbacks.after_iteration(bst, i, dtrain, evals): break - # do checkpoint after evaluation, in case evaluation also updates - # booster. - bst.save_rabit_checkpoint() - version += 1 bst = callbacks.after_training(bst) @@ -122,7 +92,12 @@ def _train_internal(params, dtrain, bst.best_score = float(bst.attr('best_score')) bst.best_iteration = int(bst.attr('best_iteration')) else: - bst.best_iteration = nboost - 1 + bst.best_iteration = bst.num_boosted_rounds() - 1 + try: + num_parallel_tree = int(json.loads(bst.save_config())['learner'][ + 'gradient_booster']['gbtree_train_param']['num_parallel_tree']) + except KeyError: # gblinear + num_parallel_tree = 1 bst.best_ntree_limit = (bst.best_iteration + 1) * num_parallel_tree # Copy to serialise and unserialise booster to reset state and free # training memory @@ -234,7 +209,7 @@ def eval(self, iteration, feval): class _PackedBooster: - def __init__(self, cvfolds): + def __init__(self, cvfolds) -> None: self.cvfolds = cvfolds def update(self, iteration, obj): @@ -262,6 +237,10 @@ def best_iteration(self): ret = self.cvfolds[0].bst.attr('best_iteration') return int(ret) + def num_boosted_rounds(self) -> int: + '''Number of boosted rounds.''' + return self.cvfolds[0].bst.num_boosted_rounds() + def groups_to_rows(groups, boundaries): """ diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 727544fd9894..0ce1563c3453 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -502,6 +502,14 @@ XGB_DLL int XGBoosterGetNumFeature(BoosterHandle handle, API_END(); } +XGB_DLL int XGBoosterBoostedRounds(BoosterHandle handle, int* out) { + API_BEGIN(); + CHECK_HANDLE(); + static_cast(handle)->Configure(); + *out = static_cast(handle)->BoostedRounds(); + API_END(); +} + XGB_DLL int XGBoosterLoadJsonConfig(BoosterHandle handle, char const* json_parameters) { API_BEGIN(); CHECK_HANDLE(); diff --git a/src/gbm/gblinear.cc b/src/gbm/gblinear.cc index b6ba17269c23..2e94e8626925 100644 --- a/src/gbm/gblinear.cc +++ b/src/gbm/gblinear.cc @@ -73,6 +73,10 @@ class GBLinear : public GradientBooster { } } + int32_t BoostedRounds() const override { + return model_.num_boosted_rounds; + } + void Load(dmlc::Stream* fi) override { model_.Load(fi); } @@ -122,7 +126,7 @@ class GBLinear : public GradientBooster { if (!this->CheckConvergence()) { updater_->Update(in_gpair, p_fmat, &model_, sum_instance_weight_); } - + model_.num_boosted_rounds++; monitor_.Stop("DoBoost"); } diff --git a/src/gbm/gblinear_model.h b/src/gbm/gblinear_model.h index 9f4888d82f51..53121ae8411f 100644 --- a/src/gbm/gblinear_model.h +++ b/src/gbm/gblinear_model.h @@ -44,11 +44,12 @@ class GBLinearModel : public Model { DeprecatedGBLinearModelParam param_; public: + int32_t num_boosted_rounds; LearnerModelParam const* learner_model_param; public: explicit GBLinearModel(LearnerModelParam const* learner_model_param) : - learner_model_param {learner_model_param} {} + num_boosted_rounds{0}, learner_model_param {learner_model_param} {} void Configure(Args const &) { } // weight for each of feature, bias is the last one diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index d67f94c2c75b..25c9809ca4bd 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -249,10 +249,17 @@ class GBTree : public GradientBooster { auto n_trees = model_.learner_model_param->num_output_group * tparam_.num_parallel_tree; return n_trees; } + // slice the trees, out must be already allocated void Slice(int32_t layer_begin, int32_t layer_end, int32_t step, GradientBooster *out, bool* out_of_bound) const override; + int32_t BoostedRounds() const override { + CHECK_NE(tparam_.num_parallel_tree, 0); + CHECK_NE(model_.learner_model_param->num_output_group, 0); + return model_.trees.size() / this->LayerTrees(); + } + void PredictBatch(DMatrix* p_fmat, PredictionCacheEntry* out_preds, bool training, diff --git a/src/learner.cc b/src/learner.cc index a11dd5c4638e..71d9bb8583f7 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -1107,6 +1107,12 @@ class LearnerImpl : public LearnerIO { } } + int32_t BoostedRounds() const override { + if (!this->gbm_) { return 0; } // haven't call train or LoadModel. + CHECK(!this->need_configuration_); + return this->gbm_->BoostedRounds(); + } + XGBAPIThreadLocalEntry& GetThreadLocal() const override { return (*LearnerAPIThreadLocalStore::Get())[this]; } diff --git a/tests/python/test_basic_models.py b/tests/python/test_basic_models.py index 55c75dcc6112..6b7956801935 100644 --- a/tests/python/test_basic_models.py +++ b/tests/python/test_basic_models.py @@ -124,6 +124,20 @@ def test_boost_from_prediction(self): predt_2 = bst.predict(dtrain) assert np.all(np.abs(predt_2 - predt_1) < 1e-6) + def test_boost_from_existing_model(self): + X = xgb.DMatrix(dpath + 'agaricus.txt.train') + booster = xgb.train({'tree_method': 'hist'}, X, num_boost_round=4) + assert booster.num_boosted_rounds() == 4 + booster = xgb.train({'tree_method': 'hist'}, X, num_boost_round=4, + xgb_model=booster) + assert booster.num_boosted_rounds() == 8 + booster = xgb.train({'updater': 'prune', 'process_type': 'update'}, X, + num_boost_round=4, xgb_model=booster) + # Trees are moved for update, the rounds is reduced. This test is + # written for being compatible with current code (1.0.0). If the + # behaviour is considered sub-optimal, feel free to change. + assert booster.num_boosted_rounds() == 4 + def test_custom_objective(self): param = {'max_depth': 2, 'eta': 1, 'objective': 'reg:logistic'} watchlist = [(dtest, 'eval'), (dtrain, 'train')] diff --git a/tests/python/test_callback.py b/tests/python/test_callback.py index bdbc4cdfc32b..d1b7f17ab85e 100644 --- a/tests/python/test_callback.py +++ b/tests/python/test_callback.py @@ -43,7 +43,7 @@ def run_evaluation_monitor(self, D_train, D_valid, rounds, verbose_eval): # Should print info by each period additionaly to first and latest iteration num_periods = rounds // int(verbose_eval) # Extra information is required for latest iteration - is_extra_info_required = num_periods * int(verbose_eval) < (rounds - 1) + is_extra_info_required = num_periods * int(verbose_eval) < (rounds - 1) assert len(output.split('\n')) == 1 + num_periods + int(is_extra_info_required) def test_evaluation_monitor(self): @@ -63,7 +63,7 @@ def test_evaluation_monitor(self): self.run_evaluation_monitor(D_train, D_valid, rounds, True) self.run_evaluation_monitor(D_train, D_valid, rounds, 2) self.run_evaluation_monitor(D_train, D_valid, rounds, 4) - self.run_evaluation_monitor(D_train, D_valid, rounds, rounds + 1) + self.run_evaluation_monitor(D_train, D_valid, rounds, rounds + 1) def test_early_stopping(self): D_train = xgb.DMatrix(self.X_train, self.y_train) @@ -81,6 +81,15 @@ def test_early_stopping(self): dump = booster.get_dump(dump_format='json') assert len(dump) - booster.best_iteration == early_stopping_rounds + 1 + # No early stopping, best_iteration should be set to last epoch + booster = xgb.train({'objective': 'binary:logistic', + 'eval_metric': 'error'}, D_train, + evals=[(D_train, 'Train'), (D_valid, 'Valid')], + num_boost_round=10, + evals_result=evals_result, + verbose_eval=True) + assert booster.num_boosted_rounds() - 1 == booster.best_iteration + def test_early_stopping_custom_eval(self): D_train = xgb.DMatrix(self.X_train, self.y_train) D_valid = xgb.DMatrix(self.X_valid, self.y_valid) @@ -153,7 +162,7 @@ def test_early_stopping_save_best_model(self): eval_metric=tm.eval_error_metric, callbacks=[early_stop]) booster = cls.get_booster() dump = booster.get_dump(dump_format='json') - assert len(dump) == booster.best_iteration + assert len(dump) == booster.best_iteration + 1 early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds, save_best=True) @@ -170,6 +179,32 @@ def test_early_stopping_save_best_model(self): eval_metric=tm.eval_error_metric, callbacks=[early_stop]) + def test_early_stopping_continuation(self): + from sklearn.datasets import load_breast_cancer + X, y = load_breast_cancer(return_X_y=True) + cls = xgb.XGBClassifier() + early_stopping_rounds = 5 + early_stop = xgb.callback.EarlyStopping(rounds=early_stopping_rounds, + save_best=True) + cls.fit(X, y, eval_set=[(X, y)], + eval_metric=tm.eval_error_metric, + callbacks=[early_stop]) + booster = cls.get_booster() + assert booster.num_boosted_rounds() == booster.best_iteration + 1 + + with tempfile.TemporaryDirectory() as tmpdir: + path = os.path.join(tmpdir, 'model.json') + cls.save_model(path) + cls = xgb.XGBClassifier() + cls.load_model(path) + assert cls._Booster is not None + early_stopping_rounds = 3 + cls.fit(X, y, eval_set=[(X, y)], eval_metric=tm.eval_error_metric, + early_stopping_rounds=early_stopping_rounds) + booster = cls.get_booster() + assert booster.num_boosted_rounds() == \ + booster.best_iteration + early_stopping_rounds + 1 + def run_eta_decay(self, tree_method, deprecated_callback): if deprecated_callback: scheduler = xgb.callback.reset_learning_rate diff --git a/tests/python/test_early_stopping.py b/tests/python/test_early_stopping.py index d685fb3faa8b..67b7c1a560d5 100644 --- a/tests/python/test_early_stopping.py +++ b/tests/python/test_early_stopping.py @@ -46,7 +46,7 @@ def evalerror(self, preds, dtrain): @staticmethod def assert_metrics_length(cv, expected_length): for key, value in cv.items(): - assert len(value) == expected_length + assert len(value) == expected_length @pytest.mark.skipif(**tm.no_sklearn()) def test_cv_early_stopping(self): diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 836eaf4dbfd9..75cd1963622e 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -62,6 +62,8 @@ def check_pred(preds, labels, output_margin): kf = KFold(n_splits=2, shuffle=True, random_state=rng) for train_index, test_index in kf.split(X, y): xgb_model = xgb.XGBClassifier().fit(X[train_index], y[train_index]) + assert (xgb_model.get_booster().num_boosted_rounds() == + xgb_model.n_estimators) preds = xgb_model.predict(X[test_index]) # test other params in XGBClassifier().fit preds2 = xgb_model.predict(X[test_index], output_margin=True,