Skip to content

Commit

Permalink
External memory support for hist tree method.
Browse files Browse the repository at this point in the history
Rewrite approx.

Save cuts.

Prototype on fetching.

Copy the code.

Simple test.

Add gpair to batch parameter.

Add hessian to batch parameter.

Move.

Pass hessian into sketching.

Extract a push page function.

Make private.

Lint.

Revert debug.

Simple DMatrix.

Regenerate the index.

ama.

Clang tidy.

Retain page.

Fix.

Lint.

Tidy.

Integer backed enum.

Convert to uint32_t.

Prototype for saving gidx.

Save cuts.

Prototype on fetching.

Copy the code.

Simple test.

Add gpair to batch parameter.

Add hessian to batch parameter.

Move.

Pass hessian into sketching.

Extract a push page function.

Make private.

Lint.

Revert debug.

Simple DMatrix.

Initial port.

Pass in hessian.

Init column sampler.

Unused code.

Use ctx.

Merge sampling.

Use ctx in partition.

Fix init root.

Force regenerate the sketch.

Create a ctx.

Get it compile.

Don't use const method.

Use page id.

Pass in base row id.

Pass the cut instead.

Small fixes.

Debug.

Fix bin size.

Debug.

Fixes.

Debug.

Fix empty partition.

Remove comment.

Lint.

Fix tests compilation.

Remove check.

Merge some fixes.

fix.

Fix fetching.

lint.

Extract expand entry.

Lint.

Fix unittests.

Fix windows build.

Fix comparison.

Make const.

Note.

const.

Fix reduce hist.

Fix sparse data.

Avoid implicit conversion.

private.

mem leak.

Remove skip initialization.

Use maximum space.

demo.

lint.

File link tags.

ama.

Fix redefinition.

Fix ranking.

use npy.

Comment.

Tune it down.

Specify the tree method.

Get rid of the duplicated partitioner.

Allocate task.

Tests.

make batches.

Log.

Remove span.

Revert "make batches."

This reverts commit 33f7072.

small cleanup.

Lint.

Revert demo.

Better make batches.

Demo.

Test for grow policy.

Test feature weights.

small cleanup.

Remove iterator in evaluation.

Fix dask test.

Pass n_threads.

Start implementation for categorical data.

Fix.

Add apply split.

Enumerate splits.

Enable sklearn.

Works.

d_step.

update.

Pass feature types into index.

Search cut.

Add test.

As cat.

Fix cut.

Extract some tests.

Fix.

Interesting case.

Add Python tests.

Cleanup.

Revert "Interesting case."

This reverts commit 6bbaac2.

Bin.

Fix.

Dispatch.

Remove subtraction trick.

Lint

Use multiple buffers.

Revert "Use multiple buffers."

This reverts commit 2849f57.

Test for external memory.

Format.

Partition based categorical split.

Remove debug code.

Fix.

Lint.

Fix test.

Fix demo.

Fix.

Add test.

Remove use of omp func.

name.

Fix.

test.

Make LCG impl compliant to std.

Fix test.

Constexpr.

Use unsigned type.

osx

More test.

Rebase error.

Rebase error.

Rebase error.

Reverse unused changes.

Config.

Remove weird set thread.

External memory test.

Revert changes.

Cleanup.

wording.

Fix doc.

Test monotone constraint.

Extract test for gamma.

typo.

Safe guard.

Cleanup && comments.

Update Python documents.

Add push col page.

hack.

Port the sketch.

Opt search bin.

Cleanup.

Reduce the gap.

Fix sum hessian.

Start cleaning up.

Duplicated.

Cleanup.

lint.

Test.

Port the changes.

test.

Port the changes.

Fixes && cleanup.

Decide whether should sorted sketch be used.

tests.

Use regen.

Lint.

Revert.

init.

empty dataset.

Handle empty dataset directly in quantile.

empty.

Update tests.

Implement external memory support for hist with dense data.

Rewrite approx.

Save cuts.

Prototype on fetching.

Copy the code.

Simple test.

Add gpair to batch parameter.

Add hessian to batch parameter.

Move.

Pass hessian into sketching.

Extract a push page function.

Make private.

Lint.

Revert debug.

Simple DMatrix.

Regenerate the index.

ama.

Clang tidy.

Retain page.

Fix.

Lint.

Tidy.

Integer backed enum.

Convert to uint32_t.

Prototype for saving gidx.

Save cuts.

Prototype on fetching.

Copy the code.

Simple test.

Add gpair to batch parameter.

