/
sparse_page_dmatrix.cc
147 lines (130 loc) · 5.36 KB
/
sparse_page_dmatrix.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
/*!
* Copyright 2014-2021 by Contributors
* \file sparse_page_dmatrix.cc
* \brief The external memory version of Page Iterator.
* \author Tianqi Chen
*/
#include "./sparse_page_dmatrix.h"
#include "./simple_batch_iterator.h"
#include "gradient_index.h"
namespace xgboost {
namespace data {
MetaInfo &SparsePageDMatrix::Info() { return info_; }
const MetaInfo &SparsePageDMatrix::Info() const { return info_; }
SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle proxy_handle,
DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, float missing,
int32_t nthreads, std::string cache_prefix)
: proxy_{proxy_handle}, iter_{iter_handle}, reset_{reset}, next_{next}, missing_{missing},
nthreads_{nthreads}, cache_prefix_{std::move(cache_prefix)} {
cache_prefix_ = cache_prefix_.empty() ? "DMatrix" : cache_prefix_;
if (rabit::IsDistributed()) {
cache_prefix_ += ("-r" + std::to_string(rabit::GetRank()));
}
DMatrixProxy *proxy = MakeProxy(proxy_);
auto iter = DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>{
iter_, reset_, next_};
uint32_t n_batches = 0;
size_t n_features = 0;
size_t n_samples = 0;
size_t nnz = 0;
auto num_rows = [&]() {
return HostAdapterDispatch(
proxy, [](auto const &value) { return value.NumRows(); });
};
auto num_cols = [&]() {
return HostAdapterDispatch(
proxy, [](auto const &value) { return value.NumCols(); });
};
// the proxy is iterated together with the sparse page source so we can obtain all
// information in 1 pass.
for (auto const &page : this->GetRowBatchesImpl()) {
this->info_.Extend(std::move(proxy->Info()), false, false);
n_features = std::max(n_features, num_cols());
n_samples += num_rows();
nnz += page.data.Size();
n_batches++;
}
iter.Reset();
this->n_batches_ = n_batches;
this->info_.num_row_ = n_samples;
this->info_.num_col_ = n_features;
this->info_.num_nonzero_ = nnz;
rabit::Allreduce<rabit::op::Max>(&info_.num_col_, 1);
CHECK_NE(info_.num_col_, 0);
}
void SparsePageDMatrix::InitializeSparsePage() {
auto id = MakeCache(this, ".row.page", cache_prefix_, &cache_info_);
// Don't use proxy DMatrix once this is already initialized, this allows users to
// release the iterator and data.
if (cache_info_.at(id)->written) {
CHECK(sparse_page_source_);
sparse_page_source_->Reset();
return;
}
auto iter = DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>{
iter_, reset_, next_};
DMatrixProxy *proxy = MakeProxy(proxy_);
sparse_page_source_.reset(); // clear before creating new one to prevent conflicts.
sparse_page_source_ = std::make_shared<SparsePageSource>(
iter, proxy, this->missing_, this->nthreads_, this->info_.num_col_,
this->n_batches_, cache_info_.at(id));
}
BatchSet<SparsePage> SparsePageDMatrix::GetRowBatchesImpl() {
this->InitializeSparsePage();
auto begin_iter = BatchIterator<SparsePage>(sparse_page_source_);
return BatchSet<SparsePage>(BatchIterator<SparsePage>(begin_iter));
}
BatchSet<SparsePage> SparsePageDMatrix::GetRowBatches() {
return this->GetRowBatchesImpl();
}
BatchSet<CSCPage> SparsePageDMatrix::GetColumnBatches() {
auto id = MakeCache(this, ".col.page", cache_prefix_, &cache_info_);
CHECK_NE(this->Info().num_col_, 0);
this->InitializeSparsePage();
if (!column_source_) {
column_source_ = std::make_shared<CSCPageSource>(
this->missing_, this->nthreads_, this->Info().num_col_,
this->n_batches_, cache_info_.at(id), sparse_page_source_);
} else {
column_source_->Reset();
}
auto begin_iter = BatchIterator<CSCPage>(column_source_);
return BatchSet<CSCPage>(BatchIterator<CSCPage>(begin_iter));
}
BatchSet<SortedCSCPage> SparsePageDMatrix::GetSortedColumnBatches() {
auto id = MakeCache(this, ".sorted.col.page", cache_prefix_, &cache_info_);
CHECK_NE(this->Info().num_col_, 0);
this->InitializeSparsePage();
if (!sorted_column_source_) {
sorted_column_source_ = std::make_shared<SortedCSCPageSource>(
this->missing_, this->nthreads_, this->Info().num_col_,
this->n_batches_, cache_info_.at(id), sparse_page_source_);
} else {
sorted_column_source_->Reset();
}
auto begin_iter = BatchIterator<SortedCSCPage>(sorted_column_source_);
return BatchSet<SortedCSCPage>(BatchIterator<SortedCSCPage>(begin_iter));
}
BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(const BatchParam& param) {
CHECK_GE(param.max_bin, 2);
// External memory is not support
if (!ghist_index_source_ || (param != batch_param_ && param != BatchParam{})) {
this->InitializeSparsePage();
ghist_index_source_.reset(new GHistIndexMatrix{this, param.max_bin});
batch_param_ = param;
}
this->InitializeSparsePage();
auto begin_iter = BatchIterator<GHistIndexMatrix>(
new SimpleBatchIteratorImpl<GHistIndexMatrix>(ghist_index_source_));
return BatchSet<GHistIndexMatrix>(begin_iter);
}
#if !defined(XGBOOST_USE_CUDA)
BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(const BatchParam& param) {
common::AssertGPUSupport();
auto begin_iter = BatchIterator<EllpackPage>(ellpack_page_source_);
return BatchSet<EllpackPage>(BatchIterator<EllpackPage>(begin_iter));
}
#endif // !defined(XGBOOST_USE_CUDA)
} // namespace data
} // namespace xgboost