forked from dmlc/xgboost
/
tree_model.h
739 lines (687 loc) · 24.8 KB
/
tree_model.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
/*!
* Copyright 2014-2022 by Contributors
* \file tree_model.h
* \brief model structure for tree
* \author Tianqi Chen
*/
#ifndef XGBOOST_TREE_MODEL_H_
#define XGBOOST_TREE_MODEL_H_
#include <dmlc/io.h>
#include <dmlc/parameter.h>
#include <xgboost/base.h>
#include <xgboost/data.h>
#include <xgboost/logging.h>
#include <xgboost/feature_map.h>
#include <xgboost/model.h>
#include <limits>
#include <vector>
#include <string>
#include <cstring>
#include <algorithm>
#include <tuple>
#include <stack>
#include "xgboost/host_device_vector.h"
namespace xgboost {
struct PathElement; // forward declaration
class Json;
// FIXME(trivialfis): Once binary IO is gone, make this parameter internal as it should
// not be configured by users.
/*! \brief meta parameters of the tree */
struct TreeParam : public dmlc::Parameter<TreeParam> {
/*! \brief (Deprecated) number of start root */
int deprecated_num_roots;
/*! \brief total number of nodes */
int num_nodes;
/*!\brief number of deleted nodes */
int num_deleted;
/*! \brief maximum depth, this is a statistics of the tree */
int deprecated_max_depth;
/*! \brief number of features used for tree construction */
bst_feature_t num_feature;
/*!
* \brief leaf vector size, used for vector tree
* used to store more than one dimensional information in tree
*/
int size_leaf_vector;
/*! \brief reserved part, make sure alignment works for 64bit */
int reserved[31];
/*! \brief constructor */
TreeParam() {
// assert compact alignment
static_assert(sizeof(TreeParam) == (31 + 6) * sizeof(int),
"TreeParam: 64 bit align");
std::memset(this, 0, sizeof(TreeParam));
num_nodes = 1;
deprecated_num_roots = 1;
}
// Swap byte order for all fields. Useful for transporting models between machines with different
// endianness (big endian vs little endian)
inline TreeParam ByteSwap() const {
TreeParam x = *this;
dmlc::ByteSwap(&x.deprecated_num_roots, sizeof(x.deprecated_num_roots), 1);
dmlc::ByteSwap(&x.num_nodes, sizeof(x.num_nodes), 1);
dmlc::ByteSwap(&x.num_deleted, sizeof(x.num_deleted), 1);
dmlc::ByteSwap(&x.deprecated_max_depth, sizeof(x.deprecated_max_depth), 1);
dmlc::ByteSwap(&x.num_feature, sizeof(x.num_feature), 1);
dmlc::ByteSwap(&x.size_leaf_vector, sizeof(x.size_leaf_vector), 1);
dmlc::ByteSwap(x.reserved, sizeof(x.reserved[0]), sizeof(x.reserved) / sizeof(x.reserved[0]));
return x;
}
// declare the parameters
DMLC_DECLARE_PARAMETER(TreeParam) {
// only declare the parameters that can be set by the user.
// other arguments are set by the algorithm.
DMLC_DECLARE_FIELD(num_nodes).set_lower_bound(1).set_default(1);
DMLC_DECLARE_FIELD(num_feature)
.describe("Number of features used in tree construction.");
DMLC_DECLARE_FIELD(num_deleted);
DMLC_DECLARE_FIELD(size_leaf_vector).set_lower_bound(0).set_default(0)
.describe("Size of leaf vector, reserved for vector tree");
}
bool operator==(const TreeParam& b) const {
return num_nodes == b.num_nodes &&
num_deleted == b.num_deleted &&
num_feature == b.num_feature &&
size_leaf_vector == b.size_leaf_vector;
}
};
/*! \brief node statistics used in regression tree */
struct RTreeNodeStat {
/*! \brief loss change caused by current split */
bst_float loss_chg;
/*! \brief sum of hessian values, used to measure coverage of data */
bst_float sum_hess;
/*! \brief weight of current node */
bst_float base_weight;
/*! \brief number of child that is leaf node known up to now */
int leaf_child_cnt {0};
RTreeNodeStat() = default;
RTreeNodeStat(float loss_chg, float sum_hess, float weight) :
loss_chg{loss_chg}, sum_hess{sum_hess}, base_weight{weight} {}
bool operator==(const RTreeNodeStat& b) const {
return loss_chg == b.loss_chg && sum_hess == b.sum_hess &&
base_weight == b.base_weight && leaf_child_cnt == b.leaf_child_cnt;
}
// Swap byte order for all fields. Useful for transporting models between machines with different
// endianness (big endian vs little endian)
inline RTreeNodeStat ByteSwap() const {
RTreeNodeStat x = *this;
dmlc::ByteSwap(&x.loss_chg, sizeof(x.loss_chg), 1);
dmlc::ByteSwap(&x.sum_hess, sizeof(x.sum_hess), 1);
dmlc::ByteSwap(&x.base_weight, sizeof(x.base_weight), 1);
dmlc::ByteSwap(&x.leaf_child_cnt, sizeof(x.leaf_child_cnt), 1);
return x;
}
};
/*!
* \brief define regression tree to be the most common tree model.
* This is the data structure used in xgboost's major tree models.
*/
class RegTree : public Model {
public:
using SplitCondT = bst_float;
static constexpr bst_node_t kInvalidNodeId {-1};
static constexpr uint32_t kDeletedNodeMarker = std::numeric_limits<uint32_t>::max();
static constexpr bst_node_t kRoot { 0 };
/*! \brief tree node */
class Node {
public:
XGBOOST_DEVICE Node() {
// assert compact alignment
static_assert(sizeof(Node) == 4 * sizeof(int) + sizeof(Info),
"Node: 64 bit align");
}
Node(int32_t cleft, int32_t cright, int32_t parent,
uint32_t split_ind, float split_cond, bool default_left) :
parent_{parent}, cleft_{cleft}, cright_{cright} {
this->SetParent(parent_);
this->SetSplit(split_ind, split_cond, default_left);
}
/*! \brief index of left child */
XGBOOST_DEVICE int LeftChild() const {
return this->cleft_;
}
/*! \brief index of right child */
XGBOOST_DEVICE int RightChild() const {
return this->cright_;
}
/*! \brief index of default child when feature is missing */
XGBOOST_DEVICE int DefaultChild() const {
return this->DefaultLeft() ? this->LeftChild() : this->RightChild();
}
/*! \brief feature index of split condition */
XGBOOST_DEVICE unsigned SplitIndex() const {
return sindex_ & ((1U << 31) - 1U);
}
/*! \brief when feature is unknown, whether goes to left child */
XGBOOST_DEVICE bool DefaultLeft() const {
return (sindex_ >> 31) != 0;
}
/*! \brief whether current node is leaf node */
XGBOOST_DEVICE bool IsLeaf() const {
return cleft_ == kInvalidNodeId;
}
/*! \return get leaf value of leaf node */
XGBOOST_DEVICE bst_float LeafValue() const {
return (this->info_).leaf_value;
}
/*! \return get split condition of the node */
XGBOOST_DEVICE SplitCondT SplitCond() const {
return (this->info_).split_cond;
}
/*! \brief get parent of the node */
XGBOOST_DEVICE int Parent() const {
return parent_ & ((1U << 31) - 1);
}
/*! \brief whether current node is left child */
XGBOOST_DEVICE bool IsLeftChild() const {
return (parent_ & (1U << 31)) != 0;
}
/*! \brief whether this node is deleted */
XGBOOST_DEVICE bool IsDeleted() const {
return sindex_ == kDeletedNodeMarker;
}
/*! \brief whether current node is root */
XGBOOST_DEVICE bool IsRoot() const { return parent_ == kInvalidNodeId; }
/*!
* \brief set the left child
* \param nid node id to right child
*/
XGBOOST_DEVICE void SetLeftChild(int nid) {
this->cleft_ = nid;
}
/*!
* \brief set the right child
* \param nid node id to right child
*/
XGBOOST_DEVICE void SetRightChild(int nid) {
this->cright_ = nid;
}
/*!
* \brief set split condition of current node
* \param split_index feature index to split
* \param split_cond split condition
* \param default_left the default direction when feature is unknown
*/
XGBOOST_DEVICE void SetSplit(unsigned split_index, SplitCondT split_cond,
bool default_left = false) {
if (default_left) split_index |= (1U << 31);
this->sindex_ = split_index;
(this->info_).split_cond = split_cond;
}
/*!
* \brief set the leaf value of the node
* \param value leaf value
* \param right right index, could be used to store
* additional information
*/
XGBOOST_DEVICE void SetLeaf(bst_float value, int right = kInvalidNodeId) {
(this->info_).leaf_value = value;
this->cleft_ = kInvalidNodeId;
this->cright_ = right;
}
/*! \brief mark that this node is deleted */
XGBOOST_DEVICE void MarkDelete() {
this->sindex_ = kDeletedNodeMarker;
}
/*! \brief Reuse this deleted node. */
XGBOOST_DEVICE void Reuse() {
this->sindex_ = 0;
}
// set parent
XGBOOST_DEVICE void SetParent(int pidx, bool is_left_child = true) {
if (is_left_child) pidx |= (1U << 31);
this->parent_ = pidx;
}
bool operator==(const Node& b) const {
return parent_ == b.parent_ && cleft_ == b.cleft_ &&
cright_ == b.cright_ && sindex_ == b.sindex_ &&
info_.leaf_value == b.info_.leaf_value;
}
inline Node ByteSwap() const {
Node x = *this;
dmlc::ByteSwap(&x.parent_, sizeof(x.parent_), 1);
dmlc::ByteSwap(&x.cleft_, sizeof(x.cleft_), 1);
dmlc::ByteSwap(&x.cright_, sizeof(x.cright_), 1);
dmlc::ByteSwap(&x.sindex_, sizeof(x.sindex_), 1);
dmlc::ByteSwap(&x.info_, sizeof(x.info_), 1);
return x;
}
private:
/*!
* \brief in leaf node, we have weights, in non-leaf nodes,
* we have split condition
*/
union Info{
bst_float leaf_value;
SplitCondT split_cond;
};
// pointer to parent, highest bit is used to
// indicate whether it's a left child or not
int32_t parent_{kInvalidNodeId};
// pointer to left, right
int32_t cleft_{kInvalidNodeId}, cright_{kInvalidNodeId};
// split feature index, left split or right split depends on the highest bit
uint32_t sindex_{0};
// extra info
Info info_;
};
/*!
* \brief change a non leaf node to a leaf node, delete its children
* \param rid node id of the node
* \param value new leaf value
*/
void ChangeToLeaf(int rid, bst_float value) {
CHECK(nodes_[nodes_[rid].LeftChild() ].IsLeaf());
CHECK(nodes_[nodes_[rid].RightChild()].IsLeaf());
this->DeleteNode(nodes_[rid].LeftChild());
this->DeleteNode(nodes_[rid].RightChild());
nodes_[rid].SetLeaf(value);
}
/*!
* \brief collapse a non leaf node to a leaf node, delete its children
* \param rid node id of the node
* \param value new leaf value
*/
void CollapseToLeaf(int rid, bst_float value) {
if (nodes_[rid].IsLeaf()) return;
if (!nodes_[nodes_[rid].LeftChild() ].IsLeaf()) {
CollapseToLeaf(nodes_[rid].LeftChild(), 0.0f);
}
if (!nodes_[nodes_[rid].RightChild() ].IsLeaf()) {
CollapseToLeaf(nodes_[rid].RightChild(), 0.0f);
}
this->ChangeToLeaf(rid, value);
}
/*! \brief model parameter */
TreeParam param;
/*! \brief constructor */
RegTree() {
param.num_nodes = 1;
param.num_deleted = 0;
nodes_.resize(param.num_nodes);
stats_.resize(param.num_nodes);
split_types_.resize(param.num_nodes, FeatureType::kNumerical);
split_categories_segments_.resize(param.num_nodes);
for (int i = 0; i < param.num_nodes; i ++) {
nodes_[i].SetLeaf(0.0f);
nodes_[i].SetParent(kInvalidNodeId);
}
}
/*! \brief get node given nid */
Node& operator[](int nid) {
return nodes_[nid];
}
/*! \brief get node given nid */
const Node& operator[](int nid) const {
return nodes_[nid];
}
/*! \brief get const reference to nodes */
const std::vector<Node>& GetNodes() const { return nodes_; }
/*! \brief get const reference to stats */
const std::vector<RTreeNodeStat>& GetStats() const { return stats_; }
/*! \brief get node statistics given nid */
RTreeNodeStat& Stat(int nid) {
return stats_[nid];
}
/*! \brief get node statistics given nid */
const RTreeNodeStat& Stat(int nid) const {
return stats_[nid];
}
/*!
* \brief load model from stream
* \param fi input stream
*/
void Load(dmlc::Stream* fi);
/*!
* \brief save model to stream
* \param fo output stream
*/
void Save(dmlc::Stream* fo) const;
void LoadModel(Json const& in) override;
void SaveModel(Json* out) const override;
bool operator==(const RegTree& b) const {
return nodes_ == b.nodes_ && stats_ == b.stats_ &&
deleted_nodes_ == b.deleted_nodes_ && param == b.param;
}
/* \brief Iterate through all nodes in this tree.
*
* \param Function that accepts a node index, and returns false when iteration should
* stop, otherwise returns true.
*/
template <typename Func> void WalkTree(Func func) const {
std::stack<bst_node_t> nodes;
nodes.push(kRoot);
auto &self = *this;
while (!nodes.empty()) {
auto nidx = nodes.top();
nodes.pop();
if (!func(nidx)) {
return;
}
auto left = self[nidx].LeftChild();
auto right = self[nidx].RightChild();
if (left != RegTree::kInvalidNodeId) {
nodes.push(left);
}
if (right != RegTree::kInvalidNodeId) {
nodes.push(right);
}
}
}
/*!
* \brief Compares whether 2 trees are equal from a user's perspective. The equality
* compares only non-deleted nodes.
*
* \param b The other tree.
*/
bool Equal(const RegTree& b) const;
/**
* \brief Expands a leaf node into two additional leaf nodes.
*
* \param nid The node index to expand.
* \param split_index Feature index of the split.
* \param split_value The split condition.
* \param default_left True to default left.
* \param base_weight The base weight, before learning rate.
* \param left_leaf_weight The left leaf weight for prediction, modified by learning rate.
* \param right_leaf_weight The right leaf weight for prediction, modified by learning rate.
* \param loss_change The loss change.
* \param sum_hess The sum hess.
* \param left_sum The sum hess of left leaf.
* \param right_sum The sum hess of right leaf.
* \param leaf_right_child The right child index of leaf, by default kInvalidNodeId,
* some updaters use the right child index of leaf as a marker
*/
void ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_value,
bool default_left, bst_float base_weight,
bst_float left_leaf_weight, bst_float right_leaf_weight,
bst_float loss_change, float sum_hess, float left_sum,
float right_sum,
bst_node_t leaf_right_child = kInvalidNodeId);
/**
* \brief Expands a leaf node with categories
*
* \param nid The node index to expand.
* \param split_index Feature index of the split.
* \param split_cat The bitset containing categories
* \param default_left True to default left.
* \param base_weight The base weight, before learning rate.
* \param left_leaf_weight The left leaf weight for prediction, modified by learning rate.
* \param right_leaf_weight The right leaf weight for prediction, modified by learning rate.
* \param loss_change The loss change.
* \param sum_hess The sum hess.
* \param left_sum The sum hess of left leaf.
* \param right_sum The sum hess of right leaf.
*/
void ExpandCategorical(bst_node_t nid, unsigned split_index,
common::Span<uint32_t> split_cat, bool default_left,
bst_float base_weight, bst_float left_leaf_weight,
bst_float right_leaf_weight, bst_float loss_change,
float sum_hess, float left_sum, float right_sum);
bool HasCategoricalSplit() const {
return !split_categories_.empty();
}
/*!
* \brief get current depth
* \param nid node id
*/
int GetDepth(int nid) const {
int depth = 0;
while (!nodes_[nid].IsRoot()) {
++depth;
nid = nodes_[nid].Parent();
}
return depth;
}
/*!
* \brief get maximum depth
* \param nid node id
*/
int MaxDepth(int nid) const {
if (nodes_[nid].IsLeaf()) return 0;
return std::max(MaxDepth(nodes_[nid].LeftChild())+1,
MaxDepth(nodes_[nid].RightChild())+1);
}
/*!
* \brief get maximum depth
*/
int MaxDepth() {
return MaxDepth(0);
}
/*! \brief number of extra nodes besides the root */
int NumExtraNodes() const {
return param.num_nodes - 1 - param.num_deleted;
}
/* \brief Count number of leaves in tree. */
bst_node_t GetNumLeaves() const;
bst_node_t GetNumSplitNodes() const;
/*!
* \brief dense feature vector that can be taken by RegTree
* and can be construct from sparse feature vector.
*/
struct FVec {
/*!
* \brief initialize the vector with size vector
* \param size The size of the feature vector.
*/
void Init(size_t size);
/*!
* \brief fill the vector with sparse vector
* \param inst The sparse instance to fill.
*/
void Fill(const SparsePage::Inst& inst);
/*!
* \brief drop the trace after fill, must be called after fill.
* \param inst The sparse instance to drop.
*/
void Drop(const SparsePage::Inst& inst);
/*!
* \brief returns the size of the feature vector
* \return the size of the feature vector
*/
size_t Size() const;
/*!
* \brief get ith value
* \param i feature index.
* \return the i-th feature value
*/
bst_float GetFvalue(size_t i) const;
/*!
* \brief check whether i-th entry is missing
* \param i feature index.
* \return whether i-th value is missing.
*/
bool IsMissing(size_t i) const;
bool HasMissing() const;
private:
/*!
* \brief a union value of value and flag
* when flag == -1, this indicate the value is missing
*/
union Entry {
bst_float fvalue;
int flag;
};
std::vector<Entry> data_;
bool has_missing_;
};
/*!
* \brief calculate the feature contributions (https://arxiv.org/abs/1706.06060) for the tree
* \param feat dense feature vector, if the feature is missing the field is set to NaN
* \param out_contribs output vector to hold the contributions
* \param condition fix one feature to either off (-1) on (1) or not fixed (0 default)
* \param condition_feature the index of the feature to fix
*/
void CalculateContributions(const RegTree::FVec& feat,
std::vector<float>* mean_values,
bst_float* out_contribs, int condition = 0,
unsigned condition_feature = 0) const;
/*!
* \brief Recursive function that computes the feature attributions for a single tree.
* \param feat dense feature vector, if the feature is missing the field is set to NaN
* \param phi dense output vector of feature attributions
* \param node_index the index of the current node in the tree
* \param unique_depth how many unique features are above the current node in the tree
* \param parent_unique_path a vector of statistics about our current path through the tree
* \param parent_zero_fraction what fraction of the parent path weight is coming as 0 (integrated)
* \param parent_one_fraction what fraction of the parent path weight is coming as 1 (fixed)
* \param parent_feature_index what feature the parent node used to split
* \param condition fix one feature to either off (-1) on (1) or not fixed (0 default)
* \param condition_feature the index of the feature to fix
* \param condition_fraction what fraction of the current weight matches our conditioning feature
*/
void TreeShap(const RegTree::FVec& feat, bst_float* phi, bst_node_t node_index,
unsigned unique_depth, PathElement* parent_unique_path,
bst_float parent_zero_fraction, bst_float parent_one_fraction,
int parent_feature_index, int condition,
unsigned condition_feature, bst_float condition_fraction) const;
/*!
* \brief calculate the approximate feature contributions for the given root
* \param feat dense feature vector, if the feature is missing the field is set to NaN
* \param out_contribs output vector to hold the contributions
*/
void CalculateContributionsApprox(const RegTree::FVec& feat,
std::vector<float>* mean_values,
bst_float* out_contribs) const;
/*!
* \brief dump the model in the requested format as a text string
* \param fmap feature map that may help give interpretations of feature
* \param with_stats whether dump out statistics as well
* \param format the format to dump the model in
* \return the string of dumped model
*/
std::string DumpModel(const FeatureMap& fmap,
bool with_stats,
std::string format) const;
/*!
* \brief Get split type for a node.
* \param nidx Index of node.
* \return The type of this split. For leaf node it's always kNumerical.
*/
FeatureType NodeSplitType(bst_node_t nidx) const {
return split_types_.at(nidx);
}
/*!
* \brief Get split types for all nodes.
*/
std::vector<FeatureType> const &GetSplitTypes() const { return split_types_; }
common::Span<uint32_t const> GetSplitCategories() const { return split_categories_; }
/*!
* \brief Get the bit storage for categories
*/
common::Span<uint32_t const> NodeCats(bst_node_t nidx) const {
auto node_ptr = GetCategoriesMatrix().node_ptr;
auto categories = GetCategoriesMatrix().categories;
auto segment = node_ptr[nidx];
auto node_cats = categories.subspan(segment.beg, segment.size);
return node_cats;
}
auto const& GetSplitCategoriesPtr() const { return split_categories_segments_; }
// The fields of split_categories_segments_[i] are set such that
// the range split_categories_[beg:(beg+size)] stores the bitset for
// the matching categories for the i-th node.
struct Segment {
size_t beg {0};
size_t size {0};
};
struct CategoricalSplitMatrix {
common::Span<FeatureType const> split_type;
common::Span<uint32_t const> categories;
common::Span<Segment const> node_ptr;
};
CategoricalSplitMatrix GetCategoriesMatrix() const {
CategoricalSplitMatrix view;
view.split_type = common::Span<FeatureType const>(this->GetSplitTypes());
view.categories = this->GetSplitCategories();
view.node_ptr = common::Span<Segment const>(split_categories_segments_);
return view;
}
private:
template <bool typed>
void LoadCategoricalSplit(Json const& in);
void SaveCategoricalSplit(Json* p_out) const;
// vector of nodes
std::vector<Node> nodes_;
// free node space, used during training process
std::vector<int> deleted_nodes_;
// stats of nodes
std::vector<RTreeNodeStat> stats_;
std::vector<FeatureType> split_types_;
// Categories for each internal node.
std::vector<uint32_t> split_categories_;
// Ptr to split categories of each node.
std::vector<Segment> split_categories_segments_;
// allocate a new node,
// !!!!!! NOTE: may cause BUG here, nodes.resize
bst_node_t AllocNode() {
if (param.num_deleted != 0) {
int nid = deleted_nodes_.back();
deleted_nodes_.pop_back();
nodes_[nid].Reuse();
--param.num_deleted;
return nid;
}
int nd = param.num_nodes++;
CHECK_LT(param.num_nodes, std::numeric_limits<int>::max())
<< "number of nodes in the tree exceed 2^31";
nodes_.resize(param.num_nodes);
stats_.resize(param.num_nodes);
split_types_.resize(param.num_nodes, FeatureType::kNumerical);
split_categories_segments_.resize(param.num_nodes);
return nd;
}
// delete a tree node, keep the parent field to allow trace back
void DeleteNode(int nid) {
CHECK_GE(nid, 1);
auto pid = (*this)[nid].Parent();
if (nid == (*this)[pid].LeftChild()) {
(*this)[pid].SetLeftChild(kInvalidNodeId);
} else {
(*this)[pid].SetRightChild(kInvalidNodeId);
}
deleted_nodes_.push_back(nid);
nodes_[nid].MarkDelete();
++param.num_deleted;
}
};
inline void RegTree::FVec::Init(size_t size) {
Entry e; e.flag = -1;
data_.resize(size);
std::fill(data_.begin(), data_.end(), e);
has_missing_ = true;
}
inline void RegTree::FVec::Fill(const SparsePage::Inst& inst) {
size_t feature_count = 0;
for (auto const& entry : inst) {
if (entry.index >= data_.size()) {
continue;
}
data_[entry.index].fvalue = entry.fvalue;
++feature_count;
}
has_missing_ = data_.size() != feature_count;
}
inline void RegTree::FVec::Drop(const SparsePage::Inst& inst) {
for (auto const& entry : inst) {
if (entry.index >= data_.size()) {
continue;
}
data_[entry.index].flag = -1;
}
has_missing_ = true;
}
inline size_t RegTree::FVec::Size() const {
return data_.size();
}
inline bst_float RegTree::FVec::GetFvalue(size_t i) const {
return data_[i].fvalue;
}
inline bool RegTree::FVec::IsMissing(size_t i) const {
return data_[i].flag == -1;
}
inline bool RegTree::FVec::HasMissing() const {
return has_missing_;
}
} // namespace xgboost
#endif // XGBOOST_TREE_MODEL_H_