Add hessian to batch parameter.

Move.

Pass hessian into sketching.

Extract a push page function.

Make private.

Lint.

Revert debug.

Simple DMatrix.

Initial port.

Pass in hessian.

Init column sampler.

Unused code.

Use ctx.

Merge sampling.

Use ctx in partition.

Fix init root.

Force regenerate the sketch.

Create a ctx.

Get it compile.

Don't use const method.

Use page id.

Pass in base row id.

Pass the cut instead.

Small fixes.

Debug.

Fix bin size.

Debug.

Fixes.

Debug.

Fix empty partition.

Remove comment.

Lint.

Fix tests compilation.

Remove check.

Merge some fixes.

fix.

Fix fetching.

lint.

Extract expand entry.

Lint.

Fix unittests.

Fix windows build.

Fix comparison.

Make const.

Note.

const.

Fix reduce hist.

Fix sparse data.

Avoid implicit conversion.

private.

mem leak.

Remove skip initialization.

Use maximum space.

demo.

lint.

File link tags.

ama.

Fix redefinition.

Fix ranking.

use npy.

Comment.

Tune it down.

Specify the tree method.

Get rid of the duplicated partitioner.

Allocate task.

Tests.

make batches.

Log.

Remove span.

Revert "make batches."

This reverts commit 33f7072.

small cleanup.

Lint.

Revert demo.

Better make batches.

Demo.

Test for grow policy.

Test feature weights.

small cleanup.

Remove iterator in evaluation.

Fix dask test.

Pass n_threads.

Start implementation for categorical data.

Fix.

Add apply split.

Enumerate splits.

Enable sklearn.

Works.

d_step.

update.

Pass feature types into index.

Search cut.

Add test.

As cat.

Fix cut.

Extract some tests.

Fix.

Interesting case.

Add Python tests.

Cleanup.

Revert "Interesting case."

This reverts commit 6bbaac2.

Bin.

Fix.

Dispatch.

Remove subtraction trick.

Lint

Use multiple buffers.

Revert "Use multiple buffers."

This reverts commit 2849f57.

Test for external memory.

Format.

Partition based categorical split.

Remove debug code.

Fix.

Lint.

Fix test.

Fix demo.

Fix.

Add test.

Remove use of omp func.

name.

Fix.

test.

Make LCG impl compliant to std.

Fix test.

Constexpr.

Use unsigned type.

osx

More test.

Rebase error.

Rebase error.

Rebase error.

Reverse unused changes.

Config.

Remove weird set thread.

External memory test.

Revert changes.

Cleanup.

wording.

Fix doc.

Test monotone constraint.

Extract test for gamma.

typo.

Safe guard.

Cleanup && comments.

Update Python documents.

Add push col page.

hack.

Port the sketch.

Opt search bin.

Cleanup.

Reduce the gap.

Fix sum hessian.

Start cleaning up.

Duplicated.

Cleanup.

lint.

Test.

Port the changes.

test.

Port the changes.

Fixes && cleanup.

Decide whether should sorted sketch be used.

tests.

Extract row partitioner.

Work on et.

Remove test.

base rowid.

Fix.

Fix reduce grad.

Generate column matrix.

Port the changes from updated driver.

test sample.

Cleanup.

Fixes.

fix clang.

debug.

Fix test.

Revert changes.

Lint.

Initial commit for sparse page.

fixes.

fix tests.

Remove column matrix.

Make sure ref is used.

Remove any_missing & gmat.

Remove part builder.

Fix approx test.

Remove thread test.

Fix sketch tests.

Avoid a loop.

fix evaluation tests.

fix ghist index test.

fix approx test.

Fix histogram test.

Note.

start working on io.

IO.

Fix empty.

Print time message.

Remove the need to load sparse page.

