forked from dmlc/xgboost
/
iterative_dmatrix.h
141 lines (122 loc) · 5.27 KB
/
iterative_dmatrix.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
/*!
* Copyright 2020-2022 by Contributors
* \file iterative_dmatrix.h
*/
#ifndef XGBOOST_DATA_ITERATIVE_DMATRIX_H_
#define XGBOOST_DATA_ITERATIVE_DMATRIX_H_
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "proxy_dmatrix.h"
#include "simple_batch_iterator.h"
#include "xgboost/base.h"
#include "xgboost/c_api.h"
#include "xgboost/data.h"
namespace xgboost {
namespace common {
class HistogramCuts;
}
namespace data {
/**
* \brief DMatrix type for `QuantileDMatrix`, the naming `IterativeDMatix` is due to its
* construction process.
*
* `QuantileDMatrix` is an intermediate storage for quantilization results including
* quantile cuts and histogram index. Quantilization is designed to be performed on stream
* of data (or batches of it). As a result, the `QuantileDMatrix` is also designed to work
* with batches of data. During initializaion, it walks through the data multiple times
* iteratively in order to perform quantilization. This design helps us reduce memory
* usage significantly by avoiding data concatenation along with removing the CSR matrix
* `SparsePage`. However, it has its limitation (can be fixed if needed):
*
* - It's only supported by hist tree method (both CPU and GPU) since approx requires a
* re-calculation of quantiles for each iteration. We can fix this by retaining a
* reference to the callback if there are feature requests.
*
* - The CPU format and the GPU format are different, the former uses a CSR + CSC for
* histogram index while the latter uses only Ellpack.
*/
class IterativeDMatrix : public DMatrix {
MetaInfo info_;
Context ctx_;
BatchParam batch_param_;
std::shared_ptr<EllpackPage> ellpack_;
std::shared_ptr<GHistIndexMatrix> ghist_;
DMatrixHandle proxy_;
DataIterResetCallback *reset_;
XGDMatrixCallbackNext *next_;
void CheckParam(BatchParam const ¶m) {
// FIXME(Jiamingy): https://github.com/dmlc/xgboost/issues/7976
if (param.max_bin != batch_param_.max_bin && param.max_bin != 0) {
LOG(WARNING) << "Inconsistent max_bin between Quantile DMatrix and Booster:" << param.max_bin
<< " vs. " << batch_param_.max_bin;
}
CHECK(!param.regen && param.hess.empty())
<< "Only `hist` and `gpu_hist` tree method can use `QuantileDMatrix`.";
}
template <typename Page>
static auto InvalidTreeMethod() {
LOG(FATAL) << "Only `hist` and `gpu_hist` tree method can use `QuantileDMatrix`.";
return BatchSet<Page>(BatchIterator<Page>(nullptr));
}
void InitFromCUDA(DataIterHandle iter, float missing, std::shared_ptr<DMatrix> ref);
void InitFromCPU(DataIterHandle iter_handle, float missing, std::shared_ptr<DMatrix> ref);
public:
explicit IterativeDMatrix(DataIterHandle iter_handle, DMatrixHandle proxy,
std::shared_ptr<DMatrix> ref, DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, float missing, int nthread,
bst_bin_t max_bin);
~IterativeDMatrix() override = default;
bool EllpackExists() const override { return static_cast<bool>(ellpack_); }
bool GHistIndexExists() const override { return static_cast<bool>(ghist_); }
bool SparsePageExists() const override { return false; }
DMatrix *Slice(common::Span<int32_t const>) override {
LOG(FATAL) << "Slicing DMatrix is not supported for Quantile DMatrix.";
return nullptr;
}
BatchSet<SparsePage> GetRowBatches() override {
LOG(FATAL) << "Not implemented.";
return BatchSet<SparsePage>(BatchIterator<SparsePage>(nullptr));
}
BatchSet<CSCPage> GetColumnBatches() override { return InvalidTreeMethod<CSCPage>(); }
BatchSet<SortedCSCPage> GetSortedColumnBatches() override {
return InvalidTreeMethod<SortedCSCPage>();
}
BatchSet<GHistIndexMatrix> GetGradientIndex(BatchParam const ¶m) override;
BatchSet<EllpackPage> GetEllpackBatches(const BatchParam ¶m) override;
bool SingleColBlock() const override { return true; }
MetaInfo &Info() override { return info_; }
MetaInfo const &Info() const override { return info_; }
Context const *Ctx() const override { return &ctx_; }
};
/**
* \brief Get quantile cuts from reference Quantile DMatrix.
*/
void GetCutsFromRef(std::shared_ptr<DMatrix> ref_, bst_feature_t n_features, BatchParam p,
common::HistogramCuts *p_cuts);
/**
* \brief Get quantile cuts from ellpack page.
*/
void GetCutsFromEllpack(EllpackPage const &page, common::HistogramCuts *cuts);
#if !defined(XGBOOST_USE_CUDA)
inline void IterativeDMatrix::InitFromCUDA(DataIterHandle iter, float missing,
std::shared_ptr<DMatrix> ref) {
// silent the warning about unused variables.
(void)(proxy_);
(void)(reset_);
(void)(next_);
common::AssertGPUSupport();
}
inline BatchSet<EllpackPage> IterativeDMatrix::GetEllpackBatches(const BatchParam ¶m) {
common::AssertGPUSupport();
auto begin_iter = BatchIterator<EllpackPage>(new SimpleBatchIteratorImpl<EllpackPage>(ellpack_));
return BatchSet<EllpackPage>(BatchIterator<EllpackPage>(begin_iter));
}
inline void GetCutsFromEllpack(EllpackPage const &, common::HistogramCuts *) {
common::AssertGPUSupport();
}
#endif // !defined(XGBOOST_USE_CUDA)
} // namespace data
} // namespace xgboost
#endif // XGBOOST_DATA_ITERATIVE_DMATRIX_H_