From 7cc858e2ad3dc6da89ba2ef631c545b40f6d28c0 Mon Sep 17 00:00:00 2001 From: fis Date: Wed, 27 Oct 2021 22:13:33 +0800 Subject: [PATCH 1/3] Support external memory in CPU histogram building. Port the rest. Comment. Start test. Debug. Remove. test. Format. Fix test.. --- src/common/hist_util.cc | 171 +++++++++++--------- src/common/hist_util.h | 18 +-- src/tree/hist/histogram.h | 125 ++++++++------ src/tree/updater_quantile_hist.cc | 27 +++- tests/cpp/tree/hist/test_evaluate_splits.cc | 2 +- tests/cpp/tree/hist/test_histogram.cc | 123 ++++++++++++-- 6 files changed, 308 insertions(+), 158 deletions(-) diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index 40392a2bcfaa..fbe1ee4dc1d7 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -133,74 +133,108 @@ struct Prefetch { constexpr size_t Prefetch::kNoPrefetchSize; - -template -void BuildHistKernel(const std::vector& gpair, +template +void BuildHistKernel(const std::vector &gpair, const RowSetCollection::Elem row_indices, - const GHistIndexMatrix& gmat, - GHistRow hist) { + const GHistIndexMatrix &gmat, GHistRow hist) { const size_t size = row_indices.Size(); - const size_t* rid = row_indices.begin; - const float* pgh = reinterpret_cast(gpair.data()); - const BinIdxType* gradient_index = gmat.index.data(); - const size_t* row_ptr = gmat.row_ptr.data(); - const uint32_t* offsets = gmat.index.Offset(); - const size_t n_features = row_ptr[row_indices.begin[0]+1] - row_ptr[row_indices.begin[0]]; - FPType* hist_data = reinterpret_cast(hist.data()); - const uint32_t two {2}; // Each element from 'gpair' and 'hist' contains - // 2 FP values: gradient and hessian. - // So we need to multiply each row-index/bin-index by 2 - // to work with gradient pairs as a singe row FP array + const size_t *rid = row_indices.begin; + auto const *pgh = reinterpret_cast(gpair.data()); + const BinIdxType *gradient_index = gmat.index.data(); + + auto const &row_ptr = gmat.row_ptr.data(); + auto base_rowid = gmat.base_rowid; + const uint32_t *offsets = gmat.index.Offset(); + auto get_row_ptr = [&](size_t ridx) { + return first_page ? row_ptr[ridx] : row_ptr[ridx - base_rowid]; + }; + auto get_rid = [&](size_t ridx) { + return first_page ? ridx : (ridx - base_rowid); + }; + + const size_t n_features = + get_row_ptr(row_indices.begin[0] + 1) - get_row_ptr(row_indices.begin[0]); + auto hist_data = reinterpret_cast(hist.data()); + const uint32_t two{2}; // Each element from 'gpair' and 'hist' contains + // 2 FP values: gradient and hessian. + // So we need to multiply each row-index/bin-index by 2 + // to work with gradient pairs as a singe row FP array for (size_t i = 0; i < size; ++i) { - const size_t icol_start = any_missing ? row_ptr[rid[i]] : rid[i] * n_features; - const size_t icol_end = any_missing ? row_ptr[rid[i]+1] : icol_start + n_features; + const size_t icol_start = + any_missing ? get_row_ptr(rid[i]) : get_rid(rid[i]) * n_features; + const size_t icol_end = + any_missing ? get_row_ptr(rid[i] + 1) : icol_start + n_features; + const size_t row_size = icol_end - icol_start; const size_t idx_gh = two * rid[i]; if (do_prefetch) { - const size_t icol_start_prefetch = any_missing ? row_ptr[rid[i+Prefetch::kPrefetchOffset]] : - rid[i + Prefetch::kPrefetchOffset] * n_features; - const size_t icol_end_prefetch = any_missing ? row_ptr[rid[i+Prefetch::kPrefetchOffset]+1] : - icol_start_prefetch + n_features; + const size_t icol_start_prefetch = + any_missing + ? get_row_ptr(rid[i + Prefetch::kPrefetchOffset]) + : get_rid(rid[i + Prefetch::kPrefetchOffset]) * n_features; + const size_t icol_end_prefetch = + any_missing ? get_row_ptr(rid[i + Prefetch::kPrefetchOffset] + 1) + : icol_start_prefetch + n_features; PREFETCH_READ_T0(pgh + two * rid[i + Prefetch::kPrefetchOffset]); for (size_t j = icol_start_prefetch; j < icol_end_prefetch; - j+=Prefetch::GetPrefetchStep()) { + j += Prefetch::GetPrefetchStep()) { PREFETCH_READ_T0(gradient_index + j); } } - const BinIdxType* gr_index_local = gradient_index + icol_start; + const BinIdxType *gr_index_local = gradient_index + icol_start; for (size_t j = 0; j < row_size; ++j) { - const uint32_t idx_bin = two * (static_cast(gr_index_local[j]) + ( - any_missing ? 0 : offsets[j])); - - hist_data[idx_bin] += pgh[idx_gh]; - hist_data[idx_bin+1] += pgh[idx_gh+1]; + const uint32_t idx_bin = two * (static_cast(gr_index_local[j]) + + (any_missing ? 0 : offsets[j])); + hist_data[idx_bin] += pgh[idx_gh]; + hist_data[idx_bin + 1] += pgh[idx_gh + 1]; } } } -template -void BuildHistDispatch(const std::vector& gpair, +template +void BuildHistDispatch(const std::vector &gpair, const RowSetCollection::Elem row_indices, - const GHistIndexMatrix& gmat, GHistRow hist) { - switch (gmat.index.GetBinTypeSize()) { + const GHistIndexMatrix &gmat, GHistRow hist) { + auto first_page = gmat.base_rowid == 0; + if (first_page) { + switch (gmat.index.GetBinTypeSize()) { case kUint8BinsTypeSize: - BuildHistKernel(gpair, row_indices, - gmat, hist); + BuildHistKernel( + gpair, row_indices, gmat, hist); break; case kUint16BinsTypeSize: - BuildHistKernel(gpair, row_indices, - gmat, hist); + BuildHistKernel( + gpair, row_indices, gmat, hist); break; case kUint32BinsTypeSize: - BuildHistKernel(gpair, row_indices, - gmat, hist); + BuildHistKernel( + gpair, row_indices, gmat, hist); break; default: CHECK(false); // no default behavior + } + } else { + switch (gmat.index.GetBinTypeSize()) { + case kUint8BinsTypeSize: + BuildHistKernel( + gpair, row_indices, gmat, hist); + break; + case kUint16BinsTypeSize: + BuildHistKernel( + gpair, row_indices, gmat, hist); + break; + case kUint32BinsTypeSize: + BuildHistKernel( + gpair, row_indices, gmat, hist); + break; + default: + CHECK(false); // no default behavior + } } } @@ -208,73 +242,52 @@ template template void GHistBuilder::BuildHist( const std::vector &gpair, - const RowSetCollection::Elem row_indices, - const GHistIndexMatrix &gmat, - GHistRowT hist) { + const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat, + GHistRowT hist) const { const size_t nrows = row_indices.Size(); const size_t no_prefetch_size = Prefetch::NoPrefetchSize(nrows); // if need to work with all rows from bin-matrix (e.g. root node) - const bool contiguousBlock = (row_indices.begin[nrows - 1] - row_indices.begin[0]) == (nrows - 1); + const bool contiguousBlock = + (row_indices.begin[nrows - 1] - row_indices.begin[0]) == (nrows - 1); if (contiguousBlock) { // contiguous memory access, built-in HW prefetching is enough - BuildHistDispatch(gpair, row_indices, gmat, hist); + BuildHistDispatch(gpair, row_indices, + gmat, hist); } else { - const RowSetCollection::Elem span1(row_indices.begin, row_indices.end - no_prefetch_size); - const RowSetCollection::Elem span2(row_indices.end - no_prefetch_size, row_indices.end); + const RowSetCollection::Elem span1(row_indices.begin, + row_indices.end - no_prefetch_size); + const RowSetCollection::Elem span2(row_indices.end - no_prefetch_size, + row_indices.end); - BuildHistDispatch(gpair, span1, gmat, hist); + BuildHistDispatch(gpair, span1, gmat, + hist); // no prefetching to avoid loading extra memory - BuildHistDispatch(gpair, span2, gmat, hist); + BuildHistDispatch(gpair, span2, gmat, + hist); } } + template void GHistBuilder::BuildHist(const std::vector &gpair, const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat, - GHistRow hist); + GHistRow hist) const; template void GHistBuilder::BuildHist(const std::vector &gpair, const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat, - GHistRow hist); + GHistRow hist) const; template void GHistBuilder::BuildHist(const std::vector &gpair, const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat, - GHistRow hist); + GHistRow hist) const; template void GHistBuilder::BuildHist(const std::vector &gpair, const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat, - GHistRow hist); - -template -void GHistBuilder::SubtractionTrick(GHistRowT self, - GHistRowT sibling, - GHistRowT parent) { - const size_t size = self.size(); - CHECK_EQ(sibling.size(), size); - CHECK_EQ(parent.size(), size); - - const size_t block_size = 1024; // aproximatly 1024 values per block - size_t n_blocks = size/block_size + !!(size%block_size); - - ParallelFor(omp_ulong(n_blocks), [&](omp_ulong iblock) { - const size_t ibegin = iblock*block_size; - const size_t iend = (((iblock+1)*block_size > size) ? size : ibegin + block_size); - SubtractionHist(self, parent, sibling, ibegin, iend); - }); -} -template -void GHistBuilder::SubtractionTrick(GHistRow self, - GHistRow sibling, - GHistRow parent); -template -void GHistBuilder::SubtractionTrick(GHistRow self, - GHistRow sibling, - GHistRow parent); - + GHistRow hist) const; } // namespace common } // namespace xgboost diff --git a/src/common/hist_util.h b/src/common/hist_util.h index 05d8c2eac1d8..c2ff58593f50 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -460,7 +460,7 @@ class ParallelGHistBuilder { } // Reduce following bins (begin, end] for nid-node in dst across threads - void ReduceHist(size_t nid, size_t begin, size_t end) { + void ReduceHist(size_t nid, size_t begin, size_t end) const { CHECK_GT(end, begin); CHECK_LT(nid, nodes_); @@ -486,7 +486,6 @@ class ParallelGHistBuilder { } } - protected: void MatchThreadsToNodes(const BlockedSpace2d& space) { const size_t space_size = space.Size(); const size_t chunck_size = space_size / nthreads_ + !!(space_size % nthreads_); @@ -533,6 +532,7 @@ class ParallelGHistBuilder { } } + private: void MatchNodeNidPairToHist() { size_t hist_allocated_additionally = 0; @@ -586,26 +586,18 @@ class GHistBuilder { using GHistRowT = GHistRow; GHistBuilder() = default; - GHistBuilder(size_t nthread, uint32_t nbins) : nthread_{nthread}, nbins_{nbins} {} + explicit GHistBuilder(uint32_t nbins): nbins_{nbins} {} // construct a histogram via histogram aggregation template - void BuildHist(const std::vector& gpair, + void BuildHist(const std::vector &gpair, const RowSetCollection::Elem row_indices, - const GHistIndexMatrix& gmat, - GHistRowT hist); - // construct a histogram via subtraction trick - void SubtractionTrick(GHistRowT self, - GHistRowT sibling, - GHistRowT parent); - + const GHistIndexMatrix &gmat, GHistRowT hist) const; uint32_t GetNumBins() const { return nbins_; } private: - /*! \brief number of threads for parallel computation */ - size_t nthread_ { 0 }; /*! \brief number of all bins over all features */ uint32_t nbins_ { 0 }; }; diff --git a/src/tree/hist/histogram.h b/src/tree/hist/histogram.h index 70c756e765e6..b92c4cf89102 100644 --- a/src/tree/hist/histogram.h +++ b/src/tree/hist/histogram.h @@ -11,6 +11,8 @@ #include "rabit/rabit.h" #include "xgboost/tree_model.h" #include "../../common/hist_util.h" +#include "../../data/gradient_index.h" +#include "../../common/observer.h" namespace xgboost { namespace tree { @@ -25,8 +27,9 @@ template class HistogramBuilder { common::GHistBuilder builder_; common::ParallelGHistBuilder buffer_; rabit::Reducer reducer_; - int32_t max_bin_ {-1}; + BatchParam param_; int32_t n_threads_ {-1}; + size_t n_batches_ {0}; // Whether XGBoost is running in distributed environment. bool is_distributed_ {false}; @@ -39,59 +42,58 @@ template class HistogramBuilder { * \param is_distributed Mostly used for testing to allow injecting parameters instead * of using global rabit variable. */ - void Reset(uint32_t total_bins, int32_t max_bin_per_feat, int32_t n_threads, - bool is_distributed = rabit::IsDistributed()) { + void Reset(uint32_t total_bins, BatchParam p, int32_t n_threads, + size_t n_batches, bool is_distributed) { CHECK_GE(n_threads, 1); n_threads_ = n_threads; - CHECK_GE(max_bin_per_feat, 2); - max_bin_ = max_bin_per_feat; + n_batches_ = n_batches; + param_ = p; hist_.Init(total_bins); hist_local_worker_.Init(total_bins); buffer_.Init(total_bins); - builder_ = common::GHistBuilder(n_threads, total_bins); + builder_ = common::GHistBuilder(total_bins); is_distributed_ = is_distributed; } template - void - BuildLocalHistograms(DMatrix *p_fmat, - std::vector nodes_for_explicit_hist_build, - common::RowSetCollection const &row_set_collection, - const std::vector &gpair_h) { + void BuildLocalHistograms( + size_t page_idx, + common::BlockedSpace2d space, + GHistIndexMatrix const &gidx, + std::vector const &nodes_for_explicit_hist_build, + common::RowSetCollection const &row_set_collection, + const std::vector &gpair_h) { const size_t n_nodes = nodes_for_explicit_hist_build.size(); - - // create space of size (# rows in each node) - common::BlockedSpace2d space( - n_nodes, - [&](size_t node) { - const int32_t nid = nodes_for_explicit_hist_build[node].nid; - return row_set_collection[nid].Size(); - }, - 256); + CHECK_GT(n_nodes, 0); std::vector target_hists(n_nodes); for (size_t i = 0; i < n_nodes; ++i) { const int32_t nid = nodes_for_explicit_hist_build[i].nid; target_hists[i] = hist_[nid]; } - buffer_.Reset(this->n_threads_, n_nodes, space, target_hists); + if (page_idx == 0) { + // FIXME(jiamingy): Handle different size of space. Right now we use the maximum + // partition size for the buffer, which might not be efficient if partition sizes + // has significant variance. + buffer_.Reset(this->n_threads_, n_nodes, space, target_hists); + } // Parallel processing by nodes and data in each node - for (auto const &gmat : p_fmat->GetBatches( - BatchParam{GenericParameter::kCpuId, max_bin_})) { - common::ParallelFor2d( - space, this->n_threads_, [&](size_t nid_in_set, common::Range1d r) { - const auto tid = static_cast(omp_get_thread_num()); - const int32_t nid = nodes_for_explicit_hist_build[nid_in_set].nid; - - auto start_of_row_set = row_set_collection[nid].begin; - auto rid_set = common::RowSetCollection::Elem( - start_of_row_set + r.begin(), start_of_row_set + r.end(), nid); - builder_.template BuildHist( - gpair_h, rid_set, gmat, - buffer_.GetInitializedHist(tid, nid_in_set)); - }); - } + common::ParallelFor2d( + space, this->n_threads_, [&](size_t nid_in_set, common::Range1d r) { + const auto tid = static_cast(omp_get_thread_num()); + const int32_t nid = nodes_for_explicit_hist_build[nid_in_set].nid; + auto elem = row_set_collection[nid]; + auto start_of_row_set = std::min(r.begin(), elem.Size()); + auto end_of_row_set = std::min(r.end(), elem.Size()); + auto rid_set = common::RowSetCollection::Elem( + elem.begin + start_of_row_set, elem.begin + end_of_row_set, nid); + auto hist = buffer_.GetInitializedHist(tid, nid_in_set); + if (rid_set.Size() != 0) { + builder_.template BuildHist(gpair_h, rid_set, gidx, + hist); + } + }); } void @@ -110,24 +112,36 @@ template class HistogramBuilder { } } - /* Main entry point of this class, build histogram for tree nodes. */ - void BuildHist(DMatrix *p_fmat, RegTree *p_tree, + /** Main entry point of this class, build histogram for tree nodes. */ + void BuildHist(size_t page_id, + common::BlockedSpace2d space, + GHistIndexMatrix const& gidx, RegTree *p_tree, common::RowSetCollection const &row_set_collection, std::vector const &nodes_for_explicit_hist_build, std::vector const &nodes_for_subtraction_trick, std::vector const &gpair) { int starting_index = std::numeric_limits::max(); int sync_count = 0; - this->AddHistRows(&starting_index, &sync_count, - nodes_for_explicit_hist_build, - nodes_for_subtraction_trick, p_tree); - if (p_fmat->IsDense()) { - BuildLocalHistograms(p_fmat, nodes_for_explicit_hist_build, - row_set_collection, gpair); + if (page_id == 0) { + this->AddHistRows(&starting_index, &sync_count, + nodes_for_explicit_hist_build, + nodes_for_subtraction_trick, p_tree); + } + if (gidx.IsDense()) { + this->BuildLocalHistograms(page_id, space, gidx, + nodes_for_explicit_hist_build, + row_set_collection, gpair); } else { - BuildLocalHistograms(p_fmat, nodes_for_explicit_hist_build, - row_set_collection, gpair); + this->BuildLocalHistograms(page_id, space, gidx, + nodes_for_explicit_hist_build, + row_set_collection, gpair); } + + CHECK_GE(n_batches_, 1); + if (page_id != n_batches_ - 1) { + return; + } + if (is_distributed_) { this->SyncHistogramDistributed(p_tree, nodes_for_explicit_hist_build, nodes_for_subtraction_trick, @@ -138,6 +152,25 @@ template class HistogramBuilder { sync_count); } } + /** same as the other build hist but handles only single batch data (in-core) */ + void BuildHist(size_t page_id, GHistIndexMatrix const &gidx, RegTree *p_tree, + common::RowSetCollection const &row_set_collection, + std::vector const &nodes_for_explicit_hist_build, + std::vector const &nodes_for_subtraction_trick, + std::vector const &gpair) { + const size_t n_nodes = nodes_for_explicit_hist_build.size(); + // create space of size (# rows in each node) + common::BlockedSpace2d space( + n_nodes, + [&](size_t nidx_in_set) { + const int32_t nidx = nodes_for_explicit_hist_build[nidx_in_set].nid; + return row_set_collection[nidx].Size(); + }, + 256); + this->BuildHist(page_id, space, gidx, p_tree, row_set_collection, + nodes_for_explicit_hist_build, nodes_for_subtraction_trick, + gpair); + } void SyncHistogramDistributed( RegTree *p_tree, diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 1207a57102f5..51d228357719 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -130,9 +130,14 @@ void QuantileHistMaker::Builder::InitRoot( nodes_for_subtraction_trick_.clear(); nodes_for_explicit_hist_build_.push_back(node); - this->histogram_builder_->BuildHist(p_fmat, p_tree, row_set_collection_, - nodes_for_explicit_hist_build_, - nodes_for_subtraction_trick_, gpair_h); + size_t page_id = 0; + for (auto const &gidx : p_fmat->GetBatches( + {GenericParameter::kCpuId, param_.max_bin})) { + this->histogram_builder_->BuildHist( + page_id, gidx, p_tree, row_set_collection_, + nodes_for_explicit_hist_build_, nodes_for_subtraction_trick_, gpair_h); + ++page_id; + } { auto nid = RegTree::kRoot; @@ -262,9 +267,15 @@ void QuantileHistMaker::Builder::ExpandTree( SplitSiblings(nodes_for_apply_split, &nodes_to_evaluate, p_tree); if (depth < param_.max_depth) { - this->histogram_builder_->BuildHist( - p_fmat, p_tree, row_set_collection_, nodes_for_explicit_hist_build_, - nodes_for_subtraction_trick_, gpair_h); + size_t i = 0; + for (auto const &gidx : p_fmat->GetBatches( + {GenericParameter::kCpuId, param_.max_bin})) { + this->histogram_builder_->BuildHist( + i, gidx, p_tree, row_set_collection_, + nodes_for_explicit_hist_build_, nodes_for_subtraction_trick_, + gpair_h); + ++i; + } } else { int starting_index = std::numeric_limits::max(); int sync_count = 0; @@ -435,7 +446,9 @@ void QuantileHistMaker::Builder::InitData( }); } exc.Rethrow(); - this->histogram_builder_->Reset(nbins, param_.max_bin, this->nthread_); + this->histogram_builder_->Reset( + nbins, BatchParam{GenericParameter::kCpuId, param_.max_bin}, + this->nthread_, 1, rabit::IsDistributed()); std::vector& row_indices = *row_set_collection_.Data(); row_indices.resize(info.num_row_); diff --git a/tests/cpp/tree/hist/test_evaluate_splits.cc b/tests/cpp/tree/hist/test_evaluate_splits.cc index 115dcb0297dd..59d1ef2328d5 100644 --- a/tests/cpp/tree/hist/test_evaluate_splits.cc +++ b/tests/cpp/tree/hist/test_evaluate_splits.cc @@ -36,7 +36,7 @@ template void TestEvaluateSplits() { std::iota(row_indices.begin(), row_indices.end(), 0); row_set_collection.Init(); - auto hist_builder = GHistBuilder(omp_get_max_threads(), gmat.cut.Ptrs().back()); + auto hist_builder = GHistBuilder(gmat.cut.Ptrs().back()); hist.Init(gmat.cut.Ptrs().back()); hist.AddHistRow(0); hist.AllocateAllData(); diff --git a/tests/cpp/tree/hist/test_histogram.cc b/tests/cpp/tree/hist/test_histogram.cc index f257a683405e..ff721cf129bd 100644 --- a/tests/cpp/tree/hist/test_histogram.cc +++ b/tests/cpp/tree/hist/test_histogram.cc @@ -8,6 +8,16 @@ namespace xgboost { namespace tree { +namespace { +void InitRowPartitionForTest(RowSetCollection *row_set, size_t n_samples, + size_t base_rowid = 0) { + auto &row_indices = *row_set->Data(); + row_indices.resize(n_samples); + std::iota(row_indices.begin(), row_indices.end(), base_rowid); + row_set->Init(); +} +} // anonymous namespace + template void TestAddHistRows(bool is_distributed) { std::vector nodes_for_explicit_hist_build_; @@ -35,8 +45,9 @@ void TestAddHistRows(bool is_distributed) { nodes_for_subtraction_trick_.emplace_back(6, tree.GetDepth(6), 0.0f); HistogramBuilder histogram_builder; - histogram_builder.Reset(gmat.cut.TotalBins(), kMaxBins, omp_get_max_threads(), - is_distributed); + histogram_builder.Reset(gmat.cut.TotalBins(), + {GenericParameter::kCpuId, kMaxBins}, + omp_get_max_threads(), 1, is_distributed); histogram_builder.AddHistRows(&starting_index, &sync_count, nodes_for_explicit_hist_build_, nodes_for_subtraction_trick_, &tree); @@ -81,7 +92,8 @@ void TestSyncHist(bool is_distributed) { HistogramBuilder histogram; uint32_t total_bins = gmat.cut.Ptrs().back(); - histogram.Reset(total_bins, kMaxBins, omp_get_max_threads(), is_distributed); + histogram.Reset(total_bins, {GenericParameter::kCpuId, kMaxBins}, + omp_get_max_threads(), 1, is_distributed); RowSetCollection row_set_collection_; { @@ -247,22 +259,26 @@ void TestBuildHistogram(bool is_distributed) { bst_node_t nid = 0; HistogramBuilder histogram; - histogram.Reset(total_bins, kMaxBins, omp_get_max_threads(), is_distributed); + histogram.Reset(total_bins, {GenericParameter::kCpuId, kMaxBins}, + omp_get_max_threads(), 1, is_distributed); RegTree tree; - RowSetCollection row_set_collection_; - row_set_collection_.Clear(); - std::vector &row_indices = *row_set_collection_.Data(); + RowSetCollection row_set_collection; + row_set_collection.Clear(); + std::vector &row_indices = *row_set_collection.Data(); row_indices.resize(kNRows); std::iota(row_indices.begin(), row_indices.end(), 0); - row_set_collection_.Init(); + row_set_collection.Init(); CPUExpandEntry node(RegTree::kRoot, tree.GetDepth(0), 0.0f); - std::vector nodes_for_explicit_hist_build_; - nodes_for_explicit_hist_build_.push_back(node); - histogram.BuildHist(p_fmat.get(), &tree, row_set_collection_, - nodes_for_explicit_hist_build_, {}, gpair); + std::vector nodes_for_explicit_hist_build; + nodes_for_explicit_hist_build.push_back(node); + for (auto const &gidx : p_fmat->GetBatches( + {GenericParameter::kCpuId, kMaxBins})) { + histogram.BuildHist(0, gidx, &tree, row_set_collection, + nodes_for_explicit_hist_build, {}, gpair); + } // Check if number of histogram bins is correct ASSERT_EQ(histogram.Histogram()[nid].size(), gmat.cut.Ptrs().back()); @@ -294,5 +310,88 @@ TEST(CPUHistogram, BuildHist) { TestBuildHistogram(false); TestBuildHistogram(false); } + +TEST(CPUHistogram, ExternalMemory) { + size_t constexpr kEntries = 1 << 16; + + int32_t constexpr kBins = 32; + auto m = CreateSparsePageDMatrix(kEntries, "cache"); + std::vector partition_size(1, 0); + size_t total_bins{0}; + size_t n_samples{0}; + + auto gpair = GenerateRandomGradients(m->Info().num_row_, 0.0, 1.0); + auto const &h_gpair = gpair.HostVector(); + + RegTree tree; + std::vector nodes; + nodes.emplace_back(0, tree.GetDepth(0), 0.0f); + + GHistRow multi_page; + HistogramBuilder multi_build; + { + /** + * Multi page + */ + std::vector rows_set; + std::vector hess(m->Info().num_row_, 1.0); + for (auto const &page : m->GetBatches( + {GenericParameter::kCpuId, kBins, hess})) { + CHECK_LT(page.base_rowid, m->Info().num_row_); + auto n_rows_in_node = page.Size(); + partition_size[0] = std::max(partition_size[0], n_rows_in_node); + total_bins = page.cut.TotalBins(); + n_samples += n_rows_in_node; + + rows_set.emplace_back(); + InitRowPartitionForTest(&rows_set.back(), n_rows_in_node, page.base_rowid); + } + ASSERT_EQ(n_samples, m->Info().num_row_); + + common::BlockedSpace2d space{ + 1, [&](size_t nidx_in_set) { return partition_size.at(nidx_in_set); }, + 256}; + + multi_build.Reset(total_bins, {GenericParameter::kCpuId, kBins}, + omp_get_max_threads(), rows_set.size(), false); + + size_t page_idx{0}; + for (auto const &page : m->GetBatches( + {GenericParameter::kCpuId, kBins, hess})) { + multi_build.BuildHist(page_idx, space, page, &tree, + rows_set.at(page_idx), nodes, {}, h_gpair); + ++page_idx; + } + ASSERT_EQ(page_idx, 2); + multi_page = multi_build.Histogram()[0]; + } + + HistogramBuilder single_build; + GHistRow single_page; + { + /** + * Single page + */ + RowSetCollection row_set_collection; + InitRowPartitionForTest(&row_set_collection, n_samples); + + single_build.Reset(total_bins, {GenericParameter::kCpuId, kBins}, + omp_get_max_threads(), 1, false); + size_t n_batches{0}; + for (auto const &page : + m->GetBatches({GenericParameter::kCpuId, kBins})) { + single_build.BuildHist(0, page, &tree, row_set_collection, nodes, {}, + h_gpair); + n_batches ++; + } + ASSERT_EQ(n_batches, 1); + single_page = single_build.Histogram()[0]; + } + + for (size_t i = 0; i < single_page.size(); ++i) { + ASSERT_NEAR(single_page[i].GetGrad(), multi_page[i].GetGrad(), kRtEps); + ASSERT_NEAR(single_page[i].GetHess(), multi_page[i].GetHess(), kRtEps); + } +} } // namespace tree } // namespace xgboost From 5588d0bef764039247e5383951ea1f2529c2471c Mon Sep 17 00:00:00 2001 From: fis Date: Wed, 10 Nov 2021 21:50:40 +0800 Subject: [PATCH 2/3] Remove header. --- src/tree/hist/histogram.h | 1 - 1 file changed, 1 deletion(-) diff --git a/src/tree/hist/histogram.h b/src/tree/hist/histogram.h index b92c4cf89102..34113ec4f9a8 100644 --- a/src/tree/hist/histogram.h +++ b/src/tree/hist/histogram.h @@ -12,7 +12,6 @@ #include "xgboost/tree_model.h" #include "../../common/hist_util.h" #include "../../data/gradient_index.h" -#include "../../common/observer.h" namespace xgboost { namespace tree { From 88fc3355dd8fd9889845c5466419827db70b354f Mon Sep 17 00:00:00 2001 From: fis Date: Wed, 10 Nov 2021 21:52:01 +0800 Subject: [PATCH 3/3] Format. --- src/tree/hist/histogram.h | 56 +++++++++++++++++---------------------- 1 file changed, 25 insertions(+), 31 deletions(-) diff --git a/src/tree/hist/histogram.h b/src/tree/hist/histogram.h index 34113ec4f9a8..242825b25bb5 100644 --- a/src/tree/hist/histogram.h +++ b/src/tree/hist/histogram.h @@ -27,10 +27,10 @@ template class HistogramBuilder { common::ParallelGHistBuilder buffer_; rabit::Reducer reducer_; BatchParam param_; - int32_t n_threads_ {-1}; - size_t n_batches_ {0}; + int32_t n_threads_{-1}; + size_t n_batches_{0}; // Whether XGBoost is running in distributed environment. - bool is_distributed_ {false}; + bool is_distributed_{false}; public: /** @@ -41,8 +41,8 @@ template class HistogramBuilder { * \param is_distributed Mostly used for testing to allow injecting parameters instead * of using global rabit variable. */ - void Reset(uint32_t total_bins, BatchParam p, int32_t n_threads, - size_t n_batches, bool is_distributed) { + void Reset(uint32_t total_bins, BatchParam p, int32_t n_threads, size_t n_batches, + bool is_distributed) { CHECK_GE(n_threads, 1); n_threads_ = n_threads; n_batches_ = n_batches; @@ -55,13 +55,11 @@ template class HistogramBuilder { } template - void BuildLocalHistograms( - size_t page_idx, - common::BlockedSpace2d space, - GHistIndexMatrix const &gidx, - std::vector const &nodes_for_explicit_hist_build, - common::RowSetCollection const &row_set_collection, - const std::vector &gpair_h) { + void BuildLocalHistograms(size_t page_idx, common::BlockedSpace2d space, + GHistIndexMatrix const &gidx, + std::vector const &nodes_for_explicit_hist_build, + common::RowSetCollection const &row_set_collection, + const std::vector &gpair_h) { const size_t n_nodes = nodes_for_explicit_hist_build.size(); CHECK_GT(n_nodes, 0); @@ -78,21 +76,19 @@ template class HistogramBuilder { } // Parallel processing by nodes and data in each node - common::ParallelFor2d( - space, this->n_threads_, [&](size_t nid_in_set, common::Range1d r) { - const auto tid = static_cast(omp_get_thread_num()); - const int32_t nid = nodes_for_explicit_hist_build[nid_in_set].nid; - auto elem = row_set_collection[nid]; - auto start_of_row_set = std::min(r.begin(), elem.Size()); - auto end_of_row_set = std::min(r.end(), elem.Size()); - auto rid_set = common::RowSetCollection::Elem( - elem.begin + start_of_row_set, elem.begin + end_of_row_set, nid); - auto hist = buffer_.GetInitializedHist(tid, nid_in_set); - if (rid_set.Size() != 0) { - builder_.template BuildHist(gpair_h, rid_set, gidx, - hist); - } - }); + common::ParallelFor2d(space, this->n_threads_, [&](size_t nid_in_set, common::Range1d r) { + const auto tid = static_cast(omp_get_thread_num()); + const int32_t nid = nodes_for_explicit_hist_build[nid_in_set].nid; + auto elem = row_set_collection[nid]; + auto start_of_row_set = std::min(r.begin(), elem.Size()); + auto end_of_row_set = std::min(r.end(), elem.Size()); + auto rid_set = common::RowSetCollection::Elem(elem.begin + start_of_row_set, + elem.begin + end_of_row_set, nid); + auto hist = buffer_.GetInitializedHist(tid, nid_in_set); + if (rid_set.Size() != 0) { + builder_.template BuildHist(gpair_h, rid_set, gidx, hist); + } + }); } void @@ -112,10 +108,8 @@ template class HistogramBuilder { } /** Main entry point of this class, build histogram for tree nodes. */ - void BuildHist(size_t page_id, - common::BlockedSpace2d space, - GHistIndexMatrix const& gidx, RegTree *p_tree, - common::RowSetCollection const &row_set_collection, + void BuildHist(size_t page_id, common::BlockedSpace2d space, GHistIndexMatrix const &gidx, + RegTree *p_tree, common::RowSetCollection const &row_set_collection, std::vector const &nodes_for_explicit_hist_build, std::vector const &nodes_for_subtraction_trick, std::vector const &gpair) {