diff --git a/include/xgboost/gbm.h b/include/xgboost/gbm.h index b2808e377922..c084be85a89b 100644 --- a/include/xgboost/gbm.h +++ b/include/xgboost/gbm.h @@ -38,7 +38,7 @@ class PredictionContainer; */ class GradientBooster : public Model, public Configurable { protected: - GenericParameter const* generic_param_; + GenericParameter const* ctx_; public: /*! \brief virtual destructor */ diff --git a/include/xgboost/linear_updater.h b/include/xgboost/linear_updater.h index 39a0c324a958..1506093ee025 100644 --- a/include/xgboost/linear_updater.h +++ b/include/xgboost/linear_updater.h @@ -29,7 +29,7 @@ class GBLinearModel; */ class LinearUpdater : public Configurable { protected: - GenericParameter const* learner_param_; + GenericParameter const* ctx_; public: /*! \brief virtual destructor */ diff --git a/src/gbm/gblinear.cc b/src/gbm/gblinear.cc index e5f5916211cc..31f0d632fe26 100644 --- a/src/gbm/gblinear.cc +++ b/src/gbm/gblinear.cc @@ -86,7 +86,7 @@ class GBLinear : public GradientBooster { } param_.UpdateAllowUnknown(cfg); param_.CheckGPUSupport(); - updater_.reset(LinearUpdater::Create(param_.updater, generic_param_)); + updater_.reset(LinearUpdater::Create(param_.updater, ctx_)); updater_->Configure(cfg); monitor_.Init("GBLinear"); } @@ -120,7 +120,7 @@ class GBLinear : public GradientBooster { CHECK_EQ(get(in["name"]), "gblinear"); FromJson(in["gblinear_train_param"], ¶m_); param_.CheckGPUSupport(); - updater_.reset(LinearUpdater::Create(param_.updater, generic_param_)); + updater_.reset(LinearUpdater::Create(param_.updater, ctx_)); this->updater_->LoadConfig(in["updater"]); } void SaveConfig(Json* p_out) const override { diff --git a/src/gbm/gbm.cc b/src/gbm/gbm.cc index 87a6ded29042..87f0bc5b89b5 100644 --- a/src/gbm/gbm.cc +++ b/src/gbm/gbm.cc @@ -26,7 +26,7 @@ GradientBooster* GradientBooster::Create( LOG(FATAL) << "Unknown gbm type " << name; } auto p_bst = (e->body)(learner_model_param); - p_bst->generic_param_ = generic_param; + p_bst->ctx_ = generic_param; return p_bst; } } // namespace xgboost diff --git a/src/gbm/gbtree.cc b/src/gbm/gbtree.cc index 569bf9991cbf..ff736b8ba260 100644 --- a/src/gbm/gbtree.cc +++ b/src/gbm/gbtree.cc @@ -49,14 +49,14 @@ void GBTree::Configure(const Args& cfg) { // configure predictors if (!cpu_predictor_) { cpu_predictor_ = std::unique_ptr( - Predictor::Create("cpu_predictor", this->generic_param_)); + Predictor::Create("cpu_predictor", this->ctx_)); } cpu_predictor_->Configure(cfg); #if defined(XGBOOST_USE_CUDA) auto n_gpus = common::AllVisibleGPUs(); if (!gpu_predictor_ && n_gpus != 0) { gpu_predictor_ = std::unique_ptr( - Predictor::Create("gpu_predictor", this->generic_param_)); + Predictor::Create("gpu_predictor", this->ctx_)); } if (n_gpus != 0) { gpu_predictor_->Configure(cfg); @@ -201,16 +201,16 @@ void GPUCopyGradient(HostDeviceVector const *in_gpair, } #endif -void CopyGradient(HostDeviceVector const *in_gpair, +void CopyGradient(HostDeviceVector const* in_gpair, int32_t n_threads, bst_group_t n_groups, bst_group_t group_id, - HostDeviceVector *out_gpair) { + HostDeviceVector* out_gpair) { if (in_gpair->DeviceIdx() != GenericParameter::kCpuId) { GPUCopyGradient(in_gpair, n_groups, group_id, out_gpair); } else { std::vector &tmp_h = out_gpair->HostVector(); auto nsize = static_cast(out_gpair->Size()); const auto &gpair_h = in_gpair->ConstHostVector(); - common::ParallelFor(nsize, [&](bst_omp_uint i) { + common::ParallelFor(nsize, n_threads, [&](bst_omp_uint i) { tmp_h[i] = gpair_h[i * n_groups + group_id]; }); } @@ -228,7 +228,7 @@ void GBTree::DoBoost(DMatrix* p_fmat, // break a lots of existing code. auto device = tparam_.tree_method != TreeMethod::kGPUHist ? GenericParameter::kCpuId - : generic_param_->gpu_id; + : ctx_->gpu_id; auto out = linalg::TensorView{ device == GenericParameter::kCpuId ? predt->predictions.HostSpan() : predt->predictions.DeviceSpan(), @@ -255,7 +255,7 @@ void GBTree::DoBoost(DMatrix* p_fmat, in_gpair->DeviceIdx()); bool update_predict = true; for (int gid = 0; gid < ngroup; ++gid) { - CopyGradient(in_gpair, ngroup, gid, &tmp); + CopyGradient(in_gpair, ctx_->Threads(), ngroup, gid, &tmp); std::vector > ret; BoostNewTrees(&tmp, p_fmat, gid, &ret); const size_t num_new_trees = ret.size(); @@ -310,7 +310,7 @@ void GBTree::InitUpdater(Args const& cfg) { // create new updaters for (const std::string& pstr : ups) { std::unique_ptr up( - TreeUpdater::Create(pstr.c_str(), generic_param_, model_.learner_model_param->task)); + TreeUpdater::Create(pstr.c_str(), ctx_, model_.learner_model_param->task)); up->Configure(cfg); updaters_.push_back(std::move(up)); } @@ -396,7 +396,7 @@ void GBTree::LoadConfig(Json const& in) { updaters_.clear(); for (auto const& kv : j_updaters) { std::unique_ptr up( - TreeUpdater::Create(kv.first, generic_param_, model_.learner_model_param->task)); + TreeUpdater::Create(kv.first, ctx_, model_.learner_model_param->task)); up->LoadConfig(kv.second); updaters_.push_back(std::move(up)); } @@ -562,7 +562,7 @@ GBTree::GetPredictor(HostDeviceVector const *out_pred, auto on_device = is_ellpack || is_from_device; // Use GPU Predictor if data is already on device and gpu_id is set. - if (on_device && generic_param_->gpu_id >= 0) { + if (on_device && ctx_->gpu_id >= 0) { #if defined(XGBOOST_USE_CUDA) CHECK_GE(common::AllVisibleGPUs(), 1) << "No visible GPU is found for XGBoost."; CHECK(gpu_predictor_); @@ -728,8 +728,8 @@ class Dart : public GBTree { auto n_groups = model_.learner_model_param->num_output_group; PredictionCacheEntry predts; // temporary storage for prediction - if (generic_param_->gpu_id != GenericParameter::kCpuId) { - predts.predictions.SetDevice(generic_param_->gpu_id); + if (ctx_->gpu_id != GenericParameter::kCpuId) { + predts.predictions.SetDevice(ctx_->gpu_id); } predts.predictions.Resize(p_fmat->Info().num_row_ * n_groups, 0); @@ -758,11 +758,10 @@ class Dart : public GBTree { } else { auto &h_out_predts = p_out_preds->predictions.HostVector(); auto &h_predts = predts.predictions.HostVector(); -#pragma omp parallel for - for (omp_ulong ridx = 0; ridx < p_fmat->Info().num_row_; ++ridx) { + common::ParallelFor(p_fmat->Info().num_row_, ctx_->Threads(), [&](auto ridx) { const size_t offset = ridx * n_groups + group; h_out_predts[offset] += (h_predts[offset] * w); - } + }); } } } @@ -846,13 +845,11 @@ class Dart : public GBTree { if (device == GenericParameter::kCpuId) { auto &h_predts = predts.predictions.HostVector(); auto &h_out_predts = out_preds->predictions.HostVector(); -#pragma omp parallel for - for (omp_ulong ridx = 0; ridx < n_rows; ++ridx) { + common::ParallelFor(n_rows, ctx_->Threads(), [&](auto ridx) { const size_t offset = ridx * n_groups + group; // Need to remove the base margin from individual tree. - h_out_predts[offset] += - (h_predts[offset] - model_.learner_model_param->base_score) * w; - } + h_out_predts[offset] += (h_predts[offset] - model_.learner_model_param->base_score) * w; + }); } else { out_preds->predictions.SetDevice(device); predts.predictions.SetDevice(device); diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index 4e508bbaec2f..a357889d7d9b 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -413,10 +413,9 @@ class GBTree : public GradientBooster { p_fmat, out_contribs, model_, tree_end, nullptr, approximate); } - std::vector DumpModel(const FeatureMap& fmap, - bool with_stats, + std::vector DumpModel(const FeatureMap& fmap, bool with_stats, std::string format) const override { - return model_.DumpModel(fmap, with_stats, format); + return model_.DumpModel(fmap, with_stats, this->ctx_->Threads(), format); } protected: diff --git a/src/gbm/gbtree_model.h b/src/gbm/gbtree_model.h index c5e05c0157af..6c13d8644a85 100644 --- a/src/gbm/gbtree_model.h +++ b/src/gbm/gbtree_model.h @@ -109,12 +109,11 @@ struct GBTreeModel : public Model { void SaveModel(Json* p_out) const override; void LoadModel(Json const& p_out) override; - std::vector DumpModel(const FeatureMap &fmap, bool with_stats, + std::vector DumpModel(const FeatureMap& fmap, bool with_stats, int32_t n_threads, std::string format) const { std::vector dump(trees.size()); - common::ParallelFor(static_cast(trees.size()), [&](size_t i) { - dump[i] = trees[i]->DumpModel(fmap, with_stats, format); - }); + common::ParallelFor(trees.size(), n_threads, + [&](size_t i) { dump[i] = trees[i]->DumpModel(fmap, with_stats, format); }); return dump; } void CommitModel(std::vector >&& new_trees, diff --git a/src/linear/coordinate_common.h b/src/linear/coordinate_common.h index d01bce826c2e..1f7c81d11975 100644 --- a/src/linear/coordinate_common.h +++ b/src/linear/coordinate_common.h @@ -149,21 +149,21 @@ GetGradientParallel(GenericParameter const *ctx, int group_idx, int num_group, */ inline std::pair GetBiasGradientParallel(int group_idx, int num_group, const std::vector &gpair, - DMatrix *p_fmat) { - double sum_grad = 0.0, sum_hess = 0.0; + DMatrix *p_fmat, int32_t n_threads) { const auto ndata = static_cast(p_fmat->Info().num_row_); - dmlc::OMPException exc; -#pragma omp parallel for schedule(static) reduction(+ : sum_grad, sum_hess) - for (bst_omp_uint i = 0; i < ndata; ++i) { - exc.Run([&]() { - auto &p = gpair[i * num_group + group_idx]; - if (p.GetHess() >= 0.0f) { - sum_grad += p.GetGrad(); - sum_hess += p.GetHess(); - } - }); - } - exc.Rethrow(); + std::vector sum_grad_tloc(n_threads, 0); + std::vector sum_hess_tloc(n_threads, 0); + + common::ParallelFor(ndata, n_threads, [&](auto i) { + auto tid = omp_get_thread_num(); + auto &p = gpair[i * num_group + group_idx]; + if (p.GetHess() >= 0.0f) { + sum_grad_tloc[tid] += p.GetGrad(); + sum_hess_tloc[tid] += p.GetHess(); + } + }); + double sum_grad = std::accumulate(sum_grad_tloc.cbegin(), sum_grad_tloc.cend(), 0.0); + double sum_hess = std::accumulate(sum_hess_tloc.cbegin(), sum_hess_tloc.cend(), 0.0); return std::make_pair(sum_grad, sum_hess); } @@ -179,23 +179,18 @@ inline std::pair GetBiasGradientParallel(int group_idx, int num_ */ inline void UpdateResidualParallel(int fidx, int group_idx, int num_group, float dw, std::vector *in_gpair, - DMatrix *p_fmat) { + DMatrix *p_fmat, int32_t n_threads) { if (dw == 0.0f) return; for (const auto &batch : p_fmat->GetBatches()) { auto page = batch.GetView(); auto col = page[fidx]; // update grad value const auto num_row = static_cast(col.size()); - dmlc::OMPException exc; -#pragma omp parallel for schedule(static) - for (bst_omp_uint j = 0; j < num_row; ++j) { - exc.Run([&]() { - GradientPair &p = (*in_gpair)[col[j].index * num_group + group_idx]; - if (p.GetHess() < 0.0f) return; - p += GradientPair(p.GetHess() * col[j].fvalue * dw, 0); - }); - } - exc.Rethrow(); + common::ParallelFor(num_row, n_threads, [&](auto j) { + GradientPair &p = (*in_gpair)[col[j].index * num_group + group_idx]; + if (p.GetHess() < 0.0f) return; + p += GradientPair(p.GetHess() * col[j].fvalue * dw, 0); + }); } } @@ -209,20 +204,15 @@ inline void UpdateResidualParallel(int fidx, int group_idx, int num_group, * \param p_fmat The input feature matrix. */ inline void UpdateBiasResidualParallel(int group_idx, int num_group, float dbias, - std::vector *in_gpair, - DMatrix *p_fmat) { + std::vector *in_gpair, DMatrix *p_fmat, + int32_t n_threads) { if (dbias == 0.0f) return; const auto ndata = static_cast(p_fmat->Info().num_row_); - dmlc::OMPException exc; -#pragma omp parallel for schedule(static) - for (bst_omp_uint i = 0; i < ndata; ++i) { - exc.Run([&]() { - GradientPair &g = (*in_gpair)[i * num_group + group_idx]; - if (g.GetHess() < 0.0f) return; - g += GradientPair(g.GetHess() * dbias, 0); - }); - } - exc.Rethrow(); + common::ParallelFor(ndata, n_threads, [&](auto i) { + GradientPair &g = (*in_gpair)[i * num_group + group_idx]; + if (g.GetHess() < 0.0f) return; + g += GradientPair(g.GetHess() * dbias, 0); + }); } /** @@ -230,9 +220,13 @@ inline void UpdateBiasResidualParallel(int group_idx, int num_group, float dbias * in coordinate descent algorithms. */ class FeatureSelector { + protected: + int32_t n_threads_{-1}; + public: + explicit FeatureSelector(int32_t n_threads) : n_threads_{n_threads} {} /*! \brief factory method */ - static FeatureSelector *Create(int choice); + static FeatureSelector *Create(int choice, int32_t n_threads); /*! \brief virtual destructor */ virtual ~FeatureSelector() = default; /** @@ -274,6 +268,7 @@ class FeatureSelector { */ class CyclicFeatureSelector : public FeatureSelector { public: + using FeatureSelector::FeatureSelector; int NextFeature(int iteration, const gbm::GBLinearModel &model, int , const std::vector &, DMatrix *, float, float) override { @@ -287,6 +282,7 @@ class CyclicFeatureSelector : public FeatureSelector { */ class ShuffleFeatureSelector : public FeatureSelector { public: + using FeatureSelector::FeatureSelector; void Setup(const gbm::GBLinearModel &model, const std::vector&, DMatrix *, float, float, int) override { @@ -313,6 +309,7 @@ class ShuffleFeatureSelector : public FeatureSelector { */ class RandomFeatureSelector : public FeatureSelector { public: + using FeatureSelector::FeatureSelector; int NextFeature(int, const gbm::GBLinearModel &model, int, const std::vector &, DMatrix *, float, float) override { @@ -331,6 +328,7 @@ class RandomFeatureSelector : public FeatureSelector { */ class GreedyFeatureSelector : public FeatureSelector { public: + using FeatureSelector::FeatureSelector; void Setup(const gbm::GBLinearModel &model, const std::vector &, DMatrix *, float, float, int param) override { @@ -360,7 +358,7 @@ class GreedyFeatureSelector : public FeatureSelector { std::fill(gpair_sums_.begin(), gpair_sums_.end(), std::make_pair(0., 0.)); for (const auto &batch : p_fmat->GetBatches()) { auto page = batch.GetView(); - common::ParallelFor(nfeat, [&](bst_omp_uint i) { + common::ParallelFor(nfeat, this->n_threads_, [&](bst_omp_uint i) { const auto col = page[i]; const bst_uint ndata = col.size(); auto &sums = gpair_sums_[group_idx * nfeat + i]; @@ -407,6 +405,7 @@ class GreedyFeatureSelector : public FeatureSelector { */ class ThriftyFeatureSelector : public FeatureSelector { public: + using FeatureSelector::FeatureSelector; void Setup(const gbm::GBLinearModel &model, const std::vector &gpair, DMatrix *p_fmat, float alpha, float lambda, int param) override { @@ -426,7 +425,7 @@ class ThriftyFeatureSelector : public FeatureSelector { for (const auto &batch : p_fmat->GetBatches()) { auto page = batch.GetView(); // column-parallel is usually fastaer than row-parallel - common::ParallelFor(nfeat, [&](bst_omp_uint i) { + common::ParallelFor(nfeat, this->n_threads_, [&](auto i) { const auto col = page[i]; const bst_uint ndata = col.size(); for (bst_uint gid = 0u; gid < ngroup; ++gid) { @@ -483,18 +482,18 @@ class ThriftyFeatureSelector : public FeatureSelector { std::vector> gpair_sums_; }; -inline FeatureSelector *FeatureSelector::Create(int choice) { +inline FeatureSelector *FeatureSelector::Create(int choice, int32_t n_threads) { switch (choice) { case kCyclic: - return new CyclicFeatureSelector(); + return new CyclicFeatureSelector(n_threads); case kShuffle: - return new ShuffleFeatureSelector(); + return new ShuffleFeatureSelector(n_threads); case kThrifty: - return new ThriftyFeatureSelector(); + return new ThriftyFeatureSelector(n_threads); case kGreedy: - return new GreedyFeatureSelector(); + return new GreedyFeatureSelector(n_threads); case kRandom: - return new RandomFeatureSelector(); + return new RandomFeatureSelector(n_threads); default: LOG(FATAL) << "unknown coordinate selector: " << choice; } diff --git a/src/linear/linear_updater.cc b/src/linear/linear_updater.cc index 95c9908e872a..4593d54f0907 100644 --- a/src/linear/linear_updater.cc +++ b/src/linear/linear_updater.cc @@ -17,7 +17,7 @@ LinearUpdater* LinearUpdater::Create(const std::string& name, GenericParameter c LOG(FATAL) << "Unknown linear updater " << name; } auto p_linear = (e->body)(); - p_linear->learner_param_ = lparam; + p_linear->ctx_ = lparam; return p_linear; } diff --git a/src/linear/updater_coordinate.cc b/src/linear/updater_coordinate.cc index ae070f5d5700..29ba5451b94e 100644 --- a/src/linear/updater_coordinate.cc +++ b/src/linear/updater_coordinate.cc @@ -30,7 +30,7 @@ class CoordinateUpdater : public LinearUpdater { tparam_.UpdateAllowUnknown(args) }; cparam_.UpdateAllowUnknown(rest); - selector_.reset(FeatureSelector::Create(tparam_.feature_selector)); + selector_.reset(FeatureSelector::Create(tparam_.feature_selector, ctx_->Threads())); monitor_.Init("CoordinateUpdater"); } @@ -51,13 +51,13 @@ class CoordinateUpdater : public LinearUpdater { const int ngroup = model->learner_model_param->num_output_group; // update bias for (int group_idx = 0; group_idx < ngroup; ++group_idx) { - auto grad = GetBiasGradientParallel(group_idx, ngroup, - in_gpair->ConstHostVector(), p_fmat); + auto grad = GetBiasGradientParallel(group_idx, ngroup, in_gpair->ConstHostVector(), p_fmat, + ctx_->Threads()); auto dbias = static_cast(tparam_.learning_rate * CoordinateDeltaBias(grad.first, grad.second)); model->Bias()[group_idx] += dbias; - UpdateBiasResidualParallel(group_idx, ngroup, - dbias, &in_gpair->HostVector(), p_fmat); + UpdateBiasResidualParallel(group_idx, ngroup, dbias, &in_gpair->HostVector(), p_fmat, + ctx_->Threads()); } // prepare for updating the weights selector_->Setup(*model, in_gpair->ConstHostVector(), p_fmat, @@ -80,14 +80,15 @@ class CoordinateUpdater : public LinearUpdater { DMatrix *p_fmat, gbm::GBLinearModel *model) { const int ngroup = model->learner_model_param->num_output_group; bst_float &w = (*model)[fidx][group_idx]; - auto gradient = GetGradientParallel(learner_param_, group_idx, ngroup, fidx, + auto gradient = GetGradientParallel(ctx_, group_idx, ngroup, fidx, *in_gpair, p_fmat); auto dw = static_cast( tparam_.learning_rate * CoordinateDelta(gradient.first, gradient.second, w, tparam_.reg_alpha_denorm, tparam_.reg_lambda_denorm)); w += dw; - UpdateResidualParallel(fidx, group_idx, ngroup, dw, in_gpair, p_fmat); + UpdateResidualParallel(fidx, group_idx, ngroup, dw, in_gpair, p_fmat, + ctx_->Threads()); } private: diff --git a/src/linear/updater_gpu_coordinate.cu b/src/linear/updater_gpu_coordinate.cu index 1634201780c8..4d2a8c5b0259 100644 --- a/src/linear/updater_gpu_coordinate.cu +++ b/src/linear/updater_gpu_coordinate.cu @@ -32,7 +32,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT void Configure(Args const& args) override { tparam_.UpdateAllowUnknown(args); coord_param_.UpdateAllowUnknown(args); - selector_.reset(FeatureSelector::Create(tparam_.feature_selector)); + selector_.reset(FeatureSelector::Create(tparam_.feature_selector, ctx_->Threads())); monitor_.Init("GPUCoordinateUpdater"); } @@ -48,7 +48,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT } void LazyInitDevice(DMatrix *p_fmat, const LearnerModelParam &model_param) { - if (learner_param_->gpu_id < 0) return; + if (ctx_->gpu_id < 0) return; num_row_ = static_cast(p_fmat->Info().num_row_); @@ -60,7 +60,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT return; } - dh::safe_cuda(cudaSetDevice(learner_param_->gpu_id)); + dh::safe_cuda(cudaSetDevice(ctx_->gpu_id)); // The begin and end indices for the section of each column associated with // this device std::vector> column_segments; @@ -103,7 +103,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT monitor_.Start("UpdateGpair"); auto &in_gpair_host = in_gpair->ConstHostVector(); // Update gpair - if (learner_param_->gpu_id >= 0) { + if (ctx_->gpu_id >= 0) { this->UpdateGpair(in_gpair_host); } monitor_.Stop("UpdateGpair"); @@ -134,7 +134,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT ++group_idx) { // Get gradient auto grad = GradientPair(0, 0); - if (learner_param_->gpu_id >= 0) { + if (ctx_->gpu_id >= 0) { grad = GetBiasGradient(group_idx, model->learner_model_param->num_output_group); } auto dbias = static_cast( @@ -143,7 +143,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT model->Bias()[group_idx] += dbias; // Update residual - if (learner_param_->gpu_id >= 0) { + if (ctx_->gpu_id >= 0) { UpdateBiasResidual(dbias, group_idx, model->learner_model_param->num_output_group); } } @@ -155,7 +155,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT bst_float &w = (*model)[fidx][group_idx]; // Get gradient auto grad = GradientPair(0, 0); - if (learner_param_->gpu_id >= 0) { + if (ctx_->gpu_id >= 0) { grad = GetGradient(group_idx, model->learner_model_param->num_output_group, fidx); } auto dw = static_cast(tparam_.learning_rate * @@ -164,14 +164,14 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT tparam_.reg_lambda_denorm)); w += dw; - if (learner_param_->gpu_id >= 0) { + if (ctx_->gpu_id >= 0) { UpdateResidual(dw, group_idx, model->learner_model_param->num_output_group, fidx); } } // This needs to be public because of the __device__ lambda. GradientPair GetBiasGradient(int group_idx, int num_group) { - dh::safe_cuda(cudaSetDevice(learner_param_->gpu_id)); + dh::safe_cuda(cudaSetDevice(ctx_->gpu_id)); auto counting = thrust::make_counting_iterator(0ull); auto f = [=] __device__(size_t idx) { return idx * num_group + group_idx; @@ -195,7 +195,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT // This needs to be public because of the __device__ lambda. GradientPair GetGradient(int group_idx, int num_group, int fidx) { - dh::safe_cuda(cudaSetDevice(learner_param_->gpu_id)); + dh::safe_cuda(cudaSetDevice(ctx_->gpu_id)); common::Span d_col = dh::ToSpan(data_).subspan(row_ptr_[fidx]); size_t col_size = row_ptr_[fidx + 1] - row_ptr_[fidx]; common::Span d_gpair = dh::ToSpan(gpair_); diff --git a/src/linear/updater_shotgun.cc b/src/linear/updater_shotgun.cc index 70f4e98d04af..d8592f1cf994 100644 --- a/src/linear/updater_shotgun.cc +++ b/src/linear/updater_shotgun.cc @@ -21,7 +21,7 @@ class ShotgunUpdater : public LinearUpdater { LOG(FATAL) << "Unsupported feature selector for shotgun updater.\n" << "Supported options are: {cyclic, shuffle}"; } - selector_.reset(FeatureSelector::Create(param_.feature_selector)); + selector_.reset(FeatureSelector::Create(param_.feature_selector, ctx_->Threads())); } void LoadConfig(Json const& in) override { auto const& config = get(in); @@ -40,12 +40,13 @@ class ShotgunUpdater : public LinearUpdater { // update bias for (int gid = 0; gid < ngroup; ++gid) { - auto grad = GetBiasGradientParallel(gid, ngroup, - in_gpair->ConstHostVector(), p_fmat); + auto grad = GetBiasGradientParallel(gid, ngroup, in_gpair->ConstHostVector(), p_fmat, + ctx_->Threads()); auto dbias = static_cast(param_.learning_rate * CoordinateDeltaBias(grad.first, grad.second)); model->Bias()[gid] += dbias; - UpdateBiasResidualParallel(gid, ngroup, dbias, &in_gpair->HostVector(), p_fmat); + UpdateBiasResidualParallel(gid, ngroup, dbias, &in_gpair->HostVector(), p_fmat, + ctx_->Threads()); } // lock-free parallel updates of weights @@ -54,42 +55,35 @@ class ShotgunUpdater : public LinearUpdater { for (const auto &batch : p_fmat->GetBatches()) { auto page = batch.GetView(); const auto nfeat = static_cast(batch.Size()); - dmlc::OMPException exc; -#pragma omp parallel for schedule(static) - for (bst_omp_uint i = 0; i < nfeat; ++i) { - exc.Run([&]() { - int ii = selector_->NextFeature - (i, *model, 0, in_gpair->ConstHostVector(), p_fmat, param_.reg_alpha_denorm, - param_.reg_lambda_denorm); - if (ii < 0) return; - const bst_uint fid = ii; - auto col = page[ii]; - for (int gid = 0; gid < ngroup; ++gid) { - double sum_grad = 0.0, sum_hess = 0.0; - for (auto& c : col) { - const GradientPair &p = gpair[c.index * ngroup + gid]; - if (p.GetHess() < 0.0f) continue; - const bst_float v = c.fvalue; - sum_grad += p.GetGrad() * v; - sum_hess += p.GetHess() * v * v; - } - bst_float &w = (*model)[fid][gid]; - auto dw = static_cast( - param_.learning_rate * - CoordinateDelta(sum_grad, sum_hess, w, param_.reg_alpha_denorm, - param_.reg_lambda_denorm)); - if (dw == 0.f) continue; - w += dw; - // update grad values - for (auto& c : col) { - GradientPair &p = gpair[c.index * ngroup + gid]; - if (p.GetHess() < 0.0f) continue; - p += GradientPair(p.GetHess() * c.fvalue * dw, 0); - } + common::ParallelFor(nfeat, ctx_->Threads(), [&](auto i) { + int ii = selector_->NextFeature(i, *model, 0, in_gpair->ConstHostVector(), p_fmat, + param_.reg_alpha_denorm, param_.reg_lambda_denorm); + if (ii < 0) return; + const bst_uint fid = ii; + auto col = page[ii]; + for (int gid = 0; gid < ngroup; ++gid) { + double sum_grad = 0.0, sum_hess = 0.0; + for (auto &c : col) { + const GradientPair &p = gpair[c.index * ngroup + gid]; + if (p.GetHess() < 0.0f) continue; + const bst_float v = c.fvalue; + sum_grad += p.GetGrad() * v; + sum_hess += p.GetHess() * v * v; } - }); - } - exc.Rethrow(); + bst_float &w = (*model)[fid][gid]; + auto dw = static_cast( + param_.learning_rate * CoordinateDelta(sum_grad, sum_hess, w, param_.reg_alpha_denorm, + param_.reg_lambda_denorm)); + if (dw == 0.f) continue; + w += dw; + // update grad values + for (auto &c : col) { + GradientPair &p = gpair[c.index * ngroup + gid]; + if (p.GetHess() < 0.0f) continue; + p += GradientPair(p.GetHess() * c.fvalue * dw, 0); + } + } + }); } }