Skip to content

Commit

Permalink
Support early stopping with training continuation, correct num booste…
Browse files Browse the repository at this point in the history
…d rounds. (#6506)

* Implement early stopping with training continuation.

* Add new C API for obtaining boosted rounds.

* Fix off by 1 in `save_best`.

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
  • Loading branch information
trivialfis and hcho3 committed Dec 17, 2020
1 parent 125b3c0 commit ca3da55
Show file tree
Hide file tree
Showing 16 changed files with 210 additions and 118 deletions.
9 changes: 9 additions & 0 deletions include/xgboost/c_api.h
Expand Up @@ -614,6 +614,15 @@ XGB_DLL int XGBoosterSlice(BoosterHandle handle, int begin_layer,
int end_layer, int step,
BoosterHandle *out);

/*!
* \brief Get number of boosted rounds from gradient booster. When process_type is
* update, this number might drop due to removed tree.
* \param handle Handle to booster.
* \param out Pointer to output integer.
* \return 0 when success, -1 when failure happens
*/
XGB_DLL int XGBoosterBoostedRounds(BoosterHandle handle, int* out);

/*!
* \brief set parameters
* \param handle handle
Expand Down
3 changes: 3 additions & 0 deletions include/xgboost/gbm.h
Expand Up @@ -79,6 +79,9 @@ class GradientBooster : public Model, public Configurable {
virtual bool AllowLazyCheckPoint() const {
return false;
}
/*! \brief Return number of boosted rounds.
*/
virtual int32_t BoostedRounds() const = 0;
/*!
* \brief perform update to the model(boosting)
* \param p_fmat feature matrix that provide access to features
Expand Down
5 changes: 5 additions & 0 deletions include/xgboost/learner.h
Expand Up @@ -134,6 +134,11 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
HostDeviceVector<bst_float> **out_preds,
uint32_t layer_begin, uint32_t layer_end) = 0;

/*
* \brief Get number of boosted rounds from gradient booster.
*/
virtual int32_t BoostedRounds() const = 0;

void LoadModel(Json const& in) override = 0;
void SaveModel(Json* out) const override = 0;

Expand Down
94 changes: 57 additions & 37 deletions python-package/xgboost/callback.py
Expand Up @@ -6,7 +6,7 @@
import collections
import os
import pickle
from typing import Callable, List
from typing import Callable, List, Optional, Union, Dict, Tuple
import numpy

from . import rabit
Expand Down Expand Up @@ -285,11 +285,13 @@ def after_training(self, model):
'''Run after training is finished.'''
return model

def before_iteration(self, model, epoch, evals_log):
def before_iteration(self, model, epoch: int,
evals_log: 'CallbackContainer.EvalsLog') -> bool:
'''Run before each iteration. Return True when training should stop.'''
return False

def after_iteration(self, model, epoch, evals_log):
def after_iteration(self, model, epoch: int,
evals_log: 'CallbackContainer.EvalsLog') -> bool:
'''Run after each iteration. Return True when training should stop.'''
return False

Expand Down Expand Up @@ -346,16 +348,21 @@ class CallbackContainer:
.. versionadded:: 1.3.0
'''
def __init__(self, callbacks: List[TrainingCallback],
metric: Callable = None, is_cv: bool = False):

EvalsLog = Dict[str, Dict[str, Union[List[float], List[Tuple[float, float]]]]]

def __init__(self,
callbacks: List[TrainingCallback],
metric: Callable = None,
is_cv: bool = False):
self.callbacks = set(callbacks)
if metric is not None:
msg = 'metric must be callable object for monitoring. For ' + \
'builtin metrics, passing them in training parameter' + \
' will invoke monitor automatically.'
assert callable(metric), msg
self.metric = metric
self.history = collections.OrderedDict()
self.history: CallbackContainer.EvalsLog = collections.OrderedDict()
self.is_cv = is_cv

if self.is_cv:
Expand Down Expand Up @@ -383,7 +390,7 @@ def after_training(self, model):
assert isinstance(model, Booster), msg
return model

def before_iteration(self, model, epoch, dtrain, evals):
def before_iteration(self, model, epoch, dtrain, evals) -> bool:
'''Function called before training iteration.'''
return any(c.before_iteration(model, epoch, self.history)
for c in self.callbacks)
Expand All @@ -409,7 +416,7 @@ def _update_history(self, score, epoch):
self.history[data_name][metric_name] = [s]
return False

def after_iteration(self, model, epoch, dtrain, evals):
def after_iteration(self, model, epoch, dtrain, evals) -> bool:
'''Function called after training iteration.'''
if self.is_cv:
scores = model.eval(epoch, self.metric)
Expand Down Expand Up @@ -445,7 +452,7 @@ class LearningRateScheduler(TrainingCallback):
rounds.
'''
def __init__(self, learning_rates):
def __init__(self, learning_rates) -> None:
assert callable(learning_rates) or \
isinstance(learning_rates, collections.abc.Sequence)
if callable(learning_rates):
Expand All @@ -454,53 +461,59 @@ def __init__(self, learning_rates):
self.learning_rates = lambda epoch: learning_rates[epoch]
super().__init__()

def after_iteration(self, model, epoch, evals_log):
def after_iteration(self, model, epoch, evals_log) -> bool:
model.set_param('learning_rate', self.learning_rates(epoch))
return False


# pylint: disable=too-many-instance-attributes
class EarlyStopping(TrainingCallback):
''' Callback function for early stopping
"""Callback function for early stopping
.. versionadded:: 1.3.0
Parameters
----------
rounds : int
rounds
Early stopping rounds.
metric_name : str
metric_name
Name of metric that is used for early stopping.
data_name: str
data_name
Name of dataset that is used for early stopping.
maximize : bool
maximize
Whether to maximize evaluation metric. None means auto (discouraged).
save_best : bool
save_best
Whether training should return the best model or the last model.
'''
"""
def __init__(self,
rounds,
metric_name=None,
data_name=None,
maximize=None,
save_best=False):
rounds: int,
metric_name: Optional[str] = None,
data_name: Optional[str] = None,
maximize: Optional[bool] = None,
save_best: Optional[bool] = False) -> None:
self.data = data_name
self.metric_name = metric_name
self.rounds = rounds
self.save_best = save_best
self.maximize = maximize
self.stopping_history = {}
self.stopping_history: CallbackContainer.EvalsLog = {}

if self.maximize is not None:
if self.maximize:
self.improve_op = lambda x, y: x > y
else:
self.improve_op = lambda x, y: x < y

self.current_rounds = 0
self.best_scores = {}
self.current_rounds: int = 0
self.best_scores: dict = {}
self.starting_round: int = 0
super().__init__()

def _update_rounds(self, score, name, metric, model, epoch):
def before_training(self, model):
self.starting_round = model.num_boosted_rounds()
return model

def _update_rounds(self, score, name, metric, model, epoch) -> bool:
# Just to be compatibility with old behavior before 1.3. We should let
# user to decide.
if self.maximize is None:
Expand Down Expand Up @@ -536,7 +549,9 @@ def _update_rounds(self, score, name, metric, model, epoch):
return True
return False

def after_iteration(self, model: Booster, epoch, evals_log):
def after_iteration(self, model, epoch: int,
evals_log: CallbackContainer.EvalsLog) -> bool:
epoch += self.starting_round # training continuation
msg = 'Must have at least 1 validation dataset for early stopping.'
assert len(evals_log.keys()) >= 1, msg
data_name = ''
Expand All @@ -562,12 +577,14 @@ def after_iteration(self, model: Booster, epoch, evals_log):
score = data_log[metric_name][-1]
return self._update_rounds(score, data_name, metric_name, model, epoch)

def after_training(self, model: Booster):
def after_training(self, model):
try:
if self.save_best:
model = model[: int(model.attr('best_iteration'))]
model = model[: int(model.attr("best_iteration")) + 1]
except XGBoostError as e:
raise XGBoostError('`save_best` is not applicable to current booster') from e
raise XGBoostError(
"`save_best` is not applicable to current booster"
) from e
return model


Expand All @@ -588,36 +605,37 @@ class EvaluationMonitor(TrainingCallback):
show_stdv : bool
Used in cv to show standard deviation. Users should not specify it.
'''
def __init__(self, rank=0, period=1, show_stdv=False):
def __init__(self, rank=0, period=1, show_stdv=False) -> None:
self.printer_rank = rank
self.show_stdv = show_stdv
self.period = period
assert period > 0
# last error message, useful when early stopping and period are used together.
self._latest = None
self._latest: Optional[str] = None
super().__init__()

def _fmt_metric(self, data, metric, score, std):
def _fmt_metric(self, data, metric, score, std) -> str:
if std is not None and self.show_stdv:
msg = '\t{0}:{1:.5f}+{2:.5f}'.format(data + '-' + metric, score, std)
else:
msg = '\t{0}:{1:.5f}'.format(data + '-' + metric, score)
return msg

def after_iteration(self, model, epoch, evals_log):
def after_iteration(self, model, epoch: int,
evals_log: CallbackContainer.EvalsLog) -> bool:
if not evals_log:
return False

msg = f'[{epoch}]'
msg: str = f'[{epoch}]'
if rabit.get_rank() == self.printer_rank:
for data, metric in evals_log.items():
for metric_name, log in metric.items():
stdv: Optional[float] = None
if isinstance(log[-1], tuple):
score = log[-1][0]
stdv = log[-1][1]
else:
score = log[-1]
stdv = None
msg += self._fmt_metric(data, metric_name, score, stdv)
msg += '\n'

Expand Down Expand Up @@ -665,7 +683,8 @@ def __init__(self, directory: os.PathLike, name: str = 'model',
self._epoch = 0
super().__init__()

def after_iteration(self, model, epoch, evals_log):
def after_iteration(self, model, epoch: int,
evals_log: CallbackContainer.EvalsLog) -> bool:
if self._epoch == self._iterations:
path = os.path.join(self._path, self._name + '_' + str(epoch) +
('.pkl' if self._as_pickle else '.json'))
Expand All @@ -677,6 +696,7 @@ def after_iteration(self, model, epoch, evals_log):
else:
model.save_model(path)
self._epoch += 1
return False


class LegacyCallbacks:
Expand Down
28 changes: 11 additions & 17 deletions python-package/xgboost/core.py
Expand Up @@ -1177,23 +1177,6 @@ def copy(self):
"""
return self.__copy__()

def load_rabit_checkpoint(self):
"""Initialize the model by load from rabit checkpoint.
Returns
-------
version: integer
The version number of the model.
"""
version = ctypes.c_int()
_check_call(_LIB.XGBoosterLoadRabitCheckpoint(
self.handle, ctypes.byref(version)))
return version.value

def save_rabit_checkpoint(self):
"""Save the current booster to rabit checkpoint."""
_check_call(_LIB.XGBoosterSaveRabitCheckpoint(self.handle))

def attr(self, key):
"""Get attribute string from the Booster.
Expand Down Expand Up @@ -1745,6 +1728,17 @@ def load_model(self, fname):
else:
raise TypeError('Unknown file type: ', fname)

def num_boosted_rounds(self) -> int:
'''Get number of boosted rounds. For gblinear this is reset to 0 after
serializing the model.
'''
rounds = ctypes.c_int()
assert self.handle is not None
_check_call(_LIB.XGBoosterBoostedRounds(
self.handle, ctypes.byref(rounds)))
return rounds.value

def dump_model(self, fout, fmap='', with_stats=False, dump_format="text"):
"""Dump model into a text or JSON file. Unlike `save_model`, the
output format is primarily used for visualization or interpretation,
Expand Down

0 comments on commit ca3da55

Please sign in to comment.