Skip to content

Commit

Permalink
Comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jul 7, 2021
1 parent 950cd65 commit 711e389
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 5 deletions.
3 changes: 2 additions & 1 deletion src/data/sparse_page_dmatrix.cc
Expand Up @@ -38,7 +38,8 @@ SparsePageDMatrix::SparsePageDMatrix(DataIterHandle iter_handle, DMatrixHandle p
return HostAdapterDispatch(
proxy, [](auto const &value) { return value.NumCols(); });
};

// the proxy is iterated together with the sparse page source so we can obtain all
// information in 1 pass.
for (auto const &page : this->GetRowBatchesImpl()) {
this->info_.Extend(std::move(proxy->Info()), false, false);
n_features = std::max(n_features, num_cols());
Expand Down
4 changes: 2 additions & 2 deletions src/data/sparse_page_dmatrix.h
Expand Up @@ -35,9 +35,9 @@ class SparsePageDMatrix : public DMatrix {
int nthreads_;
std::string cache_prefix_;
size_t n_batches_ {0};

// sparse page is the source to other page types, we make a special member function.
void InitializeSparsePage();
// Non virtual version that can be used in constructor
// Non-virtual version that can be used in constructor
BatchSet<SparsePage> GetRowBatchesImpl();

public:
Expand Down
14 changes: 12 additions & 2 deletions src/data/sparse_page_source.h
Expand Up @@ -82,9 +82,11 @@ inline void TryDeleteCacheFile(const std::string& file) {
}

struct Cache {
// whether the write to the cache is complete
bool written;
std::string name;
std::string format;
// offset into binary cache file.
std::vector<size_t> offset;

Cache(bool w, std::string n, std::string fmt)
Expand All @@ -101,6 +103,7 @@ struct Cache {
return ShardName(this->name, this->format);
}

// The write is completed.
void Commit() {
if (!written) {
std::partial_sum(offset.begin(), offset.end(), offset.begin());
Expand All @@ -109,6 +112,7 @@ struct Cache {
}
};

// Prevents multi-threaded call.
class TryLockGuard {
std::mutex& lock_;

Expand All @@ -124,6 +128,7 @@ class TryLockGuard {
template <typename S>
class SparsePageSourceImpl : public BatchIteratorImpl<S> {
protected:
// Prevents calling this iterator from multiple places(or threads).
std::mutex single_threaded_;

std::shared_ptr<S> page_;
Expand All @@ -141,6 +146,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
std::unique_ptr<dmlc::Stream> fo_;

using Ring = std::vector<std::future<std::shared_ptr<S>>>;
// A ring storing futures to data. Since the DMatrix iterator is forward only, so we
// can pre-fetch data in a ring.
std::unique_ptr<Ring> ring_{new Ring};

bool ReadCache() {
Expand All @@ -152,13 +159,15 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
fo_.reset();
ring_->resize(n_batches_);
}
size_t constexpr kPreFetch = 4; // an heuristic for number of pre-fetched batches.
// An heuristic for number of pre-fetched batches. We can make it part of BatchParam
// to let user adjust number of pre-fetched batches when needed.
size_t constexpr kPreFetch = 4;

size_t n_prefetch_batches = std::min(kPreFetch, n_batches_);
CHECK_GT(n_prefetch_batches, 0) << n_batches_;
size_t fetch_it = count_;
for (size_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) {
fetch_it %= n_batches_;
fetch_it %= n_batches_; // ring
if (ring_->at(fetch_it).valid()) { continue; }
auto const *self = this; // make sure it's const
CHECK_LT(fetch_it, cache_info_->offset.size());
Expand Down Expand Up @@ -313,6 +322,7 @@ class SparsePageSource : public SparsePageSourceImpl<SparsePage> {
}
};

// A mixin for advancing the iterator.
template <typename S>
class PageSourceIncMixIn : public SparsePageSourceImpl<S> {
protected:
Expand Down

0 comments on commit 711e389

Please sign in to comment.