From 8d2f0b5043a0f13a1e07f64da2e112070220412e Mon Sep 17 00:00:00 2001 From: fis Date: Wed, 29 Dec 2021 18:06:10 +0800 Subject: [PATCH] Remove the need to load sparse page. --- src/data/gradient_index_page_source.cc | 4 ++++ src/data/gradient_index_page_source.h | 9 ++++---- src/data/sparse_page_source.h | 29 ++++++++++++++++++++------ 3 files changed, 32 insertions(+), 10 deletions(-) diff --git a/src/data/gradient_index_page_source.cc b/src/data/gradient_index_page_source.cc index b1404765fcff..5e987cb29aee 100644 --- a/src/data/gradient_index_page_source.cc +++ b/src/data/gradient_index_page_source.cc @@ -7,6 +7,10 @@ namespace xgboost { namespace data { void GradientIndexPageSource::Fetch() { if (!this->ReadCache()) { + if (count_ != 0) { + ++(*source_); + } + CHECK_EQ(count_, source_->Iter()); auto const& csr = source_->Page(); this->page_.reset(new GHistIndexMatrix()); CHECK_NE(cuts_.Values().size(), 0); diff --git a/src/data/gradient_index_page_source.h b/src/data/gradient_index_page_source.h index 92177ab2582e..f98f5f0a7333 100644 --- a/src/data/gradient_index_page_source.h +++ b/src/data/gradient_index_page_source.h @@ -7,8 +7,8 @@ #include #include -#include "sparse_page_source.h" #include "gradient_index.h" +#include "sparse_page_source.h" namespace xgboost { namespace data { @@ -25,7 +25,8 @@ class GradientIndexPageSource : public PageSourceIncMixIn { common::HistogramCuts cuts, bool is_dense, int32_t max_bin_per_feat, common::Span feature_types, float sparse_thresh, std::shared_ptr source) - : PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache), + : PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache, + !std::isnan(sparse_thresh)), cuts_{std::move(cuts)}, is_dense_{is_dense}, max_bin_per_feat_{max_bin_per_feat}, @@ -37,6 +38,6 @@ class GradientIndexPageSource : public PageSourceIncMixIn { void Fetch() final; }; -} // namespace data -} // namespace xgboost +} // namespace data +} // namespace xgboost #endif // XGBOOST_DATA_GRADIENT_INDEX_PAGE_SOURCE_H_ diff --git a/src/data/sparse_page_source.h b/src/data/sparse_page_source.h index 5684118523a4..a9fbcec728cf 100644 --- a/src/data/sparse_page_source.h +++ b/src/data/sparse_page_source.h @@ -119,9 +119,12 @@ class SparsePageSourceImpl : public BatchIteratorImpl { size_t n_prefetch_batches = std::min(kPreFetch, n_batches_); CHECK_GT(n_prefetch_batches, 0) << "total batches:" << n_batches_; size_t fetch_it = count_; + for (size_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) { fetch_it %= n_batches_; // ring - if (ring_->at(fetch_it).valid()) { continue; } + if (ring_->at(fetch_it).valid()) { + continue; + } auto const *self = this; // make sure it's const CHECK_LT(fetch_it, cache_info_->offset.size()); ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self]() { @@ -139,8 +142,7 @@ class SparsePageSourceImpl : public BatchIteratorImpl { return page; }); } - CHECK_EQ(std::count_if(ring_->cbegin(), ring_->cend(), - [](auto const &f) { return f.valid(); }), + CHECK_EQ(std::count_if(ring_->cbegin(), ring_->cend(), [](auto const& f) { return f.valid(); }), n_prefetch_batches) << "Sparse DMatrix assumes forward iteration."; page_ = (*ring_)[count_].get(); @@ -289,15 +291,28 @@ template class PageSourceIncMixIn : public SparsePageSourceImpl { protected: std::shared_ptr source_; + using Super = SparsePageSourceImpl; + bool sync_{true}; // synchronize the row page. public: using SparsePageSourceImpl::SparsePageSourceImpl; + PageSourceIncMixIn(float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches, + std::shared_ptr cache, bool sync) + : Super::SparsePageSourceImpl{missing, nthreads, n_features, n_batches, cache}, sync_{sync} {} + PageSourceIncMixIn& operator++() final { TryLockGuard guard{this->single_threaded_}; - ++(*source_); + if (sync_) { + ++(*source_); + } ++this->count_; - this->at_end_ = source_->AtEnd(); + this->at_end_ = this->count_ == this->n_batches_; + if (this->at_end_) { + CHECK_EQ(this->count_, this->n_batches_); + } else { + CHECK_LT(this->count_, this->n_batches_); + } if (this->at_end_) { this->cache_info_->Commit(); @@ -308,7 +323,9 @@ class PageSourceIncMixIn : public SparsePageSourceImpl { } else { this->Fetch(); } - CHECK_EQ(source_->Iter(), this->count_); + if (sync_) { + CHECK_EQ(source_->Iter(), this->count_); + } return *this; } };