New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support slicing tree model #6302
Changes from 8 commits
182ad8f
e173173
8723f80
a33657a
c9ddda0
40b5562
5cb2d5e
60368c8
ef48197
52948c6
7e30509
a087793
4dc4dd3
ae9ee40
e83b6b4
0a5ebcb
accc0f1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
##### | ||
Model | ||
##### | ||
|
||
Slice tree model | ||
---------------- | ||
|
||
When ``booster`` is set to ``gbtree`` or ``dart``, XGBoost builds a tree model, which is a | ||
list of trees and can be sliced into multiple sub-models. | ||
|
||
.. code-block:: python | ||
from sklearn.datasets import make_classification | ||
num_classes = 3 | ||
X, y = make_classification(n_samples=1000, n_informative=5, | ||
n_classes=num_classes) | ||
dtrain = xgb.DMatrix(data=X, label=y) | ||
num_parallel_tree = 4 | ||
num_boost_round = 16 | ||
total_trees = num_parallel_tree * num_classes * num_boost_round | ||
|
||
# We build a boosted random forest for classification here. | ||
booster = xgb.train({ | ||
'num_parallel_tree': 4, 'subsample': 0.5, 'num_class': 3}, | ||
num_boost_round=num_boost_round, dtrain=dtrain) | ||
|
||
# This is the sliced model, containing [3, 7) forests | ||
# step is also supported with some limitations like negative step is invalid. | ||
sliced: xgb.Booster = booster[3:7] | ||
|
||
# Access individual tree layer | ||
trees = [_ for _ in booster] | ||
assert len(trees) == num_boost_round | ||
|
||
|
||
The sliced model is a copy of selected trees, that means the model itself is immutable | ||
during slicing. This feature is the basis of `save_best` option in early stopping | ||
callback. |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -580,6 +580,21 @@ XGB_DLL int XGBoosterCreate(const DMatrixHandle dmats[], | |||||||||||||||||||||||||||||||||
*/ | ||||||||||||||||||||||||||||||||||
XGB_DLL int XGBoosterFree(BoosterHandle handle); | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
/*! | ||||||||||||||||||||||||||||||||||
* \brief Slice a model according to layers. | ||||||||||||||||||||||||||||||||||
* | ||||||||||||||||||||||||||||||||||
* \param handle Booster to be sliced. | ||||||||||||||||||||||||||||||||||
* \param begin_layer start of the slice | ||||||||||||||||||||||||||||||||||
* \param end_layer end of the slice | ||||||||||||||||||||||||||||||||||
* \param step step size of the slice | ||||||||||||||||||||||||||||||||||
* \param out Sliced booster. | ||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added comments. |
||||||||||||||||||||||||||||||||||
* | ||||||||||||||||||||||||||||||||||
* \return 0 when success, -1 when failure happens, -2 when index is out of bound. | ||||||||||||||||||||||||||||||||||
*/ | ||||||||||||||||||||||||||||||||||
XGB_DLL int XGBoosterSlice(BoosterHandle handle, int begin_layer, | ||||||||||||||||||||||||||||||||||
int end_layer, int step, | ||||||||||||||||||||||||||||||||||
BoosterHandle *out); | ||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||
/*! | ||||||||||||||||||||||||||||||||||
* \brief set parameters | ||||||||||||||||||||||||||||||||||
* \param handle handle | ||||||||||||||||||||||||||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -60,6 +60,16 @@ class GradientBooster : public Model, public Configurable { | |||||||||||||||||||||||||||
* \param fo output stream | ||||||||||||||||||||||||||||
*/ | ||||||||||||||||||||||||||||
virtual void Save(dmlc::Stream* fo) const = 0; | ||||||||||||||||||||||||||||
/*! | ||||||||||||||||||||||||||||
* \brief Slice the model. | ||||||||||||||||||||||||||||
* \param layer_begin Begining of boosted tree layer used for prediction. | ||||||||||||||||||||||||||||
* \param layer_end End of booster layer. 0 means do not limit trees. | ||||||||||||||||||||||||||||
* \param out Output gradient booster | ||||||||||||||||||||||||||||
*/ | ||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added comments. |
||||||||||||||||||||||||||||
virtual void Slice(int32_t layer_begin, int32_t layer_end, int32_t step, | ||||||||||||||||||||||||||||
GradientBooster *out, bool* out_of_bound) const { | ||||||||||||||||||||||||||||
LOG(FATAL) << "Slice is not supported by current booster."; | ||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||
/*! | ||||||||||||||||||||||||||||
* \brief whether the model allow lazy checkpoint | ||||||||||||||||||||||||||||
* return true if model is only updated in DoBoost | ||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This variable is not used anywhere in the code snippet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Converted into a comment.