-
-
Notifications
You must be signed in to change notification settings - Fork 8.7k
/
gradient_index_page_source.h
43 lines (38 loc) · 1.42 KB
/
gradient_index_page_source.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
/*!
* Copyright 2021 by XGBoost Contributors
*/
#ifndef XGBOOST_DATA_GRADIENT_INDEX_PAGE_SOURCE_H_
#define XGBOOST_DATA_GRADIENT_INDEX_PAGE_SOURCE_H_
#include <memory>
#include <utility>
#include "gradient_index.h"
#include "sparse_page_source.h"
namespace xgboost {
namespace data {
class GradientIndexPageSource : public PageSourceIncMixIn<GHistIndexMatrix> {
common::HistogramCuts cuts_;
bool is_dense_;
int32_t max_bin_per_feat_;
common::Span<FeatureType const> feature_types_;
float sparse_thresh_;
public:
GradientIndexPageSource(float missing, int nthreads, bst_feature_t n_features, size_t n_batches,
std::shared_ptr<Cache> cache, BatchParam param,
common::HistogramCuts cuts, bool is_dense, int32_t max_bin_per_feat,
common::Span<FeatureType const> feature_types, float sparse_thresh,
std::shared_ptr<SparsePageSource> source)
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache,
!std::isnan(sparse_thresh)),
cuts_{std::move(cuts)},
is_dense_{is_dense},
max_bin_per_feat_{max_bin_per_feat},
feature_types_{feature_types},
sparse_thresh_{sparse_thresh} {
this->source_ = source;
this->Fetch();
}
void Fetch() final;
};
} // namespace data
} // namespace xgboost
#endif // XGBOOST_DATA_GRADIENT_INDEX_PAGE_SOURCE_H_