forked from dmlc/xgboost
/
gbm.h
234 lines (218 loc) · 8.92 KB
/
gbm.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
/*!
* Copyright 2014-2021 by Contributors
* \file gbm.h
* \brief Interface of gradient booster,
* that learns through gradient statistics.
* \author Tianqi Chen
*/
#ifndef XGBOOST_GBM_H_
#define XGBOOST_GBM_H_
#include <dmlc/registry.h>
#include <dmlc/any.h>
#include <xgboost/base.h>
#include <xgboost/data.h>
#include <xgboost/host_device_vector.h>
#include <xgboost/model.h>
#include <vector>
#include <utility>
#include <string>
#include <functional>
#include <unordered_map>
#include <memory>
namespace xgboost {
class Json;
class FeatureMap;
class ObjFunction;
struct GenericParameter;
struct LearnerModelParam;
struct PredictionCacheEntry;
class PredictionContainer;
/*!
* \brief interface of gradient boosting model.
*/
class GradientBooster : public Model, public Configurable {
protected:
GenericParameter const* generic_param_;
public:
/*! \brief virtual destructor */
~GradientBooster() override = default;
/*!
* \brief Set the configuration of gradient boosting.
* User must call configure once before InitModel and Training.
*
* \param cfg configurations on both training and model parameters.
*/
virtual void Configure(const std::vector<std::pair<std::string, std::string> >& cfg) = 0;
/*!
* \brief load model from stream
* \param fi input stream.
*/
virtual void Load(dmlc::Stream* fi) = 0;
/*!
* \brief save model to stream.
* \param fo output stream
*/
virtual void Save(dmlc::Stream* fo) const = 0;
/*!
* \brief Slice a model using boosting index. The slice m:n indicates taking all trees
* that were fit during the boosting rounds m, (m+1), (m+2), ..., (n-1).
* \param layer_begin Beginning of boosted tree layer used for prediction.
* \param layer_end End of booster layer. 0 means do not limit trees.
* \param out Output gradient booster
*/
virtual void Slice(int32_t layer_begin, int32_t layer_end, int32_t step,
GradientBooster *out, bool* out_of_bound) const {
LOG(FATAL) << "Slice is not supported by current booster.";
}
/*!
* \brief whether the model allow lazy checkpoint
* return true if model is only updated in DoBoost
* after all Allreduce calls
*/
virtual bool AllowLazyCheckPoint() const {
return false;
}
/*! \brief Return number of boosted rounds.
*/
virtual int32_t BoostedRounds() const = 0;
/*!
* \brief perform update to the model(boosting)
* \param p_fmat feature matrix that provide access to features
* \param in_gpair address of the gradient pair statistics of the data
* \param prediction The output prediction cache entry that needs to be updated.
* the booster may change content of gpair
*/
virtual void DoBoost(DMatrix* p_fmat,
HostDeviceVector<GradientPair>* in_gpair,
PredictionCacheEntry*) = 0;
/*!
* \brief generate predictions for given feature matrix
* \param dmat feature matrix
* \param out_preds output vector to hold the predictions
* \param training Whether the prediction value is used for training. For dart booster
* drop out is performed during training.
* \param layer_begin Beginning of boosted tree layer used for prediction.
* \param layer_end End of booster layer. 0 means do not limit trees.
*/
virtual void PredictBatch(DMatrix* dmat,
PredictionCacheEntry* out_preds,
bool training,
unsigned layer_begin,
unsigned layer_end) = 0;
/*!
* \brief Inplace prediction.
*
* \param x A type erased data adapter.
* \param missing Missing value in the data.
* \param [in,out] out_preds The output preds.
* \param layer_begin (Optional) Beginning of boosted tree layer used for prediction.
* \param layer_end (Optional) End of booster layer. 0 means do not limit trees.
*/
virtual void InplacePredict(dmlc::any const &, std::shared_ptr<DMatrix>, float,
PredictionCacheEntry*,
uint32_t,
uint32_t) const {
LOG(FATAL) << "Inplace predict is not supported by current booster.";
}
/*!
* \brief online prediction function, predict score for one instance at a time
* NOTE: use the batch prediction interface if possible, batch prediction is usually
* more efficient than online prediction
* This function is NOT threadsafe, make sure you only call from one thread
*
* \param inst the instance you want to predict
* \param out_preds output vector to hold the predictions
* \param layer_begin Beginning of boosted tree layer used for prediction.
* \param layer_end End of booster layer. 0 means do not limit trees.
* \sa Predict
*/
virtual void PredictInstance(const SparsePage::Inst& inst,
std::vector<bst_float>* out_preds,
unsigned layer_begin, unsigned layer_end) = 0;
/*!
* \brief predict the leaf index of each tree, the output will be nsample * ntree vector
* this is only valid in gbtree predictor
* \param dmat feature matrix
* \param out_preds output vector to hold the predictions
* \param layer_begin Beginning of boosted tree layer used for prediction.
* \param layer_end End of booster layer. 0 means do not limit trees.
*/
virtual void PredictLeaf(DMatrix *dmat,
HostDeviceVector<bst_float> *out_preds,
unsigned layer_begin, unsigned layer_end) = 0;
/*!
* \brief feature contributions to individual predictions; the output will be a vector
* of length (nfeats + 1) * num_output_group * nsample, arranged in that order
* \param dmat feature matrix
* \param out_contribs output vector to hold the contributions
* \param layer_begin Beginning of boosted tree layer used for prediction.
* \param layer_end End of booster layer. 0 means do not limit trees.
* \param approximate use a faster (inconsistent) approximation of SHAP values
* \param condition condition on the condition_feature (0=no, -1=cond off, 1=cond on).
* \param condition_feature feature to condition on (i.e. fix) during calculations
*/
virtual void PredictContribution(DMatrix* dmat,
HostDeviceVector<bst_float>* out_contribs,
unsigned layer_begin, unsigned layer_end,
bool approximate = false, int condition = 0,
unsigned condition_feature = 0) = 0;
virtual void PredictInteractionContributions(
DMatrix *dmat, HostDeviceVector<bst_float> *out_contribs,
unsigned layer_begin, unsigned layer_end, bool approximate) = 0;
/*!
* \brief dump the model in the requested format
* \param fmap feature map that may help give interpretations of feature
* \param with_stats extra statistics while dumping model
* \param format the format to dump the model in
* \return a vector of dump for boosters.
*/
virtual std::vector<std::string> DumpModel(const FeatureMap& fmap,
bool with_stats,
std::string format) const = 0;
virtual void FeatureScore(std::string const &importance_type,
std::vector<bst_feature_t> *features,
std::vector<float> *scores) const {
LOG(FATAL) << "`feature_score` is not implemented for current booster.";
}
/*!
* \brief Whether the current booster uses GPU.
*/
virtual bool UseGPU() const = 0;
/*!
* \brief create a gradient booster from given name
* \param name name of gradient booster
* \param generic_param Pointer to runtime parameters
* \param learner_model_param pointer to global model parameters
* \return The created booster.
*/
static GradientBooster* Create(
const std::string& name,
GenericParameter const* generic_param,
LearnerModelParam const* learner_model_param);
};
/*!
* \brief Registry entry for tree updater.
*/
struct GradientBoosterReg
: public dmlc::FunctionRegEntryBase<
GradientBoosterReg,
std::function<GradientBooster* (LearnerModelParam const* learner_model_param)> > {
};
/*!
* \brief Macro to register gradient booster.
*
* \code
* // example of registering a objective ndcg@k
* XGBOOST_REGISTER_GBM(GBTree, "gbtree")
* .describe("Boosting tree ensembles.")
* .set_body([]() {
* return new GradientBooster<TStats>();
* });
* \endcode
*/
#define XGBOOST_REGISTER_GBM(UniqueId, Name) \
static DMLC_ATTRIBUTE_UNUSED ::xgboost::GradientBoosterReg & \
__make_ ## GradientBoosterReg ## _ ## UniqueId ## __ = \
::dmlc::Registry< ::xgboost::GradientBoosterReg>::Get()->__REGISTER__(Name)
} // namespace xgboost
#endif // XGBOOST_GBM_H_