From e5d01ad04994d3e8cc3ee8d4289ba66a26d88568 Mon Sep 17 00:00:00 2001 From: fis Date: Tue, 6 Jul 2021 18:16:26 +0800 Subject: [PATCH 1/6] Support hessian in sketch container. * Define a parallel for function that accepts schedule. * Support hessian with group weights. --- include/xgboost/generic_parameters.h | 4 + src/common/hist_util.h | 7 +- src/common/quantile.cc | 141 +++++++++++++++++++-------- src/common/quantile.h | 6 +- src/common/threading_utils.h | 90 ++++++++++++++++- src/data/data.cc | 4 +- src/learner.cc | 4 + src/metric/rank_metric.cc | 2 +- src/tree/updater_colmaker.cc | 15 ++- tests/cpp/common/test_hist_util.cc | 20 ++++ tests/cpp/common/test_quantile.cc | 4 +- 11 files changed, 234 insertions(+), 63 deletions(-) diff --git a/include/xgboost/generic_parameters.h b/include/xgboost/generic_parameters.h index cee660e1dc67..fd74abf2cca7 100644 --- a/include/xgboost/generic_parameters.h +++ b/include/xgboost/generic_parameters.h @@ -39,6 +39,10 @@ struct GenericParameter : public XGBoostParameter { * \param require_gpu Whether GPU is explicitly required from user. */ void ConfigureGpuId(bool require_gpu); + /*! + * Return automatically chosen threads. + */ + int32_t Threads() const; // declare parameters DMLC_DECLARE_PARAMETER(GenericParameter) { diff --git a/src/common/hist_util.h b/src/common/hist_util.h index 041faf2a1b53..be42e197f3eb 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -110,7 +110,8 @@ class HistogramCuts { } }; -inline HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins) { +inline HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, + std::vector const &hessian = {}) { HistogramCuts out; auto const& info = m->Info(); const auto threads = omp_get_max_threads(); @@ -127,9 +128,9 @@ inline HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins) { } } HostSketchContainer container(reduced, max_bins, - HostSketchContainer::UseGroup(info)); + HostSketchContainer::UseGroup(info), threads); for (auto const &page : m->GetBatches()) { - container.PushRowPage(page, info); + container.PushRowPage(page, info, hessian); } container.MakeCuts(&out); return out; diff --git a/src/common/quantile.cc b/src/common/quantile.cc index e67d7daec115..56a70bae16dc 100644 --- a/src/common/quantile.cc +++ b/src/common/quantile.cc @@ -10,19 +10,21 @@ namespace xgboost { namespace common { HostSketchContainer::HostSketchContainer(std::vector columns_size, - int32_t max_bins, bool use_group) + int32_t max_bins, bool use_group, + int32_t n_threads) : columns_size_{std::move(columns_size)}, max_bins_{max_bins}, - use_group_ind_{use_group} { + use_group_ind_{use_group}, n_threads_{n_threads} { monitor_.Init(__func__); CHECK_NE(columns_size_.size(), 0); sketches_.resize(columns_size_.size()); - for (size_t i = 0; i < sketches_.size(); ++i) { + CHECK_GE(n_threads_, 1); + ParallelFor(sketches_.size(), n_threads_, Sched::Auto(), [&](auto i) { auto n_bins = std::min(static_cast(max_bins_), columns_size_[i]); n_bins = std::max(n_bins, static_cast(1)); auto eps = 1.0 / (static_cast(n_bins) * WQSketch::kFactor); sketches_[i].Init(columns_size_[i], eps); sketches_[i].inqueue.queue.resize(sketches_[i].limit_size * 2); - } + }); } std::vector @@ -89,40 +91,94 @@ std::vector HostSketchContainer::LoadBalance( return cols_ptr; } -void HostSketchContainer::PushRowPage(SparsePage const &page, - MetaInfo const &info) { +namespace { +// Function to merge hessian and sample weights +std::vector MergeWeights(MetaInfo const &info, + std::vector const &hessian, + bool use_group, int32_t n_threads) { + CHECK_EQ(hessian.size(), info.num_row_); + std::vector results(hessian.size()); + auto const &group_ptr = info.group_ptr_; + if (use_group) { + auto const &group_weights = info.weights_.HostVector(); + CHECK_GE(group_ptr.size(), 2); + CHECK_EQ(group_ptr.back(), hessian.size()); + size_t cur_group = 0; + for (size_t i = 0; i < hessian.size(); ++i) { + results[i] = hessian[i] * group_weights[cur_group]; + if (i == group_ptr[cur_group + 1]) { + cur_group++; + } + } + } else { + auto const &sample_weights = info.weights_.HostVector(); + ParallelFor(hessian.size(), n_threads, Sched::Auto(), + [&](auto i) { results[i] = hessian[i] * sample_weights[i]; }); + } + return results; +} + +std::vector UnrollGroupWeights(MetaInfo const &info) { + std::vector const &group_weights = info.weights_.HostVector(); + if (group_weights.empty()) { + return group_weights; + } + + size_t n_samples = info.num_row_; + auto const &group_ptr = info.group_ptr_; + std::vector results(n_samples); + CHECK_GE(group_ptr.size(), 2); + CHECK_EQ(group_ptr.back(), n_samples); + size_t cur_group = 0; + for (size_t i = 0; i < n_samples; ++i) { + results[i] = group_weights[cur_group]; + if (i == group_ptr[cur_group + 1]) { + cur_group++; + } + } + return results; +} +} // anonymous namespace + +void HostSketchContainer::PushRowPage( + SparsePage const &page, MetaInfo const &info, std::vector const &hessian) { monitor_.Start(__func__); - int nthread = omp_get_max_threads(); - CHECK_EQ(sketches_.size(), info.num_col_); + bst_feature_t n_columns = info.num_col_; + auto is_dense = info.num_nonzero_ == info.num_col_ * info.num_row_; + CHECK_GE(n_threads_, 1); + CHECK_EQ(sketches_.size(), n_columns); + + // glue these conditions using ternary operator to avoid making data copies. + auto const &weights = + hessian.empty() + ? (use_group_ind_ ? UnrollGroupWeights(info) // use group weight + : info.weights_.HostVector()) // use sample weight + : MergeWeights( + info, hessian, use_group_ind_, + n_threads_); // use hessian merged with group/sample weights + if (!weights.empty()) { + CHECK_EQ(weights.size(), info.num_row_); + } - // Data groups, used in ranking. - std::vector const &group_ptr = info.group_ptr_; - // Use group index for weights? auto batch = page.GetView(); // Parallel over columns. Each thread owns a set of consecutive columns. - auto const ncol = static_cast(info.num_col_); - auto const is_dense = info.num_nonzero_ == info.num_col_ * info.num_row_; - auto thread_columns_ptr = LoadBalance(page, info.num_col_, nthread); + auto const ncol = static_cast(info.num_col_); + auto thread_columns_ptr = LoadBalance(page, info.num_col_, n_threads_); dmlc::OMPException exc; -#pragma omp parallel num_threads(nthread) +#pragma omp parallel num_threads(n_threads_) { exc.Run([&]() { auto tid = static_cast(omp_get_thread_num()); auto const begin = thread_columns_ptr[tid]; auto const end = thread_columns_ptr[tid + 1]; - size_t group_ind = 0; // do not iterate if no columns are assigned to the thread if (begin < end && end <= ncol) { for (size_t i = 0; i < batch.Size(); ++i) { size_t const ridx = page.base_rowid + i; SparsePage::Inst const inst = batch[i]; - if (use_group_ind_) { - group_ind = this->SearchGroupIndFromRow(group_ptr, i + page.base_rowid); - } - size_t w_idx = use_group_ind_ ? group_ind : ridx; - auto w = info.GetWeight(w_idx); + auto w = weights.empty() ? 1.0f : weights[ridx]; auto p_inst = inst.data(); if (is_dense) { for (size_t ii = begin; ii < end; ii++) { @@ -201,6 +257,8 @@ void HostSketchContainer::AllReduce( monitor_.Start(__func__); auto& num_cuts = *p_num_cuts; CHECK_EQ(num_cuts.size(), 0); + num_cuts.resize(sketches_.size()); + auto &reduced = *p_reduced; reduced.resize(sketches_.size()); @@ -212,25 +270,23 @@ void HostSketchContainer::AllReduce( std::vector global_column_size(columns_size_); rabit::Allreduce(global_column_size.data(), global_column_size.size()); -size_t nbytes = 0; - for (size_t i = 0; i < sketches_.size(); ++i) { - int32_t intermediate_num_cuts = static_cast(std::min( - global_column_size[i], static_cast(max_bins_ * WQSketch::kFactor))); + ParallelFor(sketches_.size(), n_threads_, [&](size_t i) { + int32_t intermediate_num_cuts = static_cast( + std::min(global_column_size[i], + static_cast(max_bins_ * WQSketch::kFactor))); if (global_column_size[i] != 0) { WQSketch::SummaryContainer out; sketches_[i].GetSummary(&out); reduced[i].Reserve(intermediate_num_cuts); CHECK(reduced[i].data); reduced[i].SetPrune(out, intermediate_num_cuts); - nbytes = std::max( - WQSketch::SummaryContainer::CalcMemCost(intermediate_num_cuts), - nbytes); } + num_cuts[i] = intermediate_num_cuts; + }); - num_cuts.push_back(intermediate_num_cuts); - } auto world = rabit::GetWorldSize(); if (world == 1) { + monitor_.Stop(__func__); return; } @@ -242,7 +298,7 @@ size_t nbytes = 0; &global_sketches); std::vector final_sketches(n_columns); - ParallelFor(omp_ulong(n_columns), [&](omp_ulong fidx) { + ParallelFor(omp_ulong(n_columns), n_threads_, [&](omp_ulong fidx) { int32_t intermediate_num_cuts = num_cuts[fidx]; auto nbytes = WQSketch::SummaryContainer::CalcMemCost(intermediate_num_cuts); @@ -276,7 +332,7 @@ void AddCutPoint(WQuantileSketch::SummaryContainer const &summary, auto& cut_values = cuts->cut_values_.HostVector(); for (size_t i = 1; i < required_cuts; ++i) { bst_float cpt = summary.data[i].value; - if (i == 1 || cpt > cuts->cut_values_.ConstHostVector().back()) { + if (i == 1 || cpt > cut_values.back()) { cut_values.push_back(cpt); } } @@ -289,23 +345,28 @@ void HostSketchContainer::MakeCuts(HistogramCuts* cuts) { this->AllReduce(&reduced, &num_cuts); cuts->min_vals_.HostVector().resize(sketches_.size(), 0.0f); + std::vector final_summaries(reduced.size()); - for (size_t fid = 0; fid < reduced.size(); ++fid) { - WQSketch::SummaryContainer a; - size_t max_num_bins = std::min(num_cuts[fid], max_bins_); + ParallelFor(reduced.size(), n_threads_, Sched::Guided(), [&](size_t fidx) { + WQSketch::SummaryContainer &a = final_summaries[fidx]; + size_t max_num_bins = std::min(num_cuts[fidx], max_bins_); a.Reserve(max_num_bins + 1); CHECK(a.data); - if (num_cuts[fid] != 0) { - a.SetPrune(reduced[fid], max_num_bins + 1); - CHECK(a.data && reduced[fid].data); + if (num_cuts[fidx] != 0) { + a.SetPrune(reduced[fidx], max_num_bins + 1); + CHECK(a.data && reduced[fidx].data); const bst_float mval = a.data[0].value; - cuts->min_vals_.HostVector()[fid] = mval - fabs(mval) - 1e-5f; + cuts->min_vals_.HostVector()[fidx] = mval - fabs(mval) - 1e-5f; } else { // Empty column. const float mval = 1e-5f; - cuts->min_vals_.HostVector()[fid] = mval; + cuts->min_vals_.HostVector()[fidx] = mval; } + }); + for (size_t fid = 0; fid < reduced.size(); ++fid) { + size_t max_num_bins = std::min(num_cuts[fid], max_bins_); + WQSketch::SummaryContainer const& a = final_summaries[fid]; AddCutPoint(a, max_num_bins, cuts); // push a value that is greater than anything const bst_float cpt diff --git a/src/common/quantile.h b/src/common/quantile.h index a70bf809ea28..4167adbe81c7 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -710,6 +710,7 @@ class HostSketchContainer { std::vector columns_size_; int32_t max_bins_; bool use_group_ind_{false}; + int32_t n_threads_; Monitor monitor_; public: @@ -720,7 +721,7 @@ class HostSketchContainer { * \param use_group whether is assigned to group to data instance. */ HostSketchContainer(std::vector columns_size, int32_t max_bins, - bool use_group); + bool use_group, int32_t n_threads); static bool UseGroup(MetaInfo const &info) { size_t const num_groups = @@ -758,7 +759,8 @@ class HostSketchContainer { std::vector* p_num_cuts); /* \brief Push a CSR matrix. */ - void PushRowPage(SparsePage const& page, MetaInfo const& info); + void PushRowPage(SparsePage const &page, MetaInfo const &info, + std::vector const &hessian = {}); void MakeCuts(HistogramCuts* cuts); }; diff --git a/src/common/threading_utils.h b/src/common/threading_utils.h index 21ba4b41bddb..ab3765f501fe 100644 --- a/src/common/threading_utils.h +++ b/src/common/threading_utils.h @@ -9,6 +9,7 @@ #include #include #include +#include // std::is_signed #include "xgboost/logging.h" namespace xgboost { @@ -133,19 +134,92 @@ void ParallelFor2d(const BlockedSpace2d& space, int nthreads, Func func) { exc.Rethrow(); } +/** + * OpenMP schedule + */ +struct Sched { + enum { + kAuto, + kDynamic, + kStatic, + kGuided, + } sched; + size_t chunk{0}; + + Sched static Auto() { return Sched{kAuto}; } + Sched static Dyn(size_t n = 0) { return Sched{kDynamic, n}; } + Sched static Static(size_t n = 0) { return Sched{kStatic, n}; } + Sched static Guided() { return Sched{kGuided}; } +}; + template -void ParallelFor(Index size, size_t nthreads, Func fn) { +void ParallelFor(Index size, size_t n_threads, Sched sched, Func fn) { +#if defined(_MSC_VER) + // msvc doesn't support unsigned integer as openmp index. + using OmpInd = std::conditional_t::value, Index, omp_ulong>; +#else + using OmpInd = Index; +#endif + OmpInd length = static_cast(size); + dmlc::OMPException exc; -#pragma omp parallel for num_threads(nthreads) schedule(static) - for (Index i = 0; i < size; ++i) { - exc.Run(fn, i); + switch (sched.sched) { + case Sched::kAuto: { +#pragma omp parallel for num_threads(n_threads) + for (OmpInd i = 0; i < length; ++i) { + exc.Run(fn, i); + } + break; + } + case Sched::kDynamic: { + if (sched.chunk == 0) { +#pragma omp parallel for num_threads(n_threads) schedule(dynamic) + for (OmpInd i = 0; i < length; ++i) { + exc.Run(fn, i); + } + } else { +#pragma omp parallel for num_threads(n_threads) schedule(dynamic, sched.chunk) + for (OmpInd i = 0; i < length; ++i) { + exc.Run(fn, i); + } + } + break; + } + case Sched::kStatic: { + if (sched.chunk == 0) { +#pragma omp parallel for num_threads(n_threads) schedule(static) + for (OmpInd i = 0; i < length; ++i) { + exc.Run(fn, i); + } + } else { +#pragma omp parallel for num_threads(n_threads) schedule(static, sched.chunk) + for (OmpInd i = 0; i < length; ++i) { + exc.Run(fn, i); + } + } + break; + } + case Sched::kGuided: { +#pragma omp parallel for num_threads(n_threads) schedule(guided) + for (OmpInd i = 0; i < length; ++i) { + exc.Run(fn, i); + } + break; + } } exc.Rethrow(); } +template +void ParallelFor(Index size, size_t n_threads, Func fn) { + ParallelFor(size, n_threads, Sched::Static(), fn); +} + +// FIXME(jiamingy): Remove this function to get rid of `omp_set_num_threads`, which sets a +// global variable in runtime and affects other programs in the same process. template void ParallelFor(Index size, Func fn) { - ParallelFor(size, omp_get_max_threads(), fn); + ParallelFor(size, omp_get_max_threads(), Sched::Static(), fn); } /* \brief Configure parallel threads. @@ -174,6 +248,12 @@ inline int32_t OmpSetNumThreadsWithoutHT(int32_t* p_threads) { return nthread_original; } +inline int32_t OmpGetNumThreads(int32_t n_threads) { + if (n_threads <= 0) { + n_threads = omp_get_num_procs(); + } + return n_threads; +} } // namespace common } // namespace xgboost diff --git a/src/data/data.cc b/src/data/data.cc index 536a836ecb1a..3829e642639a 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -835,7 +835,7 @@ SparsePage SparsePage::GetTranspose(int num_columns) const { builder.InitBudget(num_columns, nthread); long batch_size = static_cast(this->Size()); // NOLINT(*) auto page = this->GetView(); - common::ParallelFor(batch_size, [&](long i) { // NOLINT(*) + common::ParallelFor(batch_size, nthread, [&](long i) { // NOLINT(*) int tid = omp_get_thread_num(); auto inst = page[i]; for (const auto& entry : inst) { @@ -843,7 +843,7 @@ SparsePage SparsePage::GetTranspose(int num_columns) const { } }); builder.InitStorage(); - common::ParallelFor(batch_size, [&](long i) { // NOLINT(*) + common::ParallelFor(batch_size, nthread, [&](long i) { // NOLINT(*) int tid = omp_get_thread_num(); auto inst = page[i]; for (const auto& entry : inst) { diff --git a/src/learner.cc b/src/learner.cc index 5ffafa782caa..aca03fb9294a 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -238,6 +238,10 @@ void GenericParameter::ConfigureGpuId(bool require_gpu) { #endif // defined(XGBOOST_USE_CUDA) } +int32_t GenericParameter::Threads() const { + return common::OmpGetNumThreads(nthread); +} + using LearnerAPIThreadLocalStore = dmlc::ThreadLocalStore>; diff --git a/src/metric/rank_metric.cc b/src/metric/rank_metric.cc index 193938c0f8e6..1f5f9f506474 100644 --- a/src/metric/rank_metric.cc +++ b/src/metric/rank_metric.cc @@ -112,7 +112,7 @@ struct EvalAMS : public Metric { PredIndPairContainer rec(ndata); const auto &h_preds = preds.ConstHostVector(); - common::ParallelFor(ndata, [&](bst_omp_uint i) { + common::ParallelFor(ndata, this->tparam_->Threads(), [&](bst_omp_uint i) { rec[i] = std::make_pair(h_preds[i], i); }); XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst); diff --git a/src/tree/updater_colmaker.cc b/src/tree/updater_colmaker.cc index 93bd0189e26b..b5b6fe94963b 100644 --- a/src/tree/updater_colmaker.cc +++ b/src/tree/updater_colmaker.cc @@ -105,10 +105,8 @@ class ColMaker: public TreeUpdater { interaction_constraints_.Configure(param_, dmat->Info().num_row_); // build tree for (auto tree : trees) { - Builder builder( - param_, - colmaker_param_, - interaction_constraints_, column_densities_); + Builder builder(param_, colmaker_param_, interaction_constraints_, + column_densities_, this->tparam_->Threads()); builder.Update(gpair->ConstHostVector(), dmat, tree); } param_.learning_rate = lr; @@ -153,9 +151,10 @@ class ColMaker: public TreeUpdater { explicit Builder(const TrainParam& param, const ColMakerTrainParam& colmaker_train_param, FeatureInteractionConstraintHost _interaction_constraints, - const std::vector &column_densities) + const std::vector &column_densities, + int32_t n_threads) : param_(param), colmaker_train_param_{colmaker_train_param}, - nthread_(omp_get_max_threads()), + nthread_(n_threads), tree_evaluator_(param_, column_densities.size(), GenericParameter::kCpuId), interaction_constraints_{std::move(_interaction_constraints)}, column_densities_(column_densities) {} @@ -525,7 +524,7 @@ class ColMaker: public TreeUpdater { // so that they are ignored in future statistics collection const auto ndata = static_cast(p_fmat->Info().num_row_); - common::ParallelFor(ndata, [&](bst_omp_uint ridx) { + common::ParallelFor(ndata, nthread_, [&](bst_omp_uint ridx) { CHECK_LT(ridx, position_.size()) << "ridx exceed bound " << "ridx="<< ridx << " pos=" << position_.size(); const int nid = this->DecodePosition(ridx); @@ -571,7 +570,7 @@ class ColMaker: public TreeUpdater { for (auto fid : fsplits) { auto col = page[fid]; const auto ndata = static_cast(col.size()); - common::ParallelFor(ndata, [&](bst_omp_uint j) { + common::ParallelFor(ndata, this->nthread_, [&](bst_omp_uint j) { const bst_uint ridx = col[j].index; const int nid = this->DecodePosition(ridx); const bst_float fvalue = col[j].fvalue; diff --git a/tests/cpp/common/test_hist_util.cc b/tests/cpp/common/test_hist_util.cc index 5a467fc316a4..054619fea539 100644 --- a/tests/cpp/common/test_hist_util.cc +++ b/tests/cpp/common/test_hist_util.cc @@ -226,6 +226,26 @@ TEST(HistUtil, DenseCutsAccuracyTestWeights) { } } +TEST(HistUtil, QuantileWithHessian) { + int bin_sizes[] = {2, 16, 256, 512}; + int sizes[] = {1000, 1500}; + int num_columns = 5; + for (auto num_rows : sizes) { + auto x = GenerateRandom(num_rows, num_columns); + auto dmat = GetDMatrixFromData(x, num_rows, num_columns); + auto w = GenerateRandomWeights(num_rows); + auto h = GenerateRandomWeights(num_rows); + dmat->Info().weights_.HostVector() = w; + for (auto num_bins : bin_sizes) { + HistogramCuts cuts_hess = SketchOnDMatrix(dmat.get(), num_bins, h); + ValidateCuts(cuts_hess, dmat.get(), num_bins); + + HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins); + ASSERT_NE(cuts.Values(), cuts_hess.Values()); + } + } +} + TEST(HistUtil, DenseCutsExternalMemory) { int bin_sizes[] = {2, 16, 256, 512}; int sizes[] = {100, 1000, 1500}; diff --git a/tests/cpp/common/test_quantile.cc b/tests/cpp/common/test_quantile.cc index fa748de1cc6c..fff74d8e06db 100644 --- a/tests/cpp/common/test_quantile.cc +++ b/tests/cpp/common/test_quantile.cc @@ -43,7 +43,7 @@ void TestDistributedQuantile(size_t rows, size_t cols) { // Generate cuts for distributed environment. auto sparsity = 0.5f; auto rank = rabit::GetRank(); - HostSketchContainer sketch_distributed(column_size, n_bins, false); + HostSketchContainer sketch_distributed(column_size, n_bins, false, OmpGetNumThreads(0)); auto m = RandomDataGenerator{rows, cols, sparsity} .Seed(rank) .Lower(.0f) @@ -59,7 +59,7 @@ void TestDistributedQuantile(size_t rows, size_t cols) { rabit::Finalize(); CHECK_EQ(rabit::GetWorldSize(), 1); std::for_each(column_size.begin(), column_size.end(), [=](auto& size) { size *= world; }); - HostSketchContainer sketch_on_single_node(column_size, n_bins, false); + HostSketchContainer sketch_on_single_node(column_size, n_bins, false, OmpGetNumThreads(0)); for (auto rank = 0; rank < world; ++rank) { auto m = RandomDataGenerator{rows, cols, sparsity} .Seed(rank) From 462445ca69f4f661b7f5d1a0fb93402281c89ca6 Mon Sep 17 00:00:00 2001 From: fis Date: Tue, 6 Jul 2021 22:21:23 +0800 Subject: [PATCH 2/6] Remove unnecessary changes. --- src/metric/rank_metric.cc | 2 +- src/tree/updater_colmaker.cc | 15 ++++++++------- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/metric/rank_metric.cc b/src/metric/rank_metric.cc index 1f5f9f506474..193938c0f8e6 100644 --- a/src/metric/rank_metric.cc +++ b/src/metric/rank_metric.cc @@ -112,7 +112,7 @@ struct EvalAMS : public Metric { PredIndPairContainer rec(ndata); const auto &h_preds = preds.ConstHostVector(); - common::ParallelFor(ndata, this->tparam_->Threads(), [&](bst_omp_uint i) { + common::ParallelFor(ndata, [&](bst_omp_uint i) { rec[i] = std::make_pair(h_preds[i], i); }); XGBOOST_PARALLEL_SORT(rec.begin(), rec.end(), common::CmpFirst); diff --git a/src/tree/updater_colmaker.cc b/src/tree/updater_colmaker.cc index b5b6fe94963b..93bd0189e26b 100644 --- a/src/tree/updater_colmaker.cc +++ b/src/tree/updater_colmaker.cc @@ -105,8 +105,10 @@ class ColMaker: public TreeUpdater { interaction_constraints_.Configure(param_, dmat->Info().num_row_); // build tree for (auto tree : trees) { - Builder builder(param_, colmaker_param_, interaction_constraints_, - column_densities_, this->tparam_->Threads()); + Builder builder( + param_, + colmaker_param_, + interaction_constraints_, column_densities_); builder.Update(gpair->ConstHostVector(), dmat, tree); } param_.learning_rate = lr; @@ -151,10 +153,9 @@ class ColMaker: public TreeUpdater { explicit Builder(const TrainParam& param, const ColMakerTrainParam& colmaker_train_param, FeatureInteractionConstraintHost _interaction_constraints, - const std::vector &column_densities, - int32_t n_threads) + const std::vector &column_densities) : param_(param), colmaker_train_param_{colmaker_train_param}, - nthread_(n_threads), + nthread_(omp_get_max_threads()), tree_evaluator_(param_, column_densities.size(), GenericParameter::kCpuId), interaction_constraints_{std::move(_interaction_constraints)}, column_densities_(column_densities) {} @@ -524,7 +525,7 @@ class ColMaker: public TreeUpdater { // so that they are ignored in future statistics collection const auto ndata = static_cast(p_fmat->Info().num_row_); - common::ParallelFor(ndata, nthread_, [&](bst_omp_uint ridx) { + common::ParallelFor(ndata, [&](bst_omp_uint ridx) { CHECK_LT(ridx, position_.size()) << "ridx exceed bound " << "ridx="<< ridx << " pos=" << position_.size(); const int nid = this->DecodePosition(ridx); @@ -570,7 +571,7 @@ class ColMaker: public TreeUpdater { for (auto fid : fsplits) { auto col = page[fid]; const auto ndata = static_cast(col.size()); - common::ParallelFor(ndata, this->nthread_, [&](bst_omp_uint j) { + common::ParallelFor(ndata, [&](bst_omp_uint j) { const bst_uint ridx = col[j].index; const int nid = this->DecodePosition(ridx); const bst_float fvalue = col[j].fvalue; From f9035bea2719aa95f2c966c810662bc990877efd Mon Sep 17 00:00:00 2001 From: fis Date: Wed, 7 Jul 2021 14:53:22 +0800 Subject: [PATCH 3/6] Compare the results. --- src/common/quantile.cc | 2 +- tests/cpp/common/test_hist_util.cc | 20 ++++++++++++++++++-- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/src/common/quantile.cc b/src/common/quantile.cc index 56a70bae16dc..fcbd76e52b51 100644 --- a/src/common/quantile.cc +++ b/src/common/quantile.cc @@ -298,7 +298,7 @@ void HostSketchContainer::AllReduce( &global_sketches); std::vector final_sketches(n_columns); - ParallelFor(omp_ulong(n_columns), n_threads_, [&](omp_ulong fidx) { + ParallelFor(n_columns, n_threads_, [&](auto fidx) { int32_t intermediate_num_cuts = num_cuts[fidx]; auto nbytes = WQSketch::SummaryContainer::CalcMemCost(intermediate_num_cuts); diff --git a/tests/cpp/common/test_hist_util.cc b/tests/cpp/common/test_hist_util.cc index 054619fea539..8ff96e065cf9 100644 --- a/tests/cpp/common/test_hist_util.cc +++ b/tests/cpp/common/test_hist_util.cc @@ -234,14 +234,30 @@ TEST(HistUtil, QuantileWithHessian) { auto x = GenerateRandom(num_rows, num_columns); auto dmat = GetDMatrixFromData(x, num_rows, num_columns); auto w = GenerateRandomWeights(num_rows); - auto h = GenerateRandomWeights(num_rows); + auto hessian = GenerateRandomWeights(num_rows); dmat->Info().weights_.HostVector() = w; + for (auto num_bins : bin_sizes) { - HistogramCuts cuts_hess = SketchOnDMatrix(dmat.get(), num_bins, h); + HistogramCuts cuts_hess = SketchOnDMatrix(dmat.get(), num_bins, hessian); ValidateCuts(cuts_hess, dmat.get(), num_bins); HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins); + ValidateCuts(cuts, dmat.get(), num_bins); + ASSERT_NE(cuts.Values(), cuts_hess.Values()); + + for (size_t i = 0; i < w.size(); ++i) { + dmat->Info().weights_.HostVector()[i] = w[i] * hessian[i]; + } + + HistogramCuts cuts_wh = SketchOnDMatrix(dmat.get(), num_bins); + ValidateCuts(cuts_wh, dmat.get(), num_bins); + + for (size_t i = 0; i < cuts.Values().size(); ++i) { + ASSERT_NEAR(cuts_wh.Values()[i], cuts_hess.Values()[i], kRtEps); + } + + dmat->Info().weights_.HostVector() = w; } } } From 743ee50ed38b4cc80907808c305b3577d2f67958 Mon Sep 17 00:00:00 2001 From: fis Date: Wed, 7 Jul 2021 15:03:58 +0800 Subject: [PATCH 4/6] Check size. --- tests/cpp/common/test_hist_util.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/cpp/common/test_hist_util.cc b/tests/cpp/common/test_hist_util.cc index 8ff96e065cf9..b1ecf8070fb4 100644 --- a/tests/cpp/common/test_hist_util.cc +++ b/tests/cpp/common/test_hist_util.cc @@ -253,7 +253,8 @@ TEST(HistUtil, QuantileWithHessian) { HistogramCuts cuts_wh = SketchOnDMatrix(dmat.get(), num_bins); ValidateCuts(cuts_wh, dmat.get(), num_bins); - for (size_t i = 0; i < cuts.Values().size(); ++i) { + ASSERT_EQ(cuts_hess.Values().size(), cuts_wh.Values().size()); + for (size_t i = 0; i < cuts_hess.Values().size(); ++i) { ASSERT_NEAR(cuts_wh.Values()[i], cuts_hess.Values()[i], kRtEps); } From 0ce38f5b83b4d338b36ea3482c9b575ea160cbd8 Mon Sep 17 00:00:00 2001 From: fis Date: Wed, 7 Jul 2021 15:04:29 +0800 Subject: [PATCH 5/6] Remove unnecessary changes. --- src/data/data.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/data/data.cc b/src/data/data.cc index 3829e642639a..536a836ecb1a 100644 --- a/src/data/data.cc +++ b/src/data/data.cc @@ -835,7 +835,7 @@ SparsePage SparsePage::GetTranspose(int num_columns) const { builder.InitBudget(num_columns, nthread); long batch_size = static_cast(this->Size()); // NOLINT(*) auto page = this->GetView(); - common::ParallelFor(batch_size, nthread, [&](long i) { // NOLINT(*) + common::ParallelFor(batch_size, [&](long i) { // NOLINT(*) int tid = omp_get_thread_num(); auto inst = page[i]; for (const auto& entry : inst) { @@ -843,7 +843,7 @@ SparsePage SparsePage::GetTranspose(int num_columns) const { } }); builder.InitStorage(); - common::ParallelFor(batch_size, nthread, [&](long i) { // NOLINT(*) + common::ParallelFor(batch_size, [&](long i) { // NOLINT(*) int tid = omp_get_thread_num(); auto inst = page[i]; for (const auto& entry : inst) { From 750c2aa61359398cc7308e724748dfc6d8641476 Mon Sep 17 00:00:00 2001 From: fis Date: Wed, 7 Jul 2021 15:30:28 +0800 Subject: [PATCH 6/6] Shuffle the hessian in test. --- tests/cpp/common/test_hist_util.cc | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/tests/cpp/common/test_hist_util.cc b/tests/cpp/common/test_hist_util.cc index b1ecf8070fb4..3abf0d45dc81 100644 --- a/tests/cpp/common/test_hist_util.cc +++ b/tests/cpp/common/test_hist_util.cc @@ -235,20 +235,16 @@ TEST(HistUtil, QuantileWithHessian) { auto dmat = GetDMatrixFromData(x, num_rows, num_columns); auto w = GenerateRandomWeights(num_rows); auto hessian = GenerateRandomWeights(num_rows); + std::mt19937 rng(0); + std::shuffle(hessian.begin(), hessian.end(), rng); dmat->Info().weights_.HostVector() = w; for (auto num_bins : bin_sizes) { HistogramCuts cuts_hess = SketchOnDMatrix(dmat.get(), num_bins, hessian); - ValidateCuts(cuts_hess, dmat.get(), num_bins); - - HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins); - ValidateCuts(cuts, dmat.get(), num_bins); - - ASSERT_NE(cuts.Values(), cuts_hess.Values()); - for (size_t i = 0; i < w.size(); ++i) { dmat->Info().weights_.HostVector()[i] = w[i] * hessian[i]; } + ValidateCuts(cuts_hess, dmat.get(), num_bins); HistogramCuts cuts_wh = SketchOnDMatrix(dmat.get(), num_bins); ValidateCuts(cuts_wh, dmat.get(), num_bins);