benchmark the external memory. [don't upload]

Revert "benchmark the external memory. [don't upload]"

This reverts commit 7fe631cd359cf6eb256b3aa08a39a2917203e045.

log info.

Fix rebase.

fix rebase.

fix.

Cleanup & more tests.

lint.

fixes

ellpack.

ellpack.

spec.

Add tests.

type.

apple.

s390x

s390x.

fix rebase.

remove renamed file.
  • Loading branch information
trivialfis committed Feb 15, 2022
1 parent 93eebe8 commit bb9ee02
Show file tree
Hide file tree
Showing 26 changed files with 696 additions and 847 deletions.
2 changes: 1 addition & 1 deletion demo/guide-python/external_memory.py
Expand Up @@ -79,7 +79,7 @@ def main(tmpdir: str) -> xgboost.Booster:

# Other tree methods including ``hist`` and ``gpu_hist`` also work, but has some
# caveats. This is still an experimental feature.
booster = xgboost.train({"tree_method": "approx"}, Xy, evals=[(Xy, "Train")])
booster = xgboost.train({"tree_method": "hist", "max_depth": 2}, Xy, evals=[(Xy, "Train")], num_boost_round=1)
return booster


Expand Down
2 changes: 1 addition & 1 deletion demo/guide-python/feature_weights.py
Expand Up @@ -27,7 +27,7 @@ def main(args):
dtrain.set_info(feature_weights=fw)

bst = xgboost.train({'tree_method': 'hist',
'colsample_bynode': 0.5},
'colsample_bynode': 0.2},
dtrain, num_boost_round=10,
evals=[(dtrain, 'd')])
feature_map = bst.get_fscore()
Expand Down
1 change: 1 addition & 0 deletions include/xgboost/data.h
Expand Up @@ -240,6 +240,7 @@ struct BatchParam {
if (hess.empty() && other.hess.empty()) {
return gpu_id != other.gpu_id || max_bin != other.max_bin;
}
// fixme: sprse_thresh
return gpu_id != other.gpu_id || max_bin != other.max_bin || hess.data() != other.hess.data();
}
bool operator==(BatchParam const& other) const {
Expand Down
293 changes: 171 additions & 122 deletions src/common/column_matrix.h

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions src/common/hist_util.cc
Expand Up @@ -165,6 +165,7 @@ void BuildHistKernel(const std::vector<GradientPair> &gpair,
any_missing ? get_row_ptr(rid[i]) : get_rid(rid[i]) * n_features;
const size_t icol_end =
any_missing ? get_row_ptr(rid[i] + 1) : icol_start + n_features;
CHECK_LE(icol_end, gmat.index.Size());

const size_t row_size = icol_end - icol_start;
const size_t idx_gh = two * rid[i];
Expand Down
40 changes: 21 additions & 19 deletions src/common/partition_builder.h
@@ -1,5 +1,5 @@
/*!
* Copyright 2021 by Contributors
* Copyright 2021-2022 by Contributors
* \file row_set.h
* \brief Quick Utility to compute subset of rows
* \author Philip Cho, Tianqi Chen
Expand Down Expand Up @@ -48,16 +48,20 @@ class PartitionBuilder {
// Analog of std::stable_partition, but in no-inplace manner
template <bool default_left, bool any_missing, typename ColumnType>
inline std::pair<size_t, size_t> PartitionKernel(const ColumnType& column,
common::Span<const size_t> rid_span, const int32_t split_cond,
common::Span<size_t> left_part, common::Span<size_t> right_part) {
common::Span<const size_t> rid_span,
const int32_t split_cond,
common::Span<size_t> left_part,
common::Span<size_t> right_part,
size_t base_rowid) {
size_t* p_left_part = left_part.data();
size_t* p_right_part = right_part.data();
size_t nleft_elems = 0;
size_t nright_elems = 0;
auto state = column.GetInitialState(rid_span.front());
auto state = column.GetInitialState(rid_span.front() - base_rowid);

for (auto rid : rid_span) {
const int32_t bin_id = column.GetBinIdx(rid, &state);
CHECK_GE(rid, base_rowid);
const int32_t bin_id = column.GetBinIdx(rid - base_rowid, &state);
if (any_missing && bin_id == ColumnType::kMissingId) {
if (default_left) {
p_left_part[nleft_elems++] = rid;
Expand Down Expand Up @@ -97,13 +101,11 @@ class PartitionBuilder {

template <typename BinIdxType, bool any_missing>
void Partition(const size_t node_in_set, const size_t nid, const common::Range1d range,
const int32_t split_cond,
const ColumnMatrix& column_matrix, const RegTree& tree, const size_t* rid) {
const int32_t split_cond, const ColumnMatrix& column_matrix, const RegTree& tree,
const size_t* rid, size_t base_rowid) {
common::Span<const size_t> rid_span(rid + range.begin(), rid + range.end());
common::Span<size_t> left = GetLeftBuffer(node_in_set,
range.begin(), range.end());
common::Span<size_t> right = GetRightBuffer(node_in_set,
range.begin(), range.end());
common::Span<size_t> left = GetLeftBuffer(node_in_set, range.begin(), range.end());
common::Span<size_t> right = GetRightBuffer(node_in_set, range.begin(), range.end());
const bst_uint fid = tree[nid].SplitIndex();
const bool default_left = tree[nid].DefaultLeft();
const auto column_ptr = column_matrix.GetColumn<BinIdxType, any_missing>(fid);
Expand All @@ -114,22 +116,22 @@ class PartitionBuilder {
const common::DenseColumn<BinIdxType, any_missing>& column =
static_cast<const common::DenseColumn<BinIdxType, any_missing>& >(*(column_ptr.get()));
if (default_left) {
child_nodes_sizes = PartitionKernel<true, any_missing>(column, rid_span,
split_cond, left, right);
child_nodes_sizes = PartitionKernel<true, any_missing>(column, rid_span, split_cond, left,
right, base_rowid);
} else {
child_nodes_sizes = PartitionKernel<false, any_missing>(column, rid_span,
split_cond, left, right);
child_nodes_sizes = PartitionKernel<false, any_missing>(column, rid_span, split_cond, left,
right, base_rowid);
}
} else {
CHECK_EQ(any_missing, true);
const common::SparseColumn<BinIdxType>& column
= static_cast<const common::SparseColumn<BinIdxType>& >(*(column_ptr.get()));
if (default_left) {
child_nodes_sizes = PartitionKernel<true, any_missing>(column, rid_span,
split_cond, left, right);
child_nodes_sizes = PartitionKernel<true, any_missing>(column, rid_span, split_cond, left,
right, base_rowid);
} else {
child_nodes_sizes = PartitionKernel<false, any_missing>(column, rid_span,
split_cond, left, right);
child_nodes_sizes = PartitionKernel<false, any_missing>(column, rid_span, split_cond, left,
right, base_rowid);
}
}

Expand Down
9 changes: 8 additions & 1 deletion src/data/ellpack_page_source.cu
@@ -1,5 +1,5 @@
/*!
* Copyright 2019-2021 XGBoost contributors
* Copyright 2019-2022 XGBoost contributors
*/
#include <memory>
#include <utility>
Expand All @@ -12,6 +12,13 @@ namespace data {
void EllpackPageSource::Fetch() {
dh::safe_cuda(cudaSetDevice(param_.gpu_id));
if (!this->ReadCache()) {
if (count_ != 0 && !sync_) {
// source is initialized to be the 0th page during construction, so when count_ is 0
// there's no need to increment the source.
++(*source_);
}
// This is not read from cache so we still need it to be synced with sparse page source.
CHECK_EQ(count_, source_->Iter());
auto const &csr = source_->Page();
this->page_.reset(new EllpackPage{});
auto *impl = this->page_->Impl();
Expand Down
22 changes: 12 additions & 10 deletions src/data/ellpack_page_source.h
@@ -1,5 +1,5 @@
/*!
* Copyright 2019-2021 by XGBoost Contributors
* Copyright 2019-2022 by XGBoost Contributors
*/

#ifndef XGBOOST_DATA_ELLPACK_PAGE_SOURCE_H_
Expand All @@ -25,15 +25,17 @@ class EllpackPageSource : public PageSourceIncMixIn<EllpackPage> {
std::unique_ptr<common::HistogramCuts> cuts_;

public:
EllpackPageSource(
float missing, int nthreads, bst_feature_t n_features, size_t n_batches,
std::shared_ptr<Cache> cache, BatchParam param,
std::unique_ptr<common::HistogramCuts> cuts, bool is_dense,
size_t row_stride, common::Span<FeatureType const> feature_types,
std::shared_ptr<SparsePageSource> source)
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache),
is_dense_{is_dense}, row_stride_{row_stride}, param_{std::move(param)},
feature_types_{feature_types}, cuts_{std::move(cuts)} {
EllpackPageSource(float missing, int nthreads, bst_feature_t n_features, size_t n_batches,
std::shared_ptr<Cache> cache, BatchParam param,
std::unique_ptr<common::HistogramCuts> cuts, bool is_dense, size_t row_stride,
common::Span<FeatureType const> feature_types,
std::shared_ptr<SparsePageSource> source)
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache, false),
is_dense_{is_dense},
row_stride_{row_stride},
param_{std::move(param)},
feature_types_{feature_types},
cuts_{std::move(cuts)} {
this->source_ = source;
this->Fetch();
}
Expand Down
28 changes: 27 additions & 1 deletion src/data/gradient_index.cc
Expand Up @@ -147,7 +147,6 @@ void GHistIndexMatrix::Init(DMatrix *p_fmat, int max_bins, double sparse_thresh,
hit_count.resize(nbins, 0);
hit_count_tloc_.resize(n_threads * nbins, 0);

this->p_fmat = p_fmat;
size_t new_size = 1;
for (const auto &batch : p_fmat->GetBatches<SparsePage>()) {
new_size += batch.Size();
Expand All @@ -167,6 +166,16 @@ void GHistIndexMatrix::Init(DMatrix *p_fmat, int max_bins, double sparse_thresh,
prev_sum = row_ptr[rbegin + batch.Size()];
rbegin += batch.Size();
}
this->columns_ = std::make_unique<common::ColumnMatrix>();

// hessian is empty when hist tree method is used or when dataset is empty
if (hess.empty() && !std::isnan(sparse_thresh)) {
// hist
CHECK(!sorted_sketch);
for (auto const &page : p_fmat->GetBatches<SparsePage>()) {
this->columns_->Init(page, *this, sparse_thresh, n_threads);
}
}
}

void GHistIndexMatrix::Init(SparsePage const &batch, common::Span<FeatureType const> ft,
Expand All @@ -190,6 +199,10 @@ void GHistIndexMatrix::Init(SparsePage const &batch, common::Span<FeatureType co
size_t prev_sum = 0;

this->PushBatch(batch, ft, rbegin, prev_sum, nbins, n_threads);
this->columns_ = std::make_unique<common::ColumnMatrix>();
if (!std::isnan(sparse_thresh)) {
this->columns_->Init(batch, *this, sparse_thresh, n_threads);
}
}

void GHistIndexMatrix::ResizeIndex(const size_t n_index, const bool isDense) {
Expand All @@ -206,4 +219,17 @@ void GHistIndexMatrix::ResizeIndex(const size_t n_index, const bool isDense) {
index.Resize((sizeof(uint32_t)) * n_index);
}
}

common::ColumnMatrix const &GHistIndexMatrix::Transpose() const {
CHECK(columns_);
return *columns_;
}

bool GHistIndexMatrix::ReadColumnPage(dmlc::SeekStream *fi) {
return this->columns_->Read(fi, this->cut.Ptrs().data());
}

size_t GHistIndexMatrix::WriteColumnPage(dmlc::Stream *fo) const {
return this->columns_->Write(fo);
}
} // namespace xgboost
7 changes: 5 additions & 2 deletions src/data/gradient_index.h
Expand Up @@ -33,7 +33,6 @@ class GHistIndexMatrix {
std::vector<size_t> hit_count;
/*! \brief The corresponding cuts */
common::HistogramCuts cut;
DMatrix* p_fmat;
/*! \brief max_bin for each feature. */
size_t max_num_bins;
/*! \brief base row index for current page (used by external memory) */
Expand Down Expand Up @@ -108,8 +107,12 @@ class GHistIndexMatrix {
return row_ptr.empty() ? 0 : row_ptr.size() - 1;
}

bool ReadColumnPage(dmlc::SeekStream* fi);
size_t WriteColumnPage(dmlc::Stream* fo) const;

common::ColumnMatrix const& Transpose() const;

private:
// unused at the moment: https://github.com/dmlc/xgboost/pull/7531
std::unique_ptr<common::ColumnMatrix> columns_;
std::vector<size_t> hit_count_tloc_;
bool isDense_;
Expand Down
8 changes: 6 additions & 2 deletions src/data/gradient_index_format.cc
@@ -1,13 +1,13 @@
/*!
* Copyright 2021 XGBoost contributors
* Copyright 2021-2022 XGBoost contributors
*/
#include "sparse_page_writer.h"
#include "gradient_index.h"
#include "histogram_cut_format.h"

namespace xgboost {
namespace data {

// fixme: io for column matrix.
class GHistIndexRawFormat : public SparsePageFormat<GHistIndexMatrix> {
public:
bool Read(GHistIndexMatrix* page, dmlc::SeekStream* fi) override {
Expand Down Expand Up @@ -55,6 +55,8 @@ class GHistIndexRawFormat : public SparsePageFormat<GHistIndexMatrix> {
return false;
}
page->SetDense(is_dense);

page->ReadColumnPage(fi);
return true;
}

Expand Down Expand Up @@ -93,6 +95,8 @@ class GHistIndexRawFormat : public SparsePageFormat<GHistIndexMatrix> {
bytes += sizeof(page.base_rowid);
fo->Write(page.IsDense());
bytes += sizeof(page.IsDense());

bytes += page.WriteColumnPage(fo);
return bytes;
}
};
Expand Down
7 changes: 7 additions & 0 deletions src/data/gradient_index_page_source.cc
Expand Up @@ -7,6 +7,13 @@ namespace xgboost {
namespace data {
void GradientIndexPageSource::Fetch() {
if (!this->ReadCache()) {
if (count_ != 0 && !sync_) {
// source is initialized to be the 0th page during construction, so when count_ is 0
// there's no need to increment the source.
++(*source_);
}
// This is not read from cache so we still need it to be synced with sparse page source.
CHECK_EQ(count_, source_->Iter());
auto const& csr = source_->Page();
this->page_.reset(new GHistIndexMatrix());
CHECK_NE(cuts_.Values().size(), 0);
Expand Down
7 changes: 4 additions & 3 deletions src/data/gradient_index_page_source.h
Expand Up @@ -22,13 +22,14 @@ class GradientIndexPageSource : public PageSourceIncMixIn<GHistIndexMatrix> {
public:
GradientIndexPageSource(float missing, int nthreads, bst_feature_t n_features, size_t n_batches,
std::shared_ptr<Cache> cache, BatchParam param,
common::HistogramCuts cuts, bool is_dense, int32_t max_bin_per_feat,
common::HistogramCuts cuts, bool is_dense,
common::Span<FeatureType const> feature_types,
std::shared_ptr<SparsePageSource> source)
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache),
: PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache,
std::isnan(param.sparse_thresh)),
cuts_{std::move(cuts)},
is_dense_{is_dense},
max_bin_per_feat_{max_bin_per_feat},
max_bin_per_feat_{param.max_bin},
feature_types_{feature_types},
sparse_thresh_{param.sparse_thresh} {
this->source_ = source;
Expand Down
22 changes: 3 additions & 19 deletions src/data/sparse_page_dmatrix.cc
Expand Up @@ -159,21 +159,6 @@ BatchSet<SortedCSCPage> SparsePageDMatrix::GetSortedColumnBatches() {

BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(const BatchParam &param) {
CHECK_GE(param.max_bin, 2);
if (param.hess.empty() && !param.regen) {
// hist method doesn't support full external memory implementation, so we concatenate
// all index here.
if (!ghist_index_page_ || (param != batch_param_ && param != BatchParam{})) {
this->InitializeSparsePage();
ghist_index_page_.reset(new GHistIndexMatrix{this, param.max_bin, param.sparse_thresh,
param.regen, ctx_.Threads()});
this->InitializeSparsePage();
batch_param_ = param;
}
auto begin_iter = BatchIterator<GHistIndexMatrix>(
new SimpleBatchIteratorImpl<GHistIndexMatrix>(ghist_index_page_));
return BatchSet<GHistIndexMatrix>(begin_iter);
}

auto id = MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_);
this->InitializeSparsePage();
if (!cache_info_.at(id)->written || RegenGHist(batch_param_, param)) {
Expand All @@ -190,10 +175,9 @@ BatchSet<GHistIndexMatrix> SparsePageDMatrix::GetGradientIndex(const BatchParam
ghist_index_source_.reset();
CHECK_NE(cuts.Values().size(), 0);
auto ft = this->info_.feature_types.ConstHostSpan();
ghist_index_source_.reset(
new GradientIndexPageSource(this->missing_, this->ctx_.Threads(), this->Info().num_col_,
this->n_batches_, cache_info_.at(id), param, std::move(cuts),
this->IsDense(), param.max_bin, ft, sparse_page_source_));
ghist_index_source_.reset(new GradientIndexPageSource(
this->missing_, this->ctx_.Threads(), this->Info().num_col_, this->n_batches_,
cache_info_.at(id), param, std::move(cuts), this->IsDense(), ft, sparse_page_source_));
} else {
CHECK(ghist_index_source_);
ghist_index_source_->Reset();
Expand Down
3 changes: 3 additions & 0 deletions src/data/sparse_page_dmatrix.cu
Expand Up @@ -11,6 +11,9 @@ namespace data {
BatchSet<EllpackPage> SparsePageDMatrix::GetEllpackBatches(const BatchParam& param) {
CHECK_GE(param.gpu_id, 0);
CHECK_GE(param.max_bin, 2);
if (!(batch_param_ != BatchParam{})) {
CHECK(param != BatchParam{}) << "Batch parameter is not initialized.";
}
auto id = MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_);
size_t row_stride = 0;
this->InitializeSparsePage();
Expand Down

0 comments on commit bb9ee02

Please sign in to comment.