-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
/
gbtree_model.h
142 lines (125 loc) · 4.64 KB
/
gbtree_model.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
/*!
* Copyright 2017-2020 by Contributors
* \file gbtree_model.h
*/
#ifndef XGBOOST_GBM_GBTREE_MODEL_H_
#define XGBOOST_GBM_GBTREE_MODEL_H_
#include <dmlc/parameter.h>
#include <dmlc/io.h>
#include <xgboost/model.h>
#include <xgboost/tree_model.h>
#include <xgboost/parameter.h>
#include <xgboost/learner.h>
#include <memory>
#include <utility>
#include <string>
#include <vector>
#include "../common/threading_utils.h"
namespace xgboost {
class Json;
namespace gbm {
/*! \brief model parameters */
struct GBTreeModelParam : public dmlc::Parameter<GBTreeModelParam> {
public:
/*! \brief number of trees */
int32_t num_trees;
/*! \brief (Deprecated) number of roots */
int32_t deprecated_num_roots;
/*! \brief number of features to be used by trees */
int32_t deprecated_num_feature;
/*! \brief pad this space, for backward compatibility reason.*/
int32_t pad_32bit;
/*! \brief deprecated padding space. */
int64_t deprecated_num_pbuffer;
// deprecated. use learner_model_param_->num_output_group.
int32_t deprecated_num_output_group;
/*! \brief size of leaf vector needed in tree */
int32_t size_leaf_vector;
/*! \brief reserved parameters */
int32_t reserved[32];
/*! \brief constructor */
GBTreeModelParam() {
std::memset(this, 0, sizeof(GBTreeModelParam)); // FIXME(trivialfis): Why?
static_assert(sizeof(GBTreeModelParam) == (4 + 2 + 2 + 32) * sizeof(int32_t),
"64/32 bit compatibility issue");
deprecated_num_roots = 1;
}
// declare parameters, only declare those that need to be set.
DMLC_DECLARE_PARAMETER(GBTreeModelParam) {
DMLC_DECLARE_FIELD(num_trees)
.set_lower_bound(0)
.set_default(0)
.describe("Number of features used for training and prediction.");
DMLC_DECLARE_FIELD(size_leaf_vector)
.set_lower_bound(0)
.set_default(0)
.describe("Reserved option for vector tree.");
}
// Swap byte order for all fields. Useful for transporting models between machines with different
// endianness (big endian vs little endian)
inline GBTreeModelParam ByteSwap() const {
GBTreeModelParam x = *this;
dmlc::ByteSwap(&x.num_trees, sizeof(x.num_trees), 1);
dmlc::ByteSwap(&x.deprecated_num_roots, sizeof(x.deprecated_num_roots), 1);
dmlc::ByteSwap(&x.deprecated_num_feature, sizeof(x.deprecated_num_feature), 1);
dmlc::ByteSwap(&x.pad_32bit, sizeof(x.pad_32bit), 1);
dmlc::ByteSwap(&x.deprecated_num_pbuffer, sizeof(x.deprecated_num_pbuffer), 1);
dmlc::ByteSwap(&x.deprecated_num_output_group, sizeof(x.deprecated_num_output_group), 1);
dmlc::ByteSwap(&x.size_leaf_vector, sizeof(x.size_leaf_vector), 1);
dmlc::ByteSwap(x.reserved, sizeof(x.reserved[0]), sizeof(x.reserved) / sizeof(x.reserved[0]));
return x;
}
};
struct GBTreeModel : public Model {
public:
explicit GBTreeModel(LearnerModelParam const* learner_model) :
learner_model_param{learner_model} {}
void Configure(const Args& cfg) {
// initialize model parameters if not yet been initialized.
if (trees.size() == 0) {
param.UpdateAllowUnknown(cfg);
}
}
void InitTreesToUpdate() {
if (trees_to_update.size() == 0u) {
for (auto & tree : trees) {
trees_to_update.push_back(std::move(tree));
}
trees.clear();
param.num_trees = 0;
tree_info.clear();
}
}
void Load(dmlc::Stream* fi);
void Save(dmlc::Stream* fo) const;
void SaveModel(Json* p_out) const override;
void LoadModel(Json const& p_out) override;
std::vector<std::string> DumpModel(const FeatureMap& fmap, bool with_stats, int32_t n_threads,
std::string format) const {
std::vector<std::string> dump(trees.size());
common::ParallelFor(trees.size(), n_threads,
[&](size_t i) { dump[i] = trees[i]->DumpModel(fmap, with_stats, format); });
return dump;
}
void CommitModel(std::vector<std::unique_ptr<RegTree> >&& new_trees,
int bst_group) {
for (auto & new_tree : new_trees) {
trees.push_back(std::move(new_tree));
tree_info.push_back(bst_group);
}
param.num_trees += static_cast<int>(new_trees.size());
}
// base margin
LearnerModelParam const* learner_model_param;
// model parameter
GBTreeModelParam param;
/*! \brief vector of trees stored in the model */
std::vector<std::unique_ptr<RegTree> > trees;
/*! \brief for the update process, a place to keep the initial trees */
std::vector<std::unique_ptr<RegTree> > trees_to_update;
/*! \brief some information indicator of the tree, reserved */
std::vector<int> tree_info;
};
} // namespace gbm
} // namespace xgboost
#endif // XGBOOST_GBM_GBTREE_MODEL_H_