forked from dmlc/xgboost
/
adaptive.cc
127 lines (112 loc) · 3.91 KB
/
adaptive.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
/*!
* Copyright 2022 by XGBoost Contributors
*/
#include "adaptive.h"
#include <limits>
#include <vector>
#include "../common/common.h"
#include "../common/numeric.h"
#include "../common/stats.h"
#include "../common/threading_utils.h"
#include "xgboost/tree_model.h"
namespace xgboost {
namespace obj {
namespace detail {
void EncodeTreeLeafHost(RegTree const& tree, std::vector<bst_node_t> const& position,
std::vector<size_t>* p_nptr, std::vector<bst_node_t>* p_nidx,
std::vector<size_t>* p_ridx) {
auto& nptr = *p_nptr;
auto& nidx = *p_nidx;
auto& ridx = *p_ridx;
ridx = common::ArgSort<size_t>(position);
std::vector<bst_node_t> sorted_pos(position);
// permutation
for (size_t i = 0; i < position.size(); ++i) {
sorted_pos[i] = position[ridx[i]];
}
// find the first non-sampled row
size_t begin_pos =
std::distance(sorted_pos.cbegin(), std::find_if(sorted_pos.cbegin(), sorted_pos.cend(),
[](bst_node_t nidx) { return nidx >= 0; }));
CHECK_LE(begin_pos, sorted_pos.size());
std::vector<bst_node_t> leaf;
tree.WalkTree([&](bst_node_t nidx) {
if (tree[nidx].IsLeaf()) {
leaf.push_back(nidx);
}
return true;
});
if (begin_pos == sorted_pos.size()) {
nidx = leaf;
return;
}
auto beg_it = sorted_pos.begin() + begin_pos;
common::RunLengthEncode(beg_it, sorted_pos.end(), &nptr);
CHECK_GT(nptr.size(), 0);
// skip the sampled rows in indptr
std::transform(nptr.begin(), nptr.end(), nptr.begin(),
[begin_pos](size_t ptr) { return ptr + begin_pos; });
size_t n_leaf = nptr.size() - 1;
auto n_unique = std::unique(beg_it, sorted_pos.end()) - beg_it;
CHECK_EQ(n_unique, n_leaf);
nidx.resize(n_leaf);
std::copy(beg_it, beg_it + n_unique, nidx.begin());
if (n_leaf != leaf.size()) {
FillMissingLeaf(leaf, &nidx, &nptr);
}
}
void UpdateTreeLeafHost(Context const* ctx, std::vector<bst_node_t> const& position,
MetaInfo const& info, HostDeviceVector<float> const& predt, float alpha,
RegTree* p_tree) {
auto& tree = *p_tree;
std::vector<bst_node_t> nidx;
std::vector<size_t> nptr;
std::vector<size_t> ridx;
EncodeTreeLeafHost(*p_tree, position, &nptr, &nidx, &ridx);
size_t n_leaf = nidx.size();
if (nptr.empty()) {
std::vector<float> quantiles;
UpdateLeafValues(&quantiles, nidx, p_tree);
return;
}
CHECK(!position.empty());
std::vector<float> quantiles(n_leaf, 0);
std::vector<int32_t> n_valids(n_leaf, 0);
auto const& h_node_idx = nidx;
auto const& h_node_ptr = nptr;
CHECK_LE(h_node_ptr.back(), info.num_row_);
// loop over each leaf
common::ParallelFor(quantiles.size(), ctx->Threads(), [&](size_t k) {
auto nidx = h_node_idx[k];
CHECK(tree[nidx].IsLeaf());
CHECK_LT(k + 1, h_node_ptr.size());
size_t n = h_node_ptr[k + 1] - h_node_ptr[k];
auto h_row_set = common::Span<size_t const>{ridx}.subspan(h_node_ptr[k], n);
// multi-target not yet supported.
auto h_labels = info.labels.HostView().Slice(linalg::All(), 0);
auto const& h_predt = predt.ConstHostVector();
auto h_weights = linalg::MakeVec(&info.weights_);
auto iter = common::MakeIndexTransformIter([&](size_t i) -> float {
auto row_idx = h_row_set[i];
return h_labels(row_idx) - h_predt[row_idx];
});
auto w_it = common::MakeIndexTransformIter([&](size_t i) -> float {
auto row_idx = h_row_set[i];
return h_weights(row_idx);
});
float q{0};
if (info.weights_.Empty()) {
q = common::Quantile(alpha, iter, iter + h_row_set.size());
} else {
q = common::WeightedQuantile(alpha, iter, iter + h_row_set.size(), w_it);
}
if (std::isnan(q)) {
CHECK(h_row_set.empty());
}
quantiles.at(k) = q;
});
UpdateLeafValues(&quantiles, nidx, p_tree);
}
} // namespace detail
} // namespace obj
} // namespace xgboost