Skip to content

Commit

Permalink
Support slicing tree model (#6302)
Browse files Browse the repository at this point in the history
This PR is meant the end the confusion around best_ntree_limit and unify model slicing. We have multi-class and random forests, asking users to understand how to set ntree_limit is difficult and error prone.

* Implement the save_best option in early stopping.

Co-authored-by: Philip Hyunsu Cho <chohyu01@cs.washington.edu>
  • Loading branch information
trivialfis and hcho3 committed Nov 3, 2020
1 parent 29745c6 commit 2cc9662
Show file tree
Hide file tree
Showing 19 changed files with 550 additions and 37 deletions.
8 changes: 4 additions & 4 deletions doc/python/callbacks.rst
Expand Up @@ -7,9 +7,9 @@ package. In XGBoost 1.3, a new callback interface is designed for Python packag
provides the flexiblity of designing various extension for training. Also, XGBoost has a
number of pre-defined callbacks for supporting early stopping, checkpoints etc.

#######################

Using builtin callbacks
#######################
-----------------------

By default, training methods in XGBoost have parameters like ``early_stopping_rounds`` and
``verbose``/``verbose_eval``, when specified the training procedure will define the
Expand Down Expand Up @@ -50,9 +50,9 @@ this callback function directly into XGBoost:
dump = booster.get_dump(dump_format='json')
assert len(early_stop.stopping_history['Valid']['CustomErr']) == len(dump)
##########################
Defining your own callback
##########################
--------------------------

XGBoost provides an callback interface class: ``xgboost.callback.TrainingCallback``, user
defined callbacks should inherit this class and override corresponding methods. There's a
Expand Down
1 change: 1 addition & 0 deletions doc/python/index.rst
Expand Up @@ -12,4 +12,5 @@ Contents
python_intro
python_api
callbacks
model
Python examples <https://github.com/dmlc/xgboost/tree/master/demo/guide-python>
38 changes: 38 additions & 0 deletions doc/python/model.rst
@@ -0,0 +1,38 @@
#####
Model
#####

Slice tree model
----------------

When ``booster`` is set to ``gbtree`` or ``dart``, XGBoost builds a tree model, which is a
list of trees and can be sliced into multiple sub-models.

.. code-block:: python
from sklearn.datasets import make_classification
num_classes = 3
X, y = make_classification(n_samples=1000, n_informative=5,
n_classes=num_classes)
dtrain = xgb.DMatrix(data=X, label=y)
num_parallel_tree = 4
num_boost_round = 16
# total number of built trees is num_parallel_tree * num_classes * num_boost_round
# We build a boosted random forest for classification here.
booster = xgb.train({
'num_parallel_tree': 4, 'subsample': 0.5, 'num_class': 3},
num_boost_round=num_boost_round, dtrain=dtrain)
# This is the sliced model, containing [3, 7) forests
# step is also supported with some limitations like negative step is invalid.
sliced: xgb.Booster = booster[3:7]
# Access individual tree layer
trees = [_ for _ in booster]
assert len(trees) == num_boost_round
The sliced model is a copy of selected trees, that means the model itself is immutable
during slicing. This feature is the basis of `save_best` option in early stopping
callback.
17 changes: 17 additions & 0 deletions include/xgboost/c_api.h
Expand Up @@ -580,6 +580,23 @@ XGB_DLL int XGBoosterCreate(const DMatrixHandle dmats[],
*/
XGB_DLL int XGBoosterFree(BoosterHandle handle);

/*!
* \brief Slice a model using boosting index. The slice m:n indicates taking all trees
* that were fit during the boosting rounds m, (m+1), (m+2), ..., (n-1).
*
* \param handle Booster to be sliced.
* \param begin_layer start of the slice
* \param end_layer end of the slice; end_layer=0 is equivalent to
* end_layer=num_boost_round
* \param step step size of the slice
* \param out Sliced booster.
*
* \return 0 when success, -1 when failure happens, -2 when index is out of bound.
*/
XGB_DLL int XGBoosterSlice(BoosterHandle handle, int begin_layer,
int end_layer, int step,
BoosterHandle *out);

/*!
* \brief set parameters
* \param handle handle
Expand Down
11 changes: 11 additions & 0 deletions include/xgboost/gbm.h
Expand Up @@ -60,6 +60,17 @@ class GradientBooster : public Model, public Configurable {
* \param fo output stream
*/
virtual void Save(dmlc::Stream* fo) const = 0;
/*!
* \brief Slice a model using boosting index. The slice m:n indicates taking all trees
* that were fit during the boosting rounds m, (m+1), (m+2), ..., (n-1).
* \param layer_begin Begining of boosted tree layer used for prediction.
* \param layer_end End of booster layer. 0 means do not limit trees.
* \param out Output gradient booster
*/
virtual void Slice(int32_t layer_begin, int32_t layer_end, int32_t step,
GradientBooster *out, bool* out_of_bound) const {
LOG(FATAL) << "Slice is not supported by current booster.";
}
/*!
* \brief whether the model allow lazy checkpoint
* return true if model is only updated in DoBoost
Expand Down
12 changes: 12 additions & 0 deletions include/xgboost/learner.h
Expand Up @@ -195,6 +195,18 @@ class Learner : public Model, public Configurable, public dmlc::Serializable {
* \return whether the model allow lazy checkpoint in rabit.
*/
bool AllowLazyCheckPoint() const;
/*!
* \brief Slice the model.
*
* See InplacePredict for layer parameters.
*
* \param step step size between slice.
* \param out_of_bound Return true if end layer is out of bound.
*
* \return a sliced model.
*/
virtual Learner *Slice(int32_t begin_layer, int32_t end_layer, int32_t step,
bool *out_of_bound) = 0;
/*!
* \brief dump the model in the requested format
* \param fmap feature map that may help give interpretations of feature
Expand Down
37 changes: 29 additions & 8 deletions python-package/xgboost/callback.py
Expand Up @@ -10,7 +10,7 @@
import numpy

from . import rabit
from .core import EarlyStopException, CallbackEnv
from .core import EarlyStopException, CallbackEnv, Booster, XGBoostError
from .compat import STRING_TYPES


Expand Down Expand Up @@ -279,9 +279,11 @@ def __init__(self):

def before_training(self, model):
'''Run before training starts.'''
return model

def after_training(self, model):
'''Run after training is finished.'''
return model

def before_iteration(self, model, epoch, evals_log):
'''Run before each iteration. Return True when training should stop.'''
Expand Down Expand Up @@ -362,12 +364,24 @@ def __init__(self, callbacks: List[TrainingCallback],
def before_training(self, model):
'''Function called before training.'''
for c in self.callbacks:
c.before_training(model=model)
model = c.before_training(model=model)
msg = 'before_training should return the model'
if self.is_cv:
assert isinstance(model.cvfolds, list), msg
else:
assert isinstance(model, Booster), msg
return model

def after_training(self, model):
'''Function called after training.'''
for c in self.callbacks:
c.after_training(model)
model = c.after_training(model=model)
msg = 'after_training should return the model'
if self.is_cv:
assert isinstance(model.cvfolds, list), msg
else:
assert isinstance(model, Booster), msg
return model

def before_iteration(self, model, epoch, dtrain, evals):
'''Function called before training iteration.'''
Expand Down Expand Up @@ -461,7 +475,7 @@ class EarlyStopping(TrainingCallback):
maximize : bool
Whether to maximize evaluation metric. None means auto (discouraged).
save_best : bool
Placeholder, the feature is not yet supported.
Whether training should return the best model or the last model.
'''
def __init__(self,
rounds,
Expand All @@ -473,9 +487,6 @@ def __init__(self,
self.metric_name = metric_name
self.rounds = rounds
self.save_best = save_best
# https://github.com/dmlc/xgboost/issues/5531
assert self.save_best is False, 'save best is not yet supported.'

self.maximize = maximize
self.stopping_history = {}

Expand Down Expand Up @@ -525,7 +536,7 @@ def _update_rounds(self, score, name, metric, model, epoch):
return True
return False

def after_iteration(self, model, epoch, evals_log):
def after_iteration(self, model: Booster, epoch, evals_log):
msg = 'Must have at least 1 validation dataset for early stopping.'
assert len(evals_log.keys()) >= 1, msg
data_name = ''
Expand All @@ -551,6 +562,14 @@ def after_iteration(self, model, epoch, evals_log):
score = data_log[metric_name][-1]
return self._update_rounds(score, data_name, metric_name, model, epoch)

def after_training(self, model: Booster):
try:
if self.save_best:
model = model[: int(model.attr('best_iteration'))]
except XGBoostError as e:
raise XGBoostError('`save_best` is not applicable to current booster') from e
return model


class EvaluationMonitor(TrainingCallback):
'''Print the evaluation result at each iteration.
Expand Down Expand Up @@ -684,9 +703,11 @@ def __init__(self, callbacks, start_iteration, end_iteration,

def before_training(self, model):
'''Nothing to do for legacy callbacks'''
return model

def after_training(self, model):
'''Nothing to do for legacy callbacks'''
return model

def before_iteration(self, model, epoch, dtrain, evals):
'''Called before each iteration.'''
Expand Down
41 changes: 39 additions & 2 deletions python-package/xgboost/core.py
Expand Up @@ -944,8 +944,8 @@ def __init__(self, params=None, cache=(), model_file=None):
Parameters for boosters.
cache : list
List of cache items.
model_file : string or os.PathLike
Path to the model file.
model_file : string/os.PathLike/Booster/bytearray
Path to the model file if it's string or PathLike.
"""
for d in cache:
if not isinstance(d, DMatrix):
Expand Down Expand Up @@ -1021,6 +1021,43 @@ def __setstate__(self, state):
state['handle'] = handle
self.__dict__.update(state)

def __getitem__(self, val):
if isinstance(val, int):
val = slice(val, val+1)
if isinstance(val, tuple):
raise ValueError('Only supports slicing through 1 dimension.')
if not isinstance(val, slice):
msg = _expect((int, slice), type(val))
raise TypeError(msg)
if isinstance(val.start, type(Ellipsis)) or val.start is None:
start = 0
else:
start = val.start
if isinstance(val.stop, type(Ellipsis)) or val.stop is None:
stop = 0
else:
stop = val.stop
if stop < start:
raise ValueError('Invalid slice', val)

step = val.step if val.step is not None else 1

start = ctypes.c_int(start)
stop = ctypes.c_int(stop)
step = ctypes.c_int(step)

sliced_handle = ctypes.c_void_p()
status = _LIB.XGBoosterSlice(self.handle, start, stop, step,
ctypes.byref(sliced_handle))
if status == -2:
raise IndexError('Layer index out of range')
_check_call(status)

sliced = Booster()
_check_call(_LIB.XGBoosterFree(sliced.handle))
sliced.handle = sliced_handle
return sliced

def save_config(self):
'''Output internal parameter configuration of Booster as a JSON
string.
Expand Down
10 changes: 6 additions & 4 deletions python-package/xgboost/training.py
Expand Up @@ -103,7 +103,7 @@ def _train_internal(params, dtrain,
num_boost_round, feval, evals_result, callbacks,
show_stdv=False, cvfolds=None)

callbacks.before_training(bst)
bst = callbacks.before_training(bst)
for i in range(start_iteration, num_boost_round):
if callbacks.before_iteration(bst, i, dtrain, evals):
break
Expand All @@ -125,7 +125,7 @@ def _train_internal(params, dtrain,
bst.save_rabit_checkpoint()
version += 1

callbacks.after_training(bst)
bst = callbacks.after_training(bst)

if evals_result is not None and is_new_callback:
evals_result.update(callbacks.history)
Expand Down Expand Up @@ -495,9 +495,8 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
verbose_eval, early_stopping_rounds, maximize, 0,
num_boost_round, feval, None, callbacks,
show_stdv=show_stdv, cvfolds=cvfolds)
callbacks.before_training(cvfolds)

booster = _PackedBooster(cvfolds)
callbacks.before_training(booster)

for i in range(num_boost_round):
if callbacks.before_iteration(booster, i, dtrain, None):
Expand All @@ -524,4 +523,7 @@ def cv(params, dtrain, num_boost_round=10, nfold=3, stratified=False, folds=None
results = pd.DataFrame.from_dict(results)
except ImportError:
pass

callbacks.after_training(booster)

return results
16 changes: 16 additions & 0 deletions src/c_api/c_api.cc
Expand Up @@ -730,6 +730,22 @@ XGB_DLL int XGBoosterSaveRabitCheckpoint(BoosterHandle handle) {
API_END();
}

XGB_DLL int XGBoosterSlice(BoosterHandle handle, int begin_layer,
int end_layer, int step,
BoosterHandle *out) {
API_BEGIN();
CHECK_HANDLE();
auto* learner = static_cast<Learner*>(handle);
bool out_of_bound = false;
auto p_out = learner->Slice(begin_layer, end_layer, step, &out_of_bound);
if (out_of_bound) {
return -2;
}
CHECK(p_out);
*out = p_out;
API_END();
}

inline void XGBoostDumpModelImpl(BoosterHandle handle, const FeatureMap &fmap,
int with_stats, const char *format,
xgboost::bst_ulong *len,
Expand Down

0 comments on commit 2cc9662

Please sign in to comment.