From 2b4cf67cd12672c0c8bca04549f0c65400ec6d0d Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Thu, 21 Apr 2022 03:10:55 -0700 Subject: [PATCH 01/14] Remove single_precision_histogram --- doc/gpu/index.rst | 4 +- doc/parameter.rst | 2 +- include/xgboost/tree_updater.h | 6 +- src/tree/tree_updater.cc | 3 +- src/tree/updater_approx.cc | 9 +- src/tree/updater_basemaker-inl.h | 7 +- src/tree/updater_colmaker.cc | 7 +- src/tree/updater_gpu_hist.cu | 162 +++++++--------------- src/tree/updater_histmaker.cc | 20 ++- src/tree/updater_prune.cc | 10 +- src/tree/updater_quantile_hist.cc | 4 +- src/tree/updater_quantile_hist.h | 3 +- src/tree/updater_refresh.cc | 13 +- src/tree/updater_sync.cc | 9 +- tests/cpp/tree/test_gpu_hist.cu | 16 +-- tests/python-gpu/test_gpu_basic_models.py | 4 +- tests/python-gpu/test_gpu_updaters.py | 3 +- 17 files changed, 107 insertions(+), 175 deletions(-) diff --git a/doc/gpu/index.rst b/doc/gpu/index.rst index e36fc72a1746..049cf311dff2 100644 --- a/doc/gpu/index.rst +++ b/doc/gpu/index.rst @@ -59,13 +59,11 @@ Supported parameters +--------------------------------+--------------+ | ``interaction_constraints`` | |tick| | +--------------------------------+--------------+ -| ``single_precision_histogram`` | |tick| | +| ``single_precision_histogram`` | |cross| | +--------------------------------+--------------+ GPU accelerated prediction is enabled by default for the above mentioned ``tree_method`` parameters but can be switched to CPU prediction by setting ``predictor`` to ``cpu_predictor``. This could be useful if you want to conserve GPU memory. Likewise when using CPU algorithms, GPU accelerated prediction can be enabled by setting ``predictor`` to ``gpu_predictor``. -The experimental parameter ``single_precision_histogram`` can be set to True to enable building histograms using single precision. This may improve speed, in particular on older architectures. - The device ordinal (which GPU to use if you have many of them) can be selected using the ``gpu_id`` parameter, which defaults to 0 (the first device reported by CUDA runtime). diff --git a/doc/parameter.rst b/doc/parameter.rst index 781150490082..4392b5bf7680 100644 --- a/doc/parameter.rst +++ b/doc/parameter.rst @@ -240,7 +240,7 @@ Additional parameters for ``hist``, ``gpu_hist`` and ``approx`` tree method * ``single_precision_histogram``, [default= ``false``] - - Use single precision to build histograms instead of double precision. + - Use single precision to build histograms instead of double precision. Currently disabled for ``gpu_hist``. * ``max_cat_to_onehot`` diff --git a/include/xgboost/tree_updater.h b/include/xgboost/tree_updater.h index 6189221dc0bf..6248a65e270d 100644 --- a/include/xgboost/tree_updater.h +++ b/include/xgboost/tree_updater.h @@ -35,6 +35,7 @@ class TreeUpdater : public Configurable { GenericParameter const* ctx_ = nullptr; public: + explicit TreeUpdater(const GenericParameter* ctx) : ctx_(ctx) {} /*! \brief virtual destructor */ ~TreeUpdater() override = default; /*! @@ -91,8 +92,9 @@ class TreeUpdater : public Configurable { * \brief Registry entry for tree updater. */ struct TreeUpdaterReg - : public dmlc::FunctionRegEntryBase > {}; + : public dmlc::FunctionRegEntryBase< + TreeUpdaterReg, + std::function > {}; /*! * \brief Macro to register tree updater. diff --git a/src/tree/tree_updater.cc b/src/tree/tree_updater.cc index 05f6c4bb5fd6..ee5659636305 100644 --- a/src/tree/tree_updater.cc +++ b/src/tree/tree_updater.cc @@ -20,8 +20,7 @@ TreeUpdater* TreeUpdater::Create(const std::string& name, GenericParameter const if (e == nullptr) { LOG(FATAL) << "Unknown tree updater " << name; } - auto p_updater = (e->body)(task); - p_updater->ctx_ = tparam; + auto p_updater = (e->body)(tparam, task); return p_updater; } diff --git a/src/tree/updater_approx.cc b/src/tree/updater_approx.cc index 3bad6f7da4cc..a06f195374b6 100644 --- a/src/tree/updater_approx.cc +++ b/src/tree/updater_approx.cc @@ -252,7 +252,10 @@ class GlobalApproxUpdater : public TreeUpdater { ObjInfo task_; public: - explicit GlobalApproxUpdater(ObjInfo task) : task_{task} { monitor_.Init(__func__); } + explicit GlobalApproxUpdater(GenericParameter const *ctx, ObjInfo task) + : task_{task}, TreeUpdater(ctx) { + monitor_.Init(__func__); + } void Configure(const Args &args) override { param_.UpdateAllowUnknown(args); @@ -343,6 +346,8 @@ XGBOOST_REGISTER_TREE_UPDATER(GlobalHistMaker, "grow_histmaker") .describe( "Tree constructor that uses approximate histogram construction " "for each node.") - .set_body([](ObjInfo task) { return new GlobalApproxUpdater(task); }); + .set_body([](GenericParameter const *ctx, ObjInfo task) { + return new GlobalApproxUpdater(ctx, task); + }); } // namespace tree } // namespace xgboost diff --git a/src/tree/updater_basemaker-inl.h b/src/tree/updater_basemaker-inl.h index da239b2090c7..7fc44a6d15fb 100644 --- a/src/tree/updater_basemaker-inl.h +++ b/src/tree/updater_basemaker-inl.h @@ -33,11 +33,10 @@ namespace tree { * \brief base tree maker class that defines common operation * needed in tree making */ -class BaseMaker: public TreeUpdater { +class BaseMaker : public TreeUpdater { public: - void Configure(const Args& args) override { - param_.UpdateAllowUnknown(args); - } + explicit BaseMaker(GenericParameter const *ctx) : TreeUpdater(ctx) {} + void Configure(const Args &args) override { param_.UpdateAllowUnknown(args); } void LoadConfig(Json const& in) override { auto const& config = get(in); diff --git a/src/tree/updater_colmaker.cc b/src/tree/updater_colmaker.cc index e3d716f2cba8..f4279a0a1c3b 100644 --- a/src/tree/updater_colmaker.cc +++ b/src/tree/updater_colmaker.cc @@ -57,7 +57,8 @@ DMLC_REGISTER_PARAMETER(ColMakerTrainParam); /*! \brief column-wise update to construct a tree */ class ColMaker: public TreeUpdater { public: - void Configure(const Args& args) override { + explicit ColMaker(GenericParameter const *ctx) : TreeUpdater(ctx) {} + void Configure(const Args &args) override { param_.UpdateAllowUnknown(args); colmaker_param_.UpdateAllowUnknown(args); } @@ -614,8 +615,8 @@ class ColMaker: public TreeUpdater { XGBOOST_REGISTER_TREE_UPDATER(ColMaker, "grow_colmaker") .describe("Grow tree with parallelization over columns.") -.set_body([](ObjInfo) { - return new ColMaker(); +.set_body([](GenericParameter const* ctx, ObjInfo) { + return new ColMaker(ctx); }); } // namespace tree } // namespace xgboost diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index cb7dd9b7e8e4..2cac6b6c4f4a 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -45,12 +45,9 @@ DMLC_REGISTRY_FILE_TAG(updater_gpu_hist); // training parameters specific to this algorithm struct GPUHistMakerTrainParam : public XGBoostParameter { - bool single_precision_histogram; bool debug_synchronize; // declare parameters DMLC_DECLARE_PARAMETER(GPUHistMakerTrainParam) { - DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe( - "Use single precision to build histograms."); DMLC_DECLARE_FIELD(debug_synchronize).set_default(false).describe( "Check if all distributed tree are identical after tree construction."); } @@ -532,6 +529,13 @@ struct GPUHistMakerDevice { void ApplySplit(const GPUExpandEntry& candidate, RegTree* p_tree) { RegTree& tree = *p_tree; + + // Sanity check - have we created a leaf with no training instances? + if (!rabit::IsDistributed()) { + CHECK(row_partitioner->GetRows(candidate.nid).size() > 0) + << "No training instances in this leaf!"; + } + auto parent_sum = candidate.split.left_sum + candidate.split.right_sum; auto base_weight = candidate.base_weight; auto left_weight = candidate.left_weight * param.learning_rate; @@ -676,20 +680,35 @@ struct GPUHistMakerDevice { } }; -template -class GPUHistMakerSpecialised { +class GPUHistMaker : public TreeUpdater { + using GradientSumT = GradientPairPrecise; + public: - explicit GPUHistMakerSpecialised(ObjInfo task) : task_{task} {}; - void Configure(const Args& args, GenericParameter const* generic_param) { + explicit GPUHistMaker(GenericParameter const* ctx, ObjInfo task) + : TreeUpdater(ctx), task_{task} {}; + void Configure(const Args& args) { + // Used in test to count how many configurations are performed + LOG(DEBUG) << "[GPU Hist]: Configure"; param_.UpdateAllowUnknown(args); - generic_param_ = generic_param; hist_maker_param_.UpdateAllowUnknown(args); dh::CheckComputeCapability(); monitor_.Init("updater_gpu_hist"); } - ~GPUHistMakerSpecialised() { // NOLINT + void LoadConfig(Json const& in) override { + auto const& config = get(in); + FromJson(config.at("gpu_hist_train_param"), &this->hist_maker_param_); + initialised_ = false; + FromJson(config.at("train_param"), ¶m_); + } + void SaveConfig(Json* p_out) const override { + auto& out = *p_out; + out["gpu_hist_train_param"] = ToJson(hist_maker_param_); + out["train_param"] = ToJson(param_); + } + + ~GPUHistMaker() { // NOLINT dh::GlobalMemoryLogger().Log(); } @@ -719,30 +738,24 @@ class GPUHistMakerSpecialised { } void InitDataOnce(DMatrix* dmat) { - device_ = generic_param_->gpu_id; - CHECK_GE(device_, 0) << "Must have at least one device"; + CHECK_GE(ctx_->gpu_id, 0) << "Must have at least one device"; info_ = &dmat->Info(); - reducer_.Init({device_}); // NOLINT + reducer_.Init({ctx_->gpu_id}); // NOLINT // Synchronise the column sampling seed uint32_t column_sampling_seed = common::GlobalRandom()(); rabit::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0); BatchParam batch_param{ - device_, - param_.max_bin, + ctx_->gpu_id, + param_.max_bin, }; auto page = (*dmat->GetBatches(batch_param).begin()).Impl(); - dh::safe_cuda(cudaSetDevice(device_)); - info_->feature_types.SetDevice(device_); - maker.reset(new GPUHistMakerDevice(device_, - page, - info_->feature_types.ConstDeviceSpan(), - info_->num_row_, - param_, - column_sampling_seed, - info_->num_col_, - batch_param)); + dh::safe_cuda(cudaSetDevice(ctx_->gpu_id)); + info_->feature_types.SetDevice(ctx_->gpu_id); + maker.reset(new GPUHistMakerDevice( + ctx_->gpu_id, page, info_->feature_types.ConstDeviceSpan(), info_->num_row_, param_, + column_sampling_seed, info_->num_col_, batch_param)); p_last_fmat_ = dmat; initialised_ = true; @@ -766,7 +779,7 @@ class GPUHistMakerSpecialised { } fs.Seek(0); rabit::Broadcast(&s_model, 0); - RegTree reference_tree {}; // rank 0 tree + RegTree reference_tree{}; // rank 0 tree reference_tree.Load(&fs); CHECK(*local_tree == reference_tree); } @@ -775,13 +788,11 @@ class GPUHistMakerSpecialised { monitor_.Start("InitData"); this->InitData(p_fmat); monitor_.Stop("InitData"); - - gpair->SetDevice(device_); + gpair->SetDevice(ctx_->gpu_id); maker->UpdateTree(gpair, p_fmat, task_, p_tree, &reducer_); } - bool UpdatePredictionCache(const DMatrix *data, - linalg::VectorView p_out_preds) { + bool UpdatePredictionCache(const DMatrix* data, linalg::VectorView p_out_preds) { if (maker == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) { return false; } @@ -791,107 +802,32 @@ class GPUHistMakerSpecialised { return true; } - TrainParam param_; // NOLINT - MetaInfo* info_{}; // NOLINT + TrainParam param_; // NOLINT + MetaInfo* info_{}; // NOLINT std::unique_ptr> maker; // NOLINT + char const* Name() const override { return "grow_gpu_hist"; } + private: - bool initialised_ { false }; + bool initialised_{false}; GPUHistMakerTrainParam hist_maker_param_; - GenericParameter const* generic_param_; dh::AllReducer reducer_; - DMatrix* p_last_fmat_ { nullptr }; - int device_{-1}; + DMatrix* p_last_fmat_{nullptr}; ObjInfo task_; common::Monitor monitor_; }; -class GPUHistMaker : public TreeUpdater { - public: - explicit GPUHistMaker(ObjInfo task) : task_{task} {} - void Configure(const Args& args) override { - // Used in test to count how many configurations are performed - LOG(DEBUG) << "[GPU Hist]: Configure"; - hist_maker_param_.UpdateAllowUnknown(args); - // The passed in args can be empty, if we simply purge the old maker without - // preserving parameters then we can't do Update on it. - TrainParam param; - if (float_maker_) { - param = float_maker_->param_; - } else if (double_maker_) { - param = double_maker_->param_; - } - if (hist_maker_param_.single_precision_histogram) { - float_maker_.reset(new GPUHistMakerSpecialised(task_)); - float_maker_->param_ = param; - float_maker_->Configure(args, ctx_); - } else { - double_maker_.reset(new GPUHistMakerSpecialised(task_)); - double_maker_->param_ = param; - double_maker_->Configure(args, ctx_); - } - } - - void LoadConfig(Json const& in) override { - auto const& config = get(in); - FromJson(config.at("gpu_hist_train_param"), &this->hist_maker_param_); - if (hist_maker_param_.single_precision_histogram) { - float_maker_.reset(new GPUHistMakerSpecialised(task_)); - FromJson(config.at("train_param"), &float_maker_->param_); - } else { - double_maker_.reset(new GPUHistMakerSpecialised(task_)); - FromJson(config.at("train_param"), &double_maker_->param_); - } - } - void SaveConfig(Json* p_out) const override { - auto& out = *p_out; - out["gpu_hist_train_param"] = ToJson(hist_maker_param_); - if (hist_maker_param_.single_precision_histogram) { - out["train_param"] = ToJson(float_maker_->param_); - } else { - out["train_param"] = ToJson(double_maker_->param_); - } - } - - void Update(HostDeviceVector* gpair, DMatrix* dmat, - const std::vector& trees) override { - if (hist_maker_param_.single_precision_histogram) { - float_maker_->Update(gpair, dmat, trees); - } else { - double_maker_->Update(gpair, dmat, trees); - } - } - - bool - UpdatePredictionCache(const DMatrix *data, - linalg::VectorView p_out_preds) override { - if (hist_maker_param_.single_precision_histogram) { - return float_maker_->UpdatePredictionCache(data, p_out_preds); - } else { - return double_maker_->UpdatePredictionCache(data, p_out_preds); - } - } - - char const* Name() const override { - return "grow_gpu_hist"; - } - - private: - GPUHistMakerTrainParam hist_maker_param_; - ObjInfo task_; - std::unique_ptr> float_maker_; - std::unique_ptr> double_maker_; -}; - #if !defined(GTEST_TEST) XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist") .describe("Grow tree with GPU.") - .set_body([](ObjInfo task) { return new GPUHistMaker(task); }); + .set_body([](GenericParameter const* tparam, ObjInfo task) { + return new GPUHistMaker(tparam, task); + }); #endif // !defined(GTEST_TEST) } // namespace tree diff --git a/src/tree/updater_histmaker.cc b/src/tree/updater_histmaker.cc index 0a85d2d73832..9d36e4d16c0a 100644 --- a/src/tree/updater_histmaker.cc +++ b/src/tree/updater_histmaker.cc @@ -24,9 +24,9 @@ DMLC_REGISTRY_FILE_TAG(updater_histmaker); class HistMaker: public BaseMaker { public: - void Update(HostDeviceVector *gpair, - DMatrix *p_fmat, - const std::vector &trees) override { + explicit HistMaker(GenericParameter const *ctx) : BaseMaker(ctx) {} + void Update(HostDeviceVector *gpair, DMatrix *p_fmat, + const std::vector &trees) override { interaction_constraints_.Configure(param_, p_fmat->Info().num_col_); // rescale learning rate according to size of trees float lr = param_.learning_rate; @@ -262,12 +262,10 @@ class HistMaker: public BaseMaker { } }; -class CQHistMaker: public HistMaker { +class CQHistMaker : public HistMaker { public: - CQHistMaker() = default; - char const* Name() const override { - return "grow_local_histmaker"; - } + explicit CQHistMaker(GenericParameter const *ctx) : HistMaker(ctx) {} + char const *Name() const override { return "grow_local_histmaker"; } protected: struct HistEntry { @@ -624,9 +622,7 @@ class CQHistMaker: public HistMaker { }; XGBOOST_REGISTER_TREE_UPDATER(LocalHistMaker, "grow_local_histmaker") -.describe("Tree constructor that uses approximate histogram construction.") -.set_body([](ObjInfo) { - return new CQHistMaker(); - }); + .describe("Tree constructor that uses approximate histogram construction.") + .set_body([](GenericParameter const *ctx, ObjInfo) { return new CQHistMaker(ctx); }); } // namespace tree } // namespace xgboost diff --git a/src/tree/updater_prune.cc b/src/tree/updater_prune.cc index f71f1c698cb9..9e6fad883040 100644 --- a/src/tree/updater_prune.cc +++ b/src/tree/updater_prune.cc @@ -21,9 +21,9 @@ namespace tree { DMLC_REGISTRY_FILE_TAG(updater_prune); /*! \brief pruner that prunes a tree after growing finishes */ -class TreePruner: public TreeUpdater { +class TreePruner : public TreeUpdater { public: - explicit TreePruner(ObjInfo task) { + explicit TreePruner(GenericParameter const* ctx, ObjInfo task) : TreeUpdater(ctx) { syncher_.reset(TreeUpdater::Create("sync", ctx_, task)); pruner_monitor_.Init("TreePruner"); } @@ -112,9 +112,7 @@ class TreePruner: public TreeUpdater { }; XGBOOST_REGISTER_TREE_UPDATER(TreePruner, "prune") -.describe("Pruner that prune the tree according to statistics.") -.set_body([](ObjInfo task) { - return new TreePruner(task); - }); + .describe("Pruner that prune the tree according to statistics.") + .set_body([](GenericParameter const* ctx, ObjInfo task) { return new TreePruner(ctx, task); }); } // namespace tree } // namespace xgboost diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 0e1b6db47691..dcbb3dbfba3e 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -390,6 +390,8 @@ template struct QuantileHistMaker::Builder; XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker") .describe("Grow tree using quantized histogram.") - .set_body([](ObjInfo task) { return new QuantileHistMaker(task); }); + .set_body([](GenericParameter const *ctx, ObjInfo task) { + return new QuantileHistMaker(ctx, task); + }); } // namespace tree } // namespace xgboost diff --git a/src/tree/updater_quantile_hist.h b/src/tree/updater_quantile_hist.h index 3c03a371ebfb..463c7a54ab39 100644 --- a/src/tree/updater_quantile_hist.h +++ b/src/tree/updater_quantile_hist.h @@ -225,7 +225,8 @@ inline BatchParam HistBatch(TrainParam const& param) { /*! \brief construct a tree using quantized feature values */ class QuantileHistMaker: public TreeUpdater { public: - explicit QuantileHistMaker(ObjInfo task) : task_{task} {} + explicit QuantileHistMaker(GenericParameter const* ctx, ObjInfo task) + : task_{task}, TreeUpdater(ctx) {} void Configure(const Args& args) override; void Update(HostDeviceVector* gpair, diff --git a/src/tree/updater_refresh.cc b/src/tree/updater_refresh.cc index d17c1e1444f7..6110e964f891 100644 --- a/src/tree/updater_refresh.cc +++ b/src/tree/updater_refresh.cc @@ -22,11 +22,10 @@ namespace tree { DMLC_REGISTRY_FILE_TAG(updater_refresh); /*! \brief pruner that prunes a tree after growing finishs */ -class TreeRefresher: public TreeUpdater { +class TreeRefresher : public TreeUpdater { public: - void Configure(const Args& args) override { - param_.UpdateAllowUnknown(args); - } + explicit TreeRefresher(GenericParameter const *ctx) : TreeUpdater(ctx) {} + void Configure(const Args &args) override { param_.UpdateAllowUnknown(args); } void LoadConfig(Json const& in) override { auto const& config = get(in); FromJson(config.at("train_param"), &this->param_); @@ -160,9 +159,7 @@ class TreeRefresher: public TreeUpdater { }; XGBOOST_REGISTER_TREE_UPDATER(TreeRefresher, "refresh") -.describe("Refresher that refreshes the weight and statistics according to data.") -.set_body([](ObjInfo) { - return new TreeRefresher(); - }); + .describe("Refresher that refreshes the weight and statistics according to data.") + .set_body([](GenericParameter const *ctx, ObjInfo) { return new TreeRefresher(ctx); }); } // namespace tree } // namespace xgboost diff --git a/src/tree/updater_sync.cc b/src/tree/updater_sync.cc index 4f7c7a1a85a6..5a22675965dc 100644 --- a/src/tree/updater_sync.cc +++ b/src/tree/updater_sync.cc @@ -20,8 +20,9 @@ DMLC_REGISTRY_FILE_TAG(updater_sync); * \brief syncher that synchronize the tree in all distributed nodes * can implement various strategies, so far it is always set to node 0's tree */ -class TreeSyncher: public TreeUpdater { +class TreeSyncher : public TreeUpdater { public: + explicit TreeSyncher(GenericParameter const* tparam) : TreeUpdater(tparam) {} void Configure(const Args&) override {} void LoadConfig(Json const&) override {} @@ -52,9 +53,7 @@ class TreeSyncher: public TreeUpdater { }; XGBOOST_REGISTER_TREE_UPDATER(TreeSyncher, "sync") -.describe("Syncher that synchronize the tree in all distributed nodes.") -.set_body([](ObjInfo) { - return new TreeSyncher(); - }); + .describe("Syncher that synchronize the tree in all distributed nodes.") + .set_body([](GenericParameter const* tparam, ObjInfo) { return new TreeSyncher(tparam); }); } // namespace tree } // namespace xgboost diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 82f40465deb2..883537863307 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -275,8 +275,10 @@ void TestHistogramIndexImpl() { int constexpr kNRows = 1000, kNCols = 10; // Build 2 matrices and build a histogram maker with that - tree::GPUHistMakerSpecialised hist_maker{ObjInfo{ObjInfo::kRegression}}, - hist_maker_ext{ObjInfo{ObjInfo::kRegression}}; + + GenericParameter generic_param(CreateEmptyGenericParam(0)); + tree::GPUHistMaker hist_maker{&generic_param,ObjInfo{ObjInfo::kRegression}}, + hist_maker_ext{&generic_param,ObjInfo{ObjInfo::kRegression}}; std::unique_ptr hist_maker_dmat( CreateSparsePageDMatrixWithRC(kNRows, kNCols, 0, true)); @@ -289,10 +291,9 @@ void TestHistogramIndexImpl() { {"max_leaves", "0"} }; - GenericParameter generic_param(CreateEmptyGenericParam(0)); - hist_maker.Configure(training_params, &generic_param); + hist_maker.Configure(training_params); hist_maker.InitDataOnce(hist_maker_dmat.get()); - hist_maker_ext.Configure(training_params, &generic_param); + hist_maker_ext.Configure(training_params); hist_maker_ext.InitDataOnce(hist_maker_ext_dmat.get()); // Extract the device maker from the histogram makers and from that its compressed @@ -344,10 +345,9 @@ void UpdateTree(HostDeviceVector* gpair, DMatrix* dmat, {"sampling_method", sampling_method}, }; - tree::GPUHistMakerSpecialised hist_maker{ObjInfo{ObjInfo::kRegression}}; GenericParameter generic_param(CreateEmptyGenericParam(0)); - hist_maker.Configure(args, &generic_param); - + tree::GPUHistMaker hist_maker{&generic_param,ObjInfo{ObjInfo::kRegression}}; + hist_maker.Configure(args); hist_maker.Update(gpair, dmat, {tree}); auto cache = linalg::VectorView{preds->DeviceSpan(), {preds->Size()}, 0}; hist_maker.UpdatePredictionCache(dmat, cache); diff --git a/tests/python-gpu/test_gpu_basic_models.py b/tests/python-gpu/test_gpu_basic_models.py index 06e63bdd56d9..9e955eac2931 100644 --- a/tests/python-gpu/test_gpu_basic_models.py +++ b/tests/python-gpu/test_gpu_basic_models.py @@ -16,11 +16,11 @@ class TestGPUBasicModels: cpu_test_bm = test_bm.TestModels() def run_cls(self, X, y): - cls = xgb.XGBClassifier(tree_method='gpu_hist', single_precision_histogram=True) + cls = xgb.XGBClassifier(tree_method='gpu_hist') cls.fit(X, y) cls.get_booster().save_model('test_deterministic_gpu_hist-0.json') - cls = xgb.XGBClassifier(tree_method='gpu_hist', single_precision_histogram=True) + cls = xgb.XGBClassifier(tree_method='gpu_hist') cls.fit(X, y) cls.get_booster().save_model('test_deterministic_gpu_hist-1.json') diff --git a/tests/python-gpu/test_gpu_updaters.py b/tests/python-gpu/test_gpu_updaters.py index a3427b566360..8f3cbcaac61f 100644 --- a/tests/python-gpu/test_gpu_updaters.py +++ b/tests/python-gpu/test_gpu_updaters.py @@ -3,7 +3,7 @@ import gc import pytest import xgboost as xgb -from hypothesis import given, strategies, assume, settings, note +from hypothesis import given, strategies, assume, settings, note, reproduce_failure sys.path.append("tests/python") import testing as tm @@ -15,7 +15,6 @@ 'max_leaves': strategies.integers(0, 256), 'max_bin': strategies.integers(2, 1024), 'grow_policy': strategies.sampled_from(['lossguide', 'depthwise']), - 'single_precision_histogram': strategies.booleans(), 'min_child_weight': strategies.floats(0.5, 2.0), 'seed': strategies.integers(0, 10), # We cannot enable subsampling as the training loss can increase From f140ebcb2f0219486ad8702eaa322e0f9da624ea Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Mon, 25 Apr 2022 04:27:15 -0700 Subject: [PATCH 02/14] Batch nodes from driver --- src/tree/driver.h | 33 +++++++++--- src/tree/updater_approx.cc | 2 +- src/tree/updater_gpu_hist.cu | 69 +++++++++++++------------- src/tree/updater_quantile_hist.cc | 2 +- tests/cpp/tree/gpu_hist/test_driver.cu | 33 +++++++----- 5 files changed, 84 insertions(+), 55 deletions(-) diff --git a/src/tree/driver.h b/src/tree/driver.h index abb8afadcb8a..1e40cc32622f 100644 --- a/src/tree/driver.h +++ b/src/tree/driver.h @@ -33,9 +33,9 @@ class Driver { std::function>; public: - explicit Driver(TrainParam::TreeGrowPolicy policy) - : policy_(policy), - queue_(policy == TrainParam::kDepthWise ? DepthWise : + explicit Driver(TrainParam param) + : param_(param), + queue_(param.grow_policy == TrainParam::kDepthWise ? DepthWise : LossGuide) {} template void Push(EntryIterT begin, EntryIterT end) { @@ -55,16 +55,30 @@ class Driver { return queue_.empty(); } + // Can a child of this entry still be expanded? + // can be used to avoid extra work + bool IsChildValid(ExpandEntryT const& parent_entry){ + if (param_.max_depth > 0 && parent_entry.depth + 1 >= param_.max_depth) return false; + if (param_.max_leaves > 0 && num_leaves_ >= param_.max_leaves) return false; + return true; + } + // Return the set of nodes to be expanded // This set has no dependencies between entries so they may be expanded in // parallel or asynchronously std::vector Pop() { if (queue_.empty()) return {}; // Return a single entry for loss guided mode - if (policy_ == TrainParam::kLossGuide) { + if (param_.grow_policy == TrainParam::kLossGuide) { ExpandEntryT e = queue_.top(); queue_.pop(); - return {e}; + + if (e.IsValid(param_, num_leaves_)) { + num_leaves_++; + return {e}; + } else { + return {}; + } } // Return nodes on same level for depth wise std::vector result; @@ -72,7 +86,11 @@ class Driver { int level = e.depth; while (e.depth == level && !queue_.empty()) { queue_.pop(); - result.emplace_back(e); + if (e.IsValid(param_, num_leaves_)) { + num_leaves_++; + result.emplace_back(e); + } + if (!queue_.empty()) { e = queue_.top(); } @@ -81,7 +99,8 @@ class Driver { } private: - TrainParam::TreeGrowPolicy policy_; + TrainParam param_; + std::size_t num_leaves_=1; ExpandQueue queue_; }; } // namespace tree diff --git a/src/tree/updater_approx.cc b/src/tree/updater_approx.cc index a06f195374b6..1c6b195ab34b 100644 --- a/src/tree/updater_approx.cc +++ b/src/tree/updater_approx.cc @@ -169,7 +169,7 @@ class GloablApproxBuilder { p_last_tree_ = p_tree; this->InitData(p_fmat, hess); - Driver driver(static_cast(param_.grow_policy)); + Driver driver(param_); auto &tree = *p_tree; driver.Push({this->InitRoot(p_fmat, gpair, hess, p_tree)}); bst_node_t num_leaves{1}; diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 2cac6b6c4f4a..2340687983a8 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -531,7 +531,7 @@ struct GPUHistMakerDevice { RegTree& tree = *p_tree; // Sanity check - have we created a leaf with no training instances? - if (!rabit::IsDistributed()) { + if (!rabit::IsDistributed() && row_partitioner) { CHECK(row_partitioner->GetRows(candidate.nid).size() > 0) << "No training instances in this leaf!"; } @@ -616,7 +616,7 @@ struct GPUHistMakerDevice { void UpdateTree(HostDeviceVector* gpair_all, DMatrix* p_fmat, ObjInfo task, RegTree* p_tree, dh::AllReducer* reducer) { auto& tree = *p_tree; - Driver driver(static_cast(param.grow_policy)); + Driver driver(param); monitor.Start("Reset"); this->Reset(gpair_all, p_fmat, p_fmat->Info().num_col_, task); @@ -626,48 +626,49 @@ struct GPUHistMakerDevice { driver.Push({ this->InitRoot(p_tree, task, reducer) }); monitor.Stop("InitRoot"); - auto num_leaves = 1; - // The set of leaves that can be expanded asynchronously auto expand_set = driver.Pop(); while (!expand_set.empty()) { + for(auto & candidate: expand_set){ + this->ApplySplit(candidate, p_tree); + } + // Get the candidates we are allowed to expand further + // e.g. We do not bother further processing nodes whose children are beyond max depth + std::vector filtered_expand_set; + std::copy_if(expand_set.begin(), expand_set.end(), std::back_inserter(filtered_expand_set), + [&](const auto& e) { return driver.IsChildValid(e); }); + auto new_candidates = - pinned.GetSpan(expand_set.size() * 2, GPUExpandEntry()); + pinned.GetSpan(filtered_expand_set.size() * 2, GPUExpandEntry()); + + for(const auto &e:filtered_expand_set){ + monitor.Start("UpdatePosition"); + // Update position is only run when child is valid, instead of right after apply + // split (as in approx tree method). Hense we have the finalise position call + // in GPU Hist. + this->UpdatePosition(e.nid, p_tree); + monitor.Stop("UpdatePosition"); + } - for (auto i = 0ull; i < expand_set.size(); i++) { + for (auto i = 0ull; i < filtered_expand_set.size(); i++) { auto candidate = expand_set.at(i); - if (!candidate.IsValid(param, num_leaves)) { - continue; - } - this->ApplySplit(candidate, p_tree); + int left_child_nidx = tree[candidate.nid].LeftChild(); + int right_child_nidx = tree[candidate.nid].RightChild(); - num_leaves++; + monitor.Start("BuildHist"); + this->BuildHistLeftRight(candidate, left_child_nidx, right_child_nidx, reducer); + monitor.Stop("BuildHist"); + } + for (auto i = 0ull; i < filtered_expand_set.size(); i++) { + auto candidate = expand_set.at(i); int left_child_nidx = tree[candidate.nid].LeftChild(); int right_child_nidx = tree[candidate.nid].RightChild(); - // Only create child entries if needed - if (GPUExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx), - num_leaves)) { - monitor.Start("UpdatePosition"); - // Update position is only run when child is valid, instead of right after apply - // split (as in approx tree method). Hense we have the finalise position call - // in GPU Hist. - this->UpdatePosition(candidate.nid, p_tree); - monitor.Stop("UpdatePosition"); - - monitor.Start("BuildHist"); - this->BuildHistLeftRight(candidate, left_child_nidx, right_child_nidx, reducer); - monitor.Stop("BuildHist"); - - monitor.Start("EvaluateSplits"); - this->EvaluateLeftRightSplits(candidate, task, left_child_nidx, right_child_nidx, *p_tree, - new_candidates.subspan(i * 2, 2)); - monitor.Stop("EvaluateSplits"); - } else { - // Set default - new_candidates[i * 2] = GPUExpandEntry(); - new_candidates[i * 2 + 1] = GPUExpandEntry(); - } + + monitor.Start("EvaluateSplits"); + this->EvaluateLeftRightSplits(candidate, task, left_child_nidx, right_child_nidx, *p_tree, + new_candidates.subspan(i * 2, 2)); + monitor.Stop("EvaluateSplits"); } dh::DefaultStream().Sync(); driver.Push(new_candidates.begin(), new_candidates.end()); diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index dcbb3dbfba3e..bdda543d75a7 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -174,7 +174,7 @@ void QuantileHistMaker::Builder::ExpandTree( DMatrix *p_fmat, RegTree *p_tree, const std::vector &gpair_h) { monitor_->Start(__func__); - Driver driver(static_cast(param_.grow_policy)); + Driver driver(param_); driver.Push(this->InitRoot(p_fmat, p_tree, gpair_h)); bst_node_t num_leaves{1}; auto expand_set = driver.Pop(); diff --git a/tests/cpp/tree/gpu_hist/test_driver.cu b/tests/cpp/tree/gpu_hist/test_driver.cu index d35f3510f628..d7f8cc63869e 100644 --- a/tests/cpp/tree/gpu_hist/test_driver.cu +++ b/tests/cpp/tree/gpu_hist/test_driver.cu @@ -6,16 +6,21 @@ namespace xgboost { namespace tree { TEST(GpuHist, DriverDepthWise) { - Driver driver(TrainParam::kDepthWise); + TrainParam p; + p.InitAllowUnknown(Args{}); + p.grow_policy=TrainParam::kDepthWise; + Driver driver(p); EXPECT_TRUE(driver.Pop().empty()); DeviceSplitCandidate split; split.loss_chg = 1.0f; - GPUExpandEntry root(0, 0, split, .0f, .0f, .0f); + split.left_sum = {0.0f, 1.0f}; + split.right_sum = {0.0f, 1.0f}; + GPUExpandEntry root(0, 0, split, 2.0f, 1.0f, 1.0f); driver.Push({root}); EXPECT_EQ(driver.Pop().front().nid, 0); - driver.Push({GPUExpandEntry{1, 1, split, .0f, .0f, .0f}}); - driver.Push({GPUExpandEntry{2, 1, split, .0f, .0f, .0f}}); - driver.Push({GPUExpandEntry{3, 2, split, .0f, .0f, .0f}}); + driver.Push({GPUExpandEntry{1, 1, split, 2.0f, 1.0f, 1.0f}}); + driver.Push({GPUExpandEntry{2, 1, split, 2.0f, 1.0f, 1.0f}}); + driver.Push({GPUExpandEntry{3, 2, split, 2.0f, 1.0f, 1.0f}}); // Should return entries from level 1 auto res = driver.Pop(); EXPECT_EQ(res.size(), 2); @@ -29,18 +34,22 @@ TEST(GpuHist, DriverDepthWise) { TEST(GpuHist, DriverLossGuided) { DeviceSplitCandidate high_gain; + high_gain.left_sum = {0.0f, 1.0f}; + high_gain.right_sum = {0.0f, 1.0f}; high_gain.loss_chg = 5.0f; - DeviceSplitCandidate low_gain; + DeviceSplitCandidate low_gain = high_gain; low_gain.loss_chg = 1.0f; - Driver driver(TrainParam::kLossGuide); + TrainParam p; + p.grow_policy=TrainParam::kLossGuide; + Driver driver(p); EXPECT_TRUE(driver.Pop().empty()); - GPUExpandEntry root(0, 0, high_gain, .0f, .0f, .0f); + GPUExpandEntry root(0, 0, high_gain, 2.0f, 1.0f, 1.0f ); driver.Push({root}); EXPECT_EQ(driver.Pop().front().nid, 0); // Select high gain first - driver.Push({GPUExpandEntry{1, 1, low_gain, .0f, .0f, .0f}}); - driver.Push({GPUExpandEntry{2, 2, high_gain, .0f, .0f, .0f}}); + driver.Push({GPUExpandEntry{1, 1, low_gain, 2.0f, 1.0f, 1.0f}}); + driver.Push({GPUExpandEntry{2, 2, high_gain, 2.0f, 1.0f, 1.0f}}); auto res = driver.Pop(); EXPECT_EQ(res.size(), 1); EXPECT_EQ(res[0].nid, 2); @@ -49,8 +58,8 @@ TEST(GpuHist, DriverLossGuided) { EXPECT_EQ(res[0].nid, 1); // If equal gain, use nid - driver.Push({GPUExpandEntry{2, 1, low_gain, .0f, .0f, .0f}}); - driver.Push({GPUExpandEntry{1, 1, low_gain, .0f, .0f, .0f}}); + driver.Push({GPUExpandEntry{2, 1, low_gain, 2.0f, 1.0f, 1.0f}}); + driver.Push({GPUExpandEntry{1, 1, low_gain, 2.0f, 1.0f, 1.0f}}); res = driver.Pop(); EXPECT_EQ(res[0].nid, 1); res = driver.Pop(); From 80a3e78f9e1dcbf2a78f6572897453c61afd60b0 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Fri, 29 Apr 2022 02:22:26 -0700 Subject: [PATCH 03/14] Categoricals broken --- src/tree/gpu_hist/evaluate_splits.cuh | 2 +- src/tree/gpu_hist/evaluator.cu | 73 +++---- src/tree/gpu_hist/histogram.cu | 9 - src/tree/updater_gpu_hist.cu | 223 +++++++++++++--------- tests/cpp/tree/gpu_hist/test_histogram.cu | 1 - tests/cpp/tree/test_gpu_hist.cu | 45 +++-- 6 files changed, 186 insertions(+), 167 deletions(-) diff --git a/src/tree/gpu_hist/evaluate_splits.cuh b/src/tree/gpu_hist/evaluate_splits.cuh index b03fd7b41b51..7d792051e5be 100644 --- a/src/tree/gpu_hist/evaluate_splits.cuh +++ b/src/tree/gpu_hist/evaluate_splits.cuh @@ -92,7 +92,7 @@ class GPUHistEvaluator { } /** - * \brief Get sorted index storage based on the left node of inputs . + * \brief Get sorted index storage based on the left node of inputs. */ auto SortedIdx(EvaluateSplitInputs left) { if (left.nidx == RegTree::kRoot && !cat_sorted_idx_.empty()) { diff --git a/src/tree/gpu_hist/evaluator.cu b/src/tree/gpu_hist/evaluator.cu index bc2027489131..381ef8fbb349 100644 --- a/src/tree/gpu_hist/evaluator.cu +++ b/src/tree/gpu_hist/evaluator.cu @@ -21,55 +21,36 @@ void GPUHistEvaluator::Reset(common::HistogramCuts const &cuts, int32_t device) { param_ = param; tree_evaluator_ = TreeEvaluator{param, n_features, device}; - if (cuts.HasCategorical() && !task.UseOneHot()) { + if (cuts.HasCategorical()) { dh::XGBCachingDeviceAllocator alloc; - auto ptrs = cuts.cut_ptrs_.ConstDeviceSpan(); - auto beg = thrust::make_counting_iterator(1ul); - auto end = thrust::make_counting_iterator(ptrs.size()); - auto to_onehot = param.max_cat_to_onehot; - // This condition avoids sort-based split function calls if the users want - // onehot-encoding-based splits. - // For some reason, any_of adds 1.5 minutes to compilation time for CUDA 11.x. - has_sort_ = thrust::any_of(thrust::cuda::par(alloc), beg, end, [=] XGBOOST_DEVICE(size_t i) { - auto idx = i - 1; - if (common::IsCat(ft, idx)) { - auto n_bins = ptrs[i] - ptrs[idx]; - bool use_sort = !common::UseOneHot(n_bins, to_onehot, task); - return use_sort; - } - return false; - }); - - if (has_sort_) { - auto bit_storage_size = common::CatBitField::ComputeStorageSize(cuts.MaxCategory() + 1); - CHECK_NE(bit_storage_size, 0); - // We need to allocate for all nodes since the updater can grow the tree layer by - // layer, all nodes in the same layer must be preserved until that layer is - // finished. We can allocate one layer at a time, but the best case is reducing the - // size of the bitset by about a half, at the cost of invoking CUDA malloc many more - // times than necessary. - split_cats_.resize(param.MaxNodes() * bit_storage_size); - h_split_cats_.resize(split_cats_.size()); - dh::safe_cuda( - cudaMemsetAsync(split_cats_.data().get(), '\0', split_cats_.size() * sizeof(CatST))); + auto bit_storage_size = common::CatBitField::ComputeStorageSize(cuts.MaxCategory() + 1); + CHECK_NE(bit_storage_size, 0); + // We need to allocate for all nodes since the updater can grow the tree layer by + // layer, all nodes in the same layer must be preserved until that layer is + // finished. We can allocate one layer at a time, but the best case is reducing the + // size of the bitset by about a half, at the cost of invoking CUDA malloc many more + // times than necessary. + split_cats_.resize(param.MaxNodes() * bit_storage_size); + h_split_cats_.resize(split_cats_.size()); + dh::safe_cuda( + cudaMemsetAsync(split_cats_.data().get(), '\0', split_cats_.size() * sizeof(CatST))); - cat_sorted_idx_.resize(cuts.cut_values_.Size() * 2); // evaluate 2 nodes at a time. - sort_input_.resize(cat_sorted_idx_.size()); + cat_sorted_idx_.resize(cuts.cut_values_.Size() * 2); // evaluate 2 nodes at a time. + sort_input_.resize(cat_sorted_idx_.size()); - /** - * cache feature index binary search result - */ - feature_idx_.resize(cat_sorted_idx_.size()); - auto d_fidxes = dh::ToSpan(feature_idx_); - auto it = thrust::make_counting_iterator(0ul); - auto values = cuts.cut_values_.ConstDeviceSpan(); - auto ptrs = cuts.cut_ptrs_.ConstDeviceSpan(); - thrust::transform(thrust::cuda::par(alloc), it, it + feature_idx_.size(), - feature_idx_.begin(), [=] XGBOOST_DEVICE(size_t i) { - auto fidx = dh::SegmentId(ptrs, i); - return fidx; - }); - } + /** + * cache feature index binary search result + */ + feature_idx_.resize(cat_sorted_idx_.size()); + auto d_fidxes = dh::ToSpan(feature_idx_); + auto it = thrust::make_counting_iterator(0ul); + auto values = cuts.cut_values_.ConstDeviceSpan(); + auto ptrs = cuts.cut_ptrs_.ConstDeviceSpan(); + thrust::transform(thrust::cuda::par(alloc), it, it + feature_idx_.size(), feature_idx_.begin(), + [=] XGBOOST_DEVICE(size_t i) { + auto fidx = dh::SegmentId(ptrs, i); + return fidx; + }); } } diff --git a/src/tree/gpu_hist/histogram.cu b/src/tree/gpu_hist/histogram.cu index 791363a05cdd..efb08d5e44e2 100644 --- a/src/tree/gpu_hist/histogram.cu +++ b/src/tree/gpu_hist/histogram.cu @@ -247,15 +247,6 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix, dh::safe_cuda(cudaGetLastError()); } -template void BuildGradientHistogram( - EllpackDeviceAccessor const& matrix, - FeatureGroupsAccessor const& feature_groups, - common::Span gpair, - common::Span ridx, - common::Span histogram, - HistRounding rounding, - bool force_global_memory); - template void BuildGradientHistogram( EllpackDeviceAccessor const& matrix, FeatureGroupsAccessor const& feature_groups, diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 2340687983a8..2cd9d4babeb1 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -57,7 +57,7 @@ DMLC_REGISTER_PARAMETER(GPUHistMakerTrainParam); #endif // !defined(GTEST_TEST) /** - * \struct DeviceHistogram + * \struct DeviceHistogramStorage * * \summary Data storage for node histograms on device. Automatically expands. * @@ -67,12 +67,18 @@ DMLC_REGISTER_PARAMETER(GPUHistMakerTrainParam); * \author Rory * \date 28/07/2018 */ -template -class DeviceHistogram { +template +class DeviceHistogramStorage { private: /*! \brief Map nidx to starting index of its histogram. */ std::map nidx_map_; + // Large buffer of zeroed memory, caches histograms dh::device_vector data_; + // If we run out of storage allocate one histogram at a time + // in overflow. Not cached, overwritten when a new histogram + // is requested + dh::device_vector overflow_; + std::map overflow_nidx_map_; int n_bins_; int device_id_; static constexpr size_t kNumItemsInGradientSum = @@ -81,6 +87,8 @@ class DeviceHistogram { "Number of items in gradient type should be 2."); public: + // Start with about 16mb + DeviceHistogramStorage() { data_.reserve(1 << 22); } void Init(int device_id, int n_bins) { this->n_bins_ = n_bins; this->device_id_ = device_id; @@ -91,21 +99,53 @@ class DeviceHistogram { dh::LaunchN(data_.size(), [=] __device__(size_t idx) { d_data[idx] = 0.0f; }); nidx_map_.clear(); + overflow_nidx_map_.clear(); } bool HistogramExists(int nidx) const { - return nidx_map_.find(nidx) != nidx_map_.cend(); + return nidx_map_.find(nidx) != nidx_map_.cend() || overflow_nidx_map_.find(nidx) != overflow_nidx_map_.cend(); } int Bins() const { return n_bins_; } - size_t HistogramSize() const { - return n_bins_ * kNumItemsInGradientSum; - } + size_t HistogramSize() const { return n_bins_ * kNumItemsInGradientSum; } + dh::device_vector& Data() { return data_; } - dh::device_vector& Data() { - return data_; + void AllocateHistograms(const std::vector& new_nidxs) { + for (int nidx : new_nidxs) { + CHECK(!HistogramExists(nidx)); + } + // Number of items currently used in data + const size_t used_size = nidx_map_.size() * HistogramSize(); + const size_t new_used_size = used_size + HistogramSize() * new_nidxs.size(); + if (used_size >= kStopGrowingSize) { + // Use overflow + // Delete previous entries + overflow_nidx_map_.clear(); + overflow_.resize(HistogramSize() * new_nidxs.size()); + // Zero memory + auto d_data = overflow_.data().get(); + dh::LaunchN(overflow_.size(), + [=] __device__(size_t idx) { d_data[idx] = 0.0; }); + // Append new histograms + for (int nidx : new_nidxs) { + overflow_nidx_map_[nidx] = overflow_nidx_map_.size() * HistogramSize(); + } + } else { + CHECK_GE(data_.size(), used_size); + // Expand if necessary + if (data_.size() < new_used_size) { + data_.resize(std::max(data_.size() * 2, new_used_size)); + } + // Append new histograms + for (int nidx : new_nidxs) { + nidx_map_[nidx] = nidx_map_.size() * HistogramSize(); + } + } + + CHECK_GE(data_.size(), nidx_map_.size() * HistogramSize()); } + /* void AllocateHistogram(int nidx) { if (HistogramExists(nidx)) return; // Number of items currently used in data @@ -139,6 +179,7 @@ class DeviceHistogram { CHECK_GE(data_.size(), nidx_map_.size() * HistogramSize()); } + */ /** * \summary Return pointer to histogram memory for a given node. @@ -147,9 +188,16 @@ class DeviceHistogram { */ common::Span GetNodeHistogram(int nidx) { CHECK(this->HistogramExists(nidx)); - auto ptr = data_.data().get() + nidx_map_.at(nidx); - return common::Span( - reinterpret_cast(ptr), n_bins_); + + if (nidx_map_.find(nidx) != nidx_map_.cend()) { + // Fetch from normal cache + auto ptr = data_.data().get() + nidx_map_.at(nidx); + return common::Span(reinterpret_cast(ptr), n_bins_); + } else { + // Fetch from overflow + auto ptr = overflow_.data().get() + overflow_nidx_map_.at(nidx); + return common::Span(reinterpret_cast(ptr), n_bins_); + } } }; @@ -166,7 +214,7 @@ struct GPUHistMakerDevice { BatchParam batch_param; std::unique_ptr row_partitioner; - DeviceHistogram hist{}; + DeviceHistogramStorage hist{}; dh::caching_device_vector d_gpair; // storage for gpair; common::Span gpair; @@ -189,8 +237,6 @@ struct GPUHistMakerDevice { std::unique_ptr sampler; std::unique_ptr feature_groups; - // Storing split categories for last node. - dh::caching_device_vector node_categories; GPUHistMakerDevice(int _device_id, EllpackPageImpl const* _page, common::Span _feature_types, bst_uint _n_rows, @@ -319,7 +365,6 @@ struct GPUHistMakerDevice { } void BuildHist(int nidx) { - hist.AllocateHistogram(nidx); auto d_node_hist = hist.GetNodeHistogram(nidx); auto d_ridx = row_partitioner->GetRows(nidx); BuildGradientHistogram(page->GetDeviceAccessor(device_id), @@ -327,8 +372,12 @@ struct GPUHistMakerDevice { d_ridx, d_node_hist, histogram_rounding); } - void SubtractionTrick(int nidx_parent, int nidx_histogram, - int nidx_subtraction) { + // Attempt to the subtraction trick + // return true if succeeded + bool SubtractionTrick(int nidx_parent, int nidx_histogram, int nidx_subtraction) { + if (!hist.HistogramExists(nidx_histogram) || !hist.HistogramExists(nidx_parent)) { + return false; + } auto d_node_hist_parent = hist.GetNodeHistogram(nidx_parent); auto d_node_hist_histogram = hist.GetNodeHistogram(nidx_histogram); auto d_node_hist_subtraction = hist.GetNodeHistogram(nidx_subtraction); @@ -337,22 +386,18 @@ struct GPUHistMakerDevice { d_node_hist_subtraction[idx] = d_node_hist_parent[idx] - d_node_hist_histogram[idx]; }); + return true; } - bool CanDoSubtractionTrick(int nidx_parent, int nidx_histogram, int nidx_subtraction) { - // Make sure histograms are already allocated - hist.AllocateHistogram(nidx_subtraction); - return hist.HistogramExists(nidx_histogram) && hist.HistogramExists(nidx_parent); - } - - void UpdatePosition(int nidx, RegTree* p_tree) { - RegTree::Node split_node = (*p_tree)[nidx]; - auto split_type = p_tree->NodeSplitType(nidx); + void UpdatePosition(const GPUExpandEntry &e, RegTree* p_tree) { + RegTree::Node split_node = (*p_tree)[e.nid]; + auto split_type = p_tree->NodeSplitType(e.nid); auto d_matrix = page->GetDeviceAccessor(device_id); - auto node_cats = dh::ToSpan(node_categories); + auto node_cats = e.split.split_cats.Bits(); + row_partitioner->UpdatePosition( - nidx, split_node.LeftChild(), split_node.RightChild(), + e.nid, split_node.LeftChild(), split_node.RightChild(), [=] __device__(bst_uint ridx) { // given a row index, returns the node id it belongs to bst_float cut_value = @@ -483,13 +528,15 @@ struct GPUHistMakerDevice { row_partitioner.reset(); } - void AllReduceHist(int nidx, dh::AllReducer* reducer) { + // num histograms is the number of contiguous histograms in memory to reduce over + void AllReduceHist(int nidx, dh::AllReducer* reducer, int num_histograms) { monitor.Start("AllReduce"); auto d_node_hist = hist.GetNodeHistogram(nidx).data(); - reducer->AllReduceSum( - reinterpret_cast(d_node_hist), - reinterpret_cast(d_node_hist), - page->Cuts().TotalBins() * (sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT))); + reducer->AllReduceSum(reinterpret_cast(d_node_hist), + reinterpret_cast(d_node_hist), + page->Cuts().TotalBins() * + (sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT)) * + num_histograms); monitor.Stop("AllReduce"); } @@ -497,33 +544,49 @@ struct GPUHistMakerDevice { /** * \brief Build GPU local histograms for the left and right child of some parent node */ - void BuildHistLeftRight(const GPUExpandEntry &candidate, int nidx_left, - int nidx_right, dh::AllReducer* reducer) { - auto build_hist_nidx = nidx_left; - auto subtraction_trick_nidx = nidx_right; - - // Decide whether to build the left histogram or right histogram - // Use sum of Hessian as a heuristic to select node with fewest training instances - bool fewer_right = candidate.split.right_sum.GetHess() < candidate.split.left_sum.GetHess(); - if (fewer_right) { - std::swap(build_hist_nidx, subtraction_trick_nidx); + void BuildHistLeftRight(std::vectorconst &candidates, dh::AllReducer* reducer, const RegTree& tree) { + if(candidates.empty()) return; + // Some nodes we will manually compute histograms + // others we will do by subtraction + std::vector hist_nidx; + std::vector subtraction_nidx; + for (auto& e : candidates) { + // Decide whether to build the left histogram or right histogram + // Use sum of Hessian as a heuristic to select node with fewest training instances + bool fewer_right = e.split.right_sum.GetHess() < e.split.left_sum.GetHess(); + if (fewer_right) { + hist_nidx.emplace_back(tree[e.nid].RightChild()); + subtraction_nidx.emplace_back(tree[e.nid].LeftChild()); + } else { + hist_nidx.emplace_back(tree[e.nid].LeftChild()); + subtraction_nidx.emplace_back(tree[e.nid].RightChild()); + } + } + std::vector all_new = hist_nidx; + all_new.insert(all_new.end(), subtraction_nidx.begin(), subtraction_nidx.end()); + // Allocate the histograms + // Guaranteed contiguous memory + hist.AllocateHistograms(all_new); + + for(auto nidx:hist_nidx){ + this->BuildHist(nidx); } - this->BuildHist(build_hist_nidx); - this->AllReduceHist(build_hist_nidx, reducer); + // Reduce all in one go + // This gives much better latency in a distributed setting + // when processing a large batch + this->AllReduceHist(hist_nidx.at(0), reducer, hist_nidx.size()); - // Check whether we can use the subtraction trick to calculate the other - bool do_subtraction_trick = this->CanDoSubtractionTrick( - candidate.nid, build_hist_nidx, subtraction_trick_nidx); + for (int i = 0; i < subtraction_nidx.size(); i++) { + auto build_hist_nidx = hist_nidx.at(i); + auto subtraction_trick_nidx = subtraction_nidx.at(i); + auto parent_nidx = candidates.at(i).nid; - if (do_subtraction_trick) { - // Calculate other histogram using subtraction trick - this->SubtractionTrick(candidate.nid, build_hist_nidx, - subtraction_trick_nidx); - } else { - // Calculate other histogram manually - this->BuildHist(subtraction_trick_nidx); - this->AllReduceHist(subtraction_trick_nidx, reducer); + if(!this->SubtractionTrick(parent_nidx, build_hist_nidx, subtraction_trick_nidx)){ + // Calculate other histogram manually + this->BuildHist(subtraction_trick_nidx); + this->AllReduceHist(subtraction_trick_nidx, reducer, 1); + } } } @@ -546,27 +609,11 @@ struct GPUHistMakerDevice { CHECK_LT(candidate.split.fvalue, std::numeric_limits::max()) << "Categorical feature value too large."; std::vector split_cats; - if (candidate.split.split_cats.Bits().empty()) { - if (common::InvalidCat(candidate.split.fvalue)) { - common::InvalidCategory(); - } - auto cat = common::AsCat(candidate.split.fvalue); - split_cats.resize(LBitField32::ComputeStorageSize(cat + 1), 0); - common::CatBitField cats_bits(split_cats); - cats_bits.Set(cat); - dh::CopyToD(split_cats, &node_categories); - } else { - auto h_cats = this->evaluator_.GetHostNodeCats(candidate.nid); - auto max_cat = candidate.split.MaxCat(); - split_cats.resize(common::CatBitField::ComputeStorageSize(max_cat + 1), 0); - CHECK_LE(split_cats.size(), h_cats.size()); - std::copy(h_cats.data(), h_cats.data() + split_cats.size(), split_cats.data()); - - node_categories.resize(candidate.split.split_cats.Bits().size()); - dh::safe_cuda(cudaMemcpyAsync( - node_categories.data().get(), candidate.split.split_cats.Data(), - candidate.split.split_cats.Bits().size_bytes(), cudaMemcpyDeviceToDevice)); - } + auto h_cats = this->evaluator_.GetHostNodeCats(candidate.nid); + auto max_cat = candidate.split.MaxCat(); + split_cats.resize(common::CatBitField::ComputeStorageSize(max_cat + 1), 0); + CHECK_LE(split_cats.size(), h_cats.size()); + std::copy(h_cats.data(), h_cats.data() + split_cats.size(), split_cats.data()); tree.ExpandCategorical( candidate.nid, candidate.split.findex, split_cats, candidate.split.dir == kLeftDir, @@ -598,8 +645,9 @@ struct GPUHistMakerDevice { GradientPairPrecise{}, thrust::plus{}); rabit::Allreduce(reinterpret_cast(&root_sum), 2); + hist.AllocateHistograms({kRootNIdx}); this->BuildHist(kRootNIdx); - this->AllReduceHist(kRootNIdx, reducer); + this->AllReduceHist(kRootNIdx, reducer, 1); // Remember root stats node_sum_gradients[kRootNIdx] = root_sum; @@ -638,6 +686,7 @@ struct GPUHistMakerDevice { std::copy_if(expand_set.begin(), expand_set.end(), std::back_inserter(filtered_expand_set), [&](const auto& e) { return driver.IsChildValid(e); }); + auto new_candidates = pinned.GetSpan(filtered_expand_set.size() * 2, GPUExpandEntry()); @@ -646,22 +695,16 @@ struct GPUHistMakerDevice { // Update position is only run when child is valid, instead of right after apply // split (as in approx tree method). Hense we have the finalise position call // in GPU Hist. - this->UpdatePosition(e.nid, p_tree); + this->UpdatePosition(e, p_tree); monitor.Stop("UpdatePosition"); } - for (auto i = 0ull; i < filtered_expand_set.size(); i++) { - auto candidate = expand_set.at(i); - int left_child_nidx = tree[candidate.nid].LeftChild(); - int right_child_nidx = tree[candidate.nid].RightChild(); - - monitor.Start("BuildHist"); - this->BuildHistLeftRight(candidate, left_child_nidx, right_child_nidx, reducer); - monitor.Stop("BuildHist"); - } + monitor.Start("BuildHist"); + this->BuildHistLeftRight(filtered_expand_set, reducer, tree); + monitor.Stop("BuildHist"); for (auto i = 0ull; i < filtered_expand_set.size(); i++) { - auto candidate = expand_set.at(i); + auto candidate = filtered_expand_set.at(i); int left_child_nidx = tree[candidate.nid].LeftChild(); int right_child_nidx = tree[candidate.nid].RightChild(); diff --git a/tests/cpp/tree/gpu_hist/test_histogram.cu b/tests/cpp/tree/gpu_hist/test_histogram.cu index 3b543a48d7cc..75d97b681a61 100644 --- a/tests/cpp/tree/gpu_hist/test_histogram.cu +++ b/tests/cpp/tree/gpu_hist/test_histogram.cu @@ -95,7 +95,6 @@ TEST(Histogram, GPUDeterministic) { std::vector shm_sizes{48 * 1024, 64 * 1024, 160 * 1024}; for (bool is_dense : is_dense_array) { for (int shm_size : shm_sizes) { - TestDeterministicHistogram(is_dense, shm_size); TestDeterministicHistogram(is_dense, shm_size); } } diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 883537863307..bdabbbcb38c2 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -29,29 +29,38 @@ TEST(GpuHist, DeviceHistogram) { constexpr size_t kNBins = 128; constexpr size_t kNNodes = 4; constexpr size_t kStopGrowing = kNNodes * kNBins * 2u; - DeviceHistogram histogram; + DeviceHistogramStorage histogram; histogram.Init(0, kNBins); - for (size_t i = 0; i < kNNodes; ++i) { - histogram.AllocateHistogram(i); + for (int i = 0; i < kNNodes; ++i) { + histogram.AllocateHistograms({i}); } histogram.Reset(); ASSERT_EQ(histogram.Data().size(), kStopGrowing); // Use allocated memory but do not erase nidx_map. - for (size_t i = 0; i < kNNodes; ++i) { - histogram.AllocateHistogram(i); + for (int i = 0; i < kNNodes; ++i) { + histogram.AllocateHistograms({i}); } - for (size_t i = 0; i < kNNodes; ++i) { + for (int i = 0; i < kNNodes; ++i) { ASSERT_TRUE(histogram.HistogramExists(i)); } - // Erase existing nidx_map. - for (size_t i = kNNodes; i < kNNodes * 2; ++i) { - histogram.AllocateHistogram(i); - } - for (size_t i = 0; i < kNNodes; ++i) { - ASSERT_FALSE(histogram.HistogramExists(i)); + // Add two new nodes + histogram.AllocateHistograms({kNNodes}); + histogram.AllocateHistograms({kNNodes+1}); + + // Old cached nodes should still exist + for (int i = 0; i < kNNodes; ++i) { + ASSERT_TRUE(histogram.HistogramExists(i)); } + + // Should be deleted + ASSERT_FALSE(histogram.HistogramExists({kNNodes})); + // Most recent node should exist + ASSERT_TRUE(histogram.HistogramExists({kNNodes + 1})); + + // Add same node again - should fail + EXPECT_ANY_THROW(histogram.AllocateHistograms({kNNodes+1});); } std::vector GetHostHistGpair() { @@ -95,9 +104,9 @@ void TestBuildHist(bool use_shared_memory_histograms) { thrust::host_vector h_gidx_buffer (page->gidx_buffer.HostVector()); maker.row_partitioner.reset(new RowPartitioner(0, kNRows)); - maker.hist.AllocateHistogram(0); + maker.hist.AllocateHistograms({0}); maker.gpair = gpair.DeviceSpan(); - maker.histogram_rounding = CreateRoundingFactor(maker.gpair);; + maker.histogram_rounding = CreateRoundingFactor(maker.gpair); BuildGradientHistogram( page->GetDeviceAccessor(0), maker.feature_groups->DeviceAccessor(0), @@ -105,7 +114,7 @@ void TestBuildHist(bool use_shared_memory_histograms) { maker.hist.GetNodeHistogram(0), maker.histogram_rounding, !use_shared_memory_histograms); - DeviceHistogram& d_hist = maker.hist; + DeviceHistogramStorage& d_hist = maker.hist; auto node_histogram = d_hist.GetNodeHistogram(0); // d_hist.data stored in float, not gradient pair @@ -128,12 +137,10 @@ void TestBuildHist(bool use_shared_memory_histograms) { TEST(GpuHist, BuildHistGlobalMem) { TestBuildHist(false); - TestBuildHist(false); } TEST(GpuHist, BuildHistSharedMem) { TestBuildHist(true); - TestBuildHist(true); } TEST(GpuHist, ApplySplit) { @@ -173,8 +180,6 @@ TEST(GpuHist, ApplySplit) { ASSERT_EQ(tree.GetSplitCategories().size(), 1); uint32_t bits = 1u << 30; // bits: 0, 1, 0, 0, 0, ..., 0 ASSERT_EQ(tree.GetSplitCategories().back(), bits); - - ASSERT_EQ(updater.node_categories.size(), 1); } } @@ -238,7 +243,7 @@ TEST(GpuHist, EvaluateRootSplit) { // Initialize GPUHistMakerDevice::hist maker.hist.Init(0, (max_bins - 1) * kNCols); - maker.hist.AllocateHistogram(0); + maker.hist.AllocateHistograms({0}); // Each row of hist_gpair represents gpairs for one feature. // Each entry represents a bin. std::vector hist_gpair = GetHostHistGpair(); From e1fb7024fdea6224349a3fbd863ba839b3a78748 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Sun, 1 May 2022 09:50:09 -0700 Subject: [PATCH 04/14] Refactor categoricals --- src/tree/gpu_hist/evaluate_splits.cu | 39 ++++++++------- src/tree/gpu_hist/evaluate_splits.cuh | 6 +-- src/tree/gpu_hist/evaluator.cu | 72 +++++++++++++-------------- src/tree/updater_gpu_hist.cu | 43 +++++----------- tests/cpp/tree/test_gpu_hist.cu | 2 - 5 files changed, 72 insertions(+), 90 deletions(-) diff --git a/src/tree/gpu_hist/evaluate_splits.cu b/src/tree/gpu_hist/evaluate_splits.cu index ce8b13d0def2..7fba1902b881 100644 --- a/src/tree/gpu_hist/evaluate_splits.cu +++ b/src/tree/gpu_hist/evaluate_splits.cu @@ -273,12 +273,19 @@ __device__ DeviceSplitCandidate operator+(const DeviceSplitCandidate& a, * \brief Set the bits for categorical splits based on the split threshold. */ template -__device__ void SortBasedSplit(EvaluateSplitInputs const &input, +__device__ void SetCategoricalSplit(EvaluateSplitInputs const &input, common::Span d_sorted_idx, bst_feature_t fidx, bool is_left, common::Span out, - DeviceSplitCandidate *p_out_split) { + DeviceSplitCandidate *p_out_split, ObjInfo task) { auto &out_split = *p_out_split; out_split.split_cats = common::CatBitField{out}; + + // Simple case for one hot split + if (common::UseOneHot(input.FeatureBins(fidx), input.param.max_cat_to_onehot, task)) { + out_split.split_cats.Set(common::AsCat(out_split.fvalue)); + return; + } + auto node_sorted_idx = is_left ? d_sorted_idx.subspan(0, input.feature_values.size()) : d_sorted_idx.subspan(input.feature_values.size(), input.feature_values.size()); @@ -313,7 +320,7 @@ void GPUHistEvaluator::EvaluateSplits( EvaluateSplitInputs left, EvaluateSplitInputs right, ObjInfo task, TreeEvaluator::SplitEvaluator evaluator, common::Span out_splits) { - if (!split_cats_.empty()) { + if (need_sort_histogram_) { this->SortHistogram(left, right, evaluator); } @@ -354,14 +361,12 @@ void GPUHistEvaluator::EvaluateSplits( template void GPUHistEvaluator::CopyToHost(EvaluateSplitInputs const &input, common::Span cats_out) { - if (has_sort_) { - dh::CUDAEvent event; - event.Record(dh::DefaultStream()); - auto h_cats = this->HostCatStorage(input.nidx); - copy_stream_.View().Wait(event); - dh::safe_cuda(cudaMemcpyAsync(h_cats.data(), cats_out.data(), cats_out.size_bytes(), - cudaMemcpyDeviceToHost, copy_stream_.View())); - } + dh::CUDAEvent event; + event.Record(dh::DefaultStream()); + auto h_cats = this->HostCatStorage(input.nidx); + copy_stream_.View().Wait(event); + dh::safe_cuda(cudaMemcpyAsync(h_cats.data(), cats_out.data(), cats_out.size_bytes(), + cudaMemcpyDeviceToHost, copy_stream_.View())); } template @@ -378,17 +383,16 @@ void GPUHistEvaluator::EvaluateSplits(GPUExpandEntry candidate, Ob auto d_sorted_idx = this->SortedIdx(left); auto d_entries = out_entries; auto cats_out = this->DeviceCatStorage(left.nidx); - // turn candidate into entry, along with hanlding sort based split. + // turn candidate into entry, along with handling sort based split. dh::LaunchN(right.feature_set.empty() ? 1 : 2, [=] __device__(size_t i) { auto const &input = i == 0 ? left : right; auto &split = out_splits[i]; auto fidx = out_splits[i].findex; - if (split.is_cat && - !common::UseOneHot(input.FeatureBins(fidx), input.param.max_cat_to_onehot, task)) { + if (split.is_cat) { bool is_left = i == 0; auto out = is_left ? cats_out.first(cats_out.size() / 2) : cats_out.last(cats_out.size() / 2); - SortBasedSplit(input, d_sorted_idx, fidx, is_left, out, &out_splits[i]); + SetCategoricalSplit(input, d_sorted_idx, fidx, is_left, out, &out_splits[i], task); } float base_weight = @@ -420,9 +424,8 @@ GPUExpandEntry GPUHistEvaluator::EvaluateSingleSplit( auto &split = out_split[i]; auto fidx = out_split[i].findex; - if (split.is_cat && - !common::UseOneHot(input.FeatureBins(fidx), input.param.max_cat_to_onehot, task)) { - SortBasedSplit(input, d_sorted_idx, fidx, true, cats_out, &out_split[i]); + if (split.is_cat) { + SetCategoricalSplit(input, d_sorted_idx, fidx, true, cats_out, &out_split[i], task); } float left_weight = evaluator.CalcWeight(0, input.param, GradStats{split.left_sum}); diff --git a/src/tree/gpu_hist/evaluate_splits.cuh b/src/tree/gpu_hist/evaluate_splits.cuh index b03fd7b41b51..f28aac97b417 100644 --- a/src/tree/gpu_hist/evaluate_splits.cuh +++ b/src/tree/gpu_hist/evaluate_splits.cuh @@ -58,9 +58,9 @@ class GPUHistEvaluator { dh::device_vector feature_idx_; // Training param used for evaluation TrainParam param_; - // whether the input data requires sort based split, which is more complicated so we try - // to avoid it if possible. - bool has_sort_{false}; + // Do we have any categorical features that require sorting histograms? + // use this to skip the expensive sort step + bool need_sort_histogram_ = false; // Copy the categories from device to host asynchronously. void CopyToHost(EvaluateSplitInputs const &input, common::Span cats_out); diff --git a/src/tree/gpu_hist/evaluator.cu b/src/tree/gpu_hist/evaluator.cu index bc2027489131..6c081e1ba6df 100644 --- a/src/tree/gpu_hist/evaluator.cu +++ b/src/tree/gpu_hist/evaluator.cu @@ -30,46 +30,44 @@ void GPUHistEvaluator::Reset(common::HistogramCuts const &cuts, // This condition avoids sort-based split function calls if the users want // onehot-encoding-based splits. // For some reason, any_of adds 1.5 minutes to compilation time for CUDA 11.x. - has_sort_ = thrust::any_of(thrust::cuda::par(alloc), beg, end, [=] XGBOOST_DEVICE(size_t i) { - auto idx = i - 1; - if (common::IsCat(ft, idx)) { - auto n_bins = ptrs[i] - ptrs[idx]; - bool use_sort = !common::UseOneHot(n_bins, to_onehot, task); - return use_sort; - } - return false; - }); + need_sort_histogram_ = + thrust::any_of(thrust::cuda::par(alloc), beg, end, [=] XGBOOST_DEVICE(size_t i) { + auto idx = i - 1; + if (common::IsCat(ft, idx)) { + auto n_bins = ptrs[i] - ptrs[idx]; + bool use_sort = !common::UseOneHot(n_bins, to_onehot, task); + return use_sort; + } + return false; + }); - if (has_sort_) { - auto bit_storage_size = common::CatBitField::ComputeStorageSize(cuts.MaxCategory() + 1); - CHECK_NE(bit_storage_size, 0); - // We need to allocate for all nodes since the updater can grow the tree layer by - // layer, all nodes in the same layer must be preserved until that layer is - // finished. We can allocate one layer at a time, but the best case is reducing the - // size of the bitset by about a half, at the cost of invoking CUDA malloc many more - // times than necessary. - split_cats_.resize(param.MaxNodes() * bit_storage_size); - h_split_cats_.resize(split_cats_.size()); - dh::safe_cuda( - cudaMemsetAsync(split_cats_.data().get(), '\0', split_cats_.size() * sizeof(CatST))); + auto bit_storage_size = common::CatBitField::ComputeStorageSize(cuts.MaxCategory() + 1); + CHECK_NE(bit_storage_size, 0); + // We need to allocate for all nodes since the updater can grow the tree layer by + // layer, all nodes in the same layer must be preserved until that layer is + // finished. We can allocate one layer at a time, but the best case is reducing the + // size of the bitset by about a half, at the cost of invoking CUDA malloc many more + // times than necessary. + split_cats_.resize(param.MaxNodes() * bit_storage_size); + h_split_cats_.resize(split_cats_.size()); + dh::safe_cuda( + cudaMemsetAsync(split_cats_.data().get(), '\0', split_cats_.size() * sizeof(CatST))); - cat_sorted_idx_.resize(cuts.cut_values_.Size() * 2); // evaluate 2 nodes at a time. - sort_input_.resize(cat_sorted_idx_.size()); + cat_sorted_idx_.resize(cuts.cut_values_.Size() * 2); // evaluate 2 nodes at a time. + sort_input_.resize(cat_sorted_idx_.size()); - /** - * cache feature index binary search result - */ - feature_idx_.resize(cat_sorted_idx_.size()); - auto d_fidxes = dh::ToSpan(feature_idx_); - auto it = thrust::make_counting_iterator(0ul); - auto values = cuts.cut_values_.ConstDeviceSpan(); - auto ptrs = cuts.cut_ptrs_.ConstDeviceSpan(); - thrust::transform(thrust::cuda::par(alloc), it, it + feature_idx_.size(), - feature_idx_.begin(), [=] XGBOOST_DEVICE(size_t i) { - auto fidx = dh::SegmentId(ptrs, i); - return fidx; - }); - } + /** + * cache feature index binary search result + */ + feature_idx_.resize(cat_sorted_idx_.size()); + auto d_fidxes = dh::ToSpan(feature_idx_); + auto it = thrust::make_counting_iterator(0ul); + auto values = cuts.cut_values_.ConstDeviceSpan(); + thrust::transform(thrust::cuda::par(alloc), it, it + feature_idx_.size(), feature_idx_.begin(), + [=] XGBOOST_DEVICE(size_t i) { + auto fidx = dh::SegmentId(ptrs, i); + return fidx; + }); } } diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 569188fd5374..861b6e15b264 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -197,8 +197,6 @@ struct GPUHistMakerDevice { std::unique_ptr sampler; std::unique_ptr feature_groups; - // Storing split categories for last node. - dh::caching_device_vector node_categories; GPUHistMakerDevice(Context const* ctx, EllpackPageImpl const* _page, common::Span _feature_types, bst_uint _n_rows, @@ -354,14 +352,14 @@ struct GPUHistMakerDevice { return hist.HistogramExists(nidx_histogram) && hist.HistogramExists(nidx_parent); } - void UpdatePosition(int nidx, RegTree* p_tree) { - RegTree::Node split_node = (*p_tree)[nidx]; - auto split_type = p_tree->NodeSplitType(nidx); + void UpdatePosition(const GPUExpandEntry &e, RegTree* p_tree) { + RegTree::Node split_node = (*p_tree)[e.nid]; + auto split_type = p_tree->NodeSplitType(e.nid); auto d_matrix = page->GetDeviceAccessor(ctx_->gpu_id); - auto node_cats = dh::ToSpan(node_categories); + auto node_cats = e.split.split_cats.Bits(); row_partitioner->UpdatePosition( - nidx, split_node.LeftChild(), split_node.RightChild(), + e.nid, split_node.LeftChild(), split_node.RightChild(), [=] __device__(bst_uint ridx) { // given a row index, returns the node id it belongs to bst_float cut_value = @@ -567,28 +565,13 @@ struct GPUHistMakerDevice { CHECK_LT(candidate.split.fvalue, std::numeric_limits::max()) << "Categorical feature value too large."; std::vector split_cats; - if (candidate.split.split_cats.Bits().empty()) { - if (common::InvalidCat(candidate.split.fvalue)) { - common::InvalidCategory(); - } - auto cat = common::AsCat(candidate.split.fvalue); - split_cats.resize(LBitField32::ComputeStorageSize(cat + 1), 0); - common::CatBitField cats_bits(split_cats); - cats_bits.Set(cat); - dh::CopyToD(split_cats, &node_categories); - } else { - auto h_cats = this->evaluator_.GetHostNodeCats(candidate.nid); - auto max_cat = candidate.split.MaxCat(); - split_cats.resize(common::CatBitField::ComputeStorageSize(max_cat + 1), 0); - CHECK_LE(split_cats.size(), h_cats.size()); - std::copy(h_cats.data(), h_cats.data() + split_cats.size(), split_cats.data()); - - node_categories.resize(candidate.split.split_cats.Bits().size()); - dh::safe_cuda(cudaMemcpyAsync( - node_categories.data().get(), candidate.split.split_cats.Data(), - candidate.split.split_cats.Bits().size_bytes(), cudaMemcpyDeviceToDevice)); - } - + CHECK_GT(candidate.split.split_cats.Bits().size(), 0); + auto h_cats = this->evaluator_.GetHostNodeCats(candidate.nid); + auto max_cat = candidate.split.MaxCat(); + split_cats.resize(common::CatBitField::ComputeStorageSize(max_cat + 1), 0); + CHECK_LE(split_cats.size(), h_cats.size()); + std::copy(h_cats.data(), h_cats.data() + split_cats.size(), split_cats.data()); + tree.ExpandCategorical( candidate.nid, candidate.split.findex, split_cats, candidate.split.dir == kLeftDir, base_weight, left_weight, right_weight, candidate.split.loss_chg, parent_sum.GetHess(), @@ -674,7 +657,7 @@ struct GPUHistMakerDevice { // Update position is only run when child is valid, instead of right after apply // split (as in approx tree method). Hense we have the finalise position call // in GPU Hist. - this->UpdatePosition(candidate.nid, p_tree); + this->UpdatePosition(candidate, p_tree); monitor.Stop("UpdatePosition"); monitor.Start("BuildHist"); diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 3c93c283917a..ea5556b38fca 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -174,8 +174,6 @@ TEST(GpuHist, ApplySplit) { ASSERT_EQ(tree.GetSplitCategories().size(), 1); uint32_t bits = 1u << 30; // bits: 0, 1, 0, 0, 0, ..., 0 ASSERT_EQ(tree.GetSplitCategories().back(), bits); - - ASSERT_EQ(updater.node_categories.size(), 1); } } From dc100cfbf5bb10230875680424611a8136cb7996 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Mon, 2 May 2022 06:00:24 -0700 Subject: [PATCH 05/14] Refactor categoricals 2 --- src/common/categorical.h | 4 +- src/tree/gpu_hist/evaluate_splits.cu | 4 +- src/tree/gpu_hist/evaluate_splits.cuh | 28 ++++++++----- src/tree/gpu_hist/evaluator.cu | 15 +++---- src/tree/hist/evaluate_splits.h | 2 +- src/tree/updater_gpu_hist.cu | 4 +- .../cpp/tree/gpu_hist/test_evaluate_splits.cu | 31 ++++++++------ tests/cpp/tree/test_gpu_hist.cu | 40 ------------------- 8 files changed, 52 insertions(+), 76 deletions(-) diff --git a/src/common/categorical.h b/src/common/categorical.h index 5eff62264cf2..341a887f48a9 100644 --- a/src/common/categorical.h +++ b/src/common/categorical.h @@ -82,8 +82,8 @@ inline void InvalidCategory() { /*! * \brief Whether should we use onehot encoding for categorical data. */ -XGBOOST_DEVICE inline bool UseOneHot(uint32_t n_cats, uint32_t max_cat_to_onehot, ObjInfo task) { - bool use_one_hot = n_cats < max_cat_to_onehot || task.UseOneHot(); +XGBOOST_DEVICE inline bool UseOneHot(uint32_t n_cats, uint32_t max_cat_to_onehot) { + bool use_one_hot = n_cats < max_cat_to_onehot; return use_one_hot; } diff --git a/src/tree/gpu_hist/evaluate_splits.cu b/src/tree/gpu_hist/evaluate_splits.cu index 7fba1902b881..2966b84e75af 100644 --- a/src/tree/gpu_hist/evaluate_splits.cu +++ b/src/tree/gpu_hist/evaluate_splits.cu @@ -241,7 +241,7 @@ __global__ void EvaluateSplitsKernel( if (common::IsCat(inputs.feature_types, fidx)) { auto n_bins_in_feat = inputs.feature_segments[fidx + 1] - inputs.feature_segments[fidx]; - if (common::UseOneHot(n_bins_in_feat, inputs.param.max_cat_to_onehot, task)) { + if (common::UseOneHot(n_bins_in_feat, inputs.param.max_cat_to_onehot)) { EvaluateFeature(fidx, inputs, evaluator, sorted_idx, 0, &best_split, &temp_storage); } else { @@ -281,7 +281,7 @@ __device__ void SetCategoricalSplit(EvaluateSplitInputs const &inp out_split.split_cats = common::CatBitField{out}; // Simple case for one hot split - if (common::UseOneHot(input.FeatureBins(fidx), input.param.max_cat_to_onehot, task)) { + if (common::UseOneHot(input.FeatureBins(fidx), input.param.max_cat_to_onehot)) { out_split.split_cats.Set(common::AsCat(out_split.fvalue)); return; } diff --git a/src/tree/gpu_hist/evaluate_splits.cuh b/src/tree/gpu_hist/evaluate_splits.cuh index f28aac97b417..67e56426217a 100644 --- a/src/tree/gpu_hist/evaluate_splits.cuh +++ b/src/tree/gpu_hist/evaluate_splits.cuh @@ -61,6 +61,9 @@ class GPUHistEvaluator { // Do we have any categorical features that require sorting histograms? // use this to skip the expensive sort step bool need_sort_histogram_ = false; + // Number of elements of categorical storage type + // needed to hold categoricals for a single mode + std::size_t node_categorical_storage_size_ = 0; // Copy the categories from device to host asynchronously. void CopyToHost(EvaluateSplitInputs const &input, common::Span cats_out); @@ -69,12 +72,17 @@ class GPUHistEvaluator { * \brief Get host category storage of nidx for internal calculation. */ auto HostCatStorage(bst_node_t nidx) { - auto cat_bits = h_split_cats_.size() / param_.MaxNodes(); + + std::size_t min_size=(nidx+2)*node_categorical_storage_size_; + if(h_split_cats_.size(){h_split_cats_}.subspan(nidx * cat_bits, cat_bits); + auto cats_out = common::Span{h_split_cats_}.subspan(nidx * node_categorical_storage_size_, node_categorical_storage_size_); return cats_out; } - auto cats_out = common::Span{h_split_cats_}.subspan(nidx * cat_bits, cat_bits * 2); + auto cats_out = common::Span{h_split_cats_}.subspan(nidx * node_categorical_storage_size_, node_categorical_storage_size_ * 2); return cats_out; } @@ -82,12 +90,15 @@ class GPUHistEvaluator { * \brief Get device category storage of nidx for internal calculation. */ auto DeviceCatStorage(bst_node_t nidx) { - auto cat_bits = split_cats_.size() / param_.MaxNodes(); + std::size_t min_size=(nidx+2)*node_categorical_storage_size_; + if(split_cats_.size() ft, ObjInfo task, + void Reset(common::HistogramCuts const &cuts, common::Span ft, bst_feature_t n_features, TrainParam const ¶m, int32_t device); /** @@ -123,8 +134,7 @@ class GPUHistEvaluator { */ common::Span GetHostNodeCats(bst_node_t nidx) const { copy_stream_.View().Sync(); - auto cat_bits = h_split_cats_.size() / param_.MaxNodes(); - auto cats_out = common::Span{h_split_cats_}.subspan(nidx * cat_bits, cat_bits); + auto cats_out = common::Span{h_split_cats_}.subspan(nidx * node_categorical_storage_size_, node_categorical_storage_size_); return cats_out; } /** diff --git a/src/tree/gpu_hist/evaluator.cu b/src/tree/gpu_hist/evaluator.cu index 6c081e1ba6df..777b017be24e 100644 --- a/src/tree/gpu_hist/evaluator.cu +++ b/src/tree/gpu_hist/evaluator.cu @@ -16,12 +16,12 @@ namespace xgboost { namespace tree { template void GPUHistEvaluator::Reset(common::HistogramCuts const &cuts, - common::Span ft, ObjInfo task, + common::Span ft, bst_feature_t n_features, TrainParam const ¶m, int32_t device) { param_ = param; tree_evaluator_ = TreeEvaluator{param, n_features, device}; - if (cuts.HasCategorical() && !task.UseOneHot()) { + if (cuts.HasCategorical()) { dh::XGBCachingDeviceAllocator alloc; auto ptrs = cuts.cut_ptrs_.ConstDeviceSpan(); auto beg = thrust::make_counting_iterator(1ul); @@ -35,21 +35,22 @@ void GPUHistEvaluator::Reset(common::HistogramCuts const &cuts, auto idx = i - 1; if (common::IsCat(ft, idx)) { auto n_bins = ptrs[i] - ptrs[idx]; - bool use_sort = !common::UseOneHot(n_bins, to_onehot, task); + bool use_sort = !common::UseOneHot(n_bins, to_onehot); return use_sort; } return false; }); - auto bit_storage_size = common::CatBitField::ComputeStorageSize(cuts.MaxCategory() + 1); - CHECK_NE(bit_storage_size, 0); + node_categorical_storage_size_ = + common::CatBitField::ComputeStorageSize(cuts.MaxCategory() + 1); + CHECK_NE(node_categorical_storage_size_, 0); // We need to allocate for all nodes since the updater can grow the tree layer by // layer, all nodes in the same layer must be preserved until that layer is // finished. We can allocate one layer at a time, but the best case is reducing the // size of the bitset by about a half, at the cost of invoking CUDA malloc many more // times than necessary. - split_cats_.resize(param.MaxNodes() * bit_storage_size); - h_split_cats_.resize(split_cats_.size()); + split_cats_.resize(node_categorical_storage_size_); + h_split_cats_.resize(node_categorical_storage_size_); dh::safe_cuda( cudaMemsetAsync(split_cats_.data().get(), '\0', split_cats_.size() * sizeof(CatST))); diff --git a/src/tree/hist/evaluate_splits.h b/src/tree/hist/evaluate_splits.h index 4e445a0680e5..8a61ea809c04 100644 --- a/src/tree/hist/evaluate_splits.h +++ b/src/tree/hist/evaluate_splits.h @@ -244,7 +244,7 @@ template class HistEvaluator { } if (is_cat) { auto n_bins = cut_ptrs.at(fidx + 1) - cut_ptrs[fidx]; - if (common::UseOneHot(n_bins, param_.max_cat_to_onehot, task_)) { + if (common::UseOneHot(n_bins, param_.max_cat_to_onehot)) { EnumerateSplit<+1, kOneHot>(cut, {}, histogram, fidx, nidx, evaluator, best); EnumerateSplit<-1, kOneHot>(cut, {}, histogram, fidx, nidx, evaluator, best); } else { diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 861b6e15b264..8ee6f43f78f5 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -241,7 +241,7 @@ struct GPUHistMakerDevice { param.colsample_bytree); dh::safe_cuda(cudaSetDevice(ctx_->gpu_id)); - this->evaluator_.Reset(page->Cuts(), feature_types, task, dmat->Info().num_col_, param, + this->evaluator_.Reset(page->Cuts(), feature_types, dmat->Info().num_col_, param, ctx_->gpu_id); this->interaction_constraints.Reset(); @@ -571,7 +571,7 @@ struct GPUHistMakerDevice { split_cats.resize(common::CatBitField::ComputeStorageSize(max_cat + 1), 0); CHECK_LE(split_cats.size(), h_cats.size()); std::copy(h_cats.data(), h_cats.data() + split_cats.size(), split_cats.data()); - + tree.ExpandCategorical( candidate.nid, candidate.split.findex, split_cats, candidate.split.dir == kLeftDir, base_weight, left_weight, right_weight, candidate.split.loss_chg, parent_sum.GetHess(), diff --git a/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu b/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu index 0cbfc9f2a6cf..2243cb4dda90 100644 --- a/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu +++ b/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu @@ -24,14 +24,16 @@ void TestEvaluateSingleSplit(bool is_categorical) { TrainParam tparam = ZeroParam(); GPUTrainingParam param{tparam}; + common::HistogramCuts cuts; + cuts.cut_values_.HostVector() = std::vector{1.0, 2.0, 11.0, 12.0}; + cuts.cut_ptrs_.HostVector() = std::vector{0, 2, 4}; + cuts.min_vals_.HostVector() = std::vector{0.0, 0.0}; + cuts.cut_ptrs_.SetDevice(0); + cuts.cut_values_.SetDevice(0); + cuts.min_vals_.SetDevice(0); thrust::device_vector feature_set = std::vector{0, 1}; - thrust::device_vector feature_segments = - std::vector{0, 2, 4}; - thrust::device_vector feature_values = - std::vector{1.0, 2.0, 11.0, 12.0}; - thrust::device_vector feature_min_values = - std::vector{0.0, 0.0}; + // Setup gradients so that second feature gets higher gain thrust::device_vector feature_histogram = std::vector{ @@ -42,21 +44,25 @@ void TestEvaluateSingleSplit(bool is_categorical) { FeatureType::kCategorical); common::Span d_feature_types; if (is_categorical) { + auto max_cat = *std::max_element(cuts.cut_values_.HostVector().begin(), + cuts.cut_values_.HostVector().end()); + cuts.SetCategorical(true, max_cat); d_feature_types = dh::ToSpan(feature_types); } + EvaluateSplitInputs input{1, parent_sum, param, dh::ToSpan(feature_set), d_feature_types, - dh::ToSpan(feature_segments), - dh::ToSpan(feature_values), - dh::ToSpan(feature_min_values), + cuts.cut_ptrs_.ConstDeviceSpan(), + cuts.cut_values_.ConstDeviceSpan(), + cuts.min_vals_.ConstDeviceSpan(), dh::ToSpan(feature_histogram)}; GPUHistEvaluator evaluator{ - tparam, static_cast(feature_min_values.size()), 0}; - dh::device_vector out_cats; + tparam, static_cast(feature_set.size()), 0}; + evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, 0); DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, 0, ObjInfo{ObjInfo::kRegression}).split; @@ -264,8 +270,7 @@ TEST_F(TestPartitionBasedSplit, GpuHist) { cuts_.cut_values_.SetDevice(0); cuts_.min_vals_.SetDevice(0); - ObjInfo task{ObjInfo::kRegression}; - evaluator.Reset(cuts_, dh::ToSpan(ft), task, info_.num_col_, param_, 0); + evaluator.Reset(cuts_, dh::ToSpan(ft), info_.num_col_, param_, 0); dh::device_vector d_hist(hist_[0].size()); auto node_hist = hist_[0]; diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index ea5556b38fca..2f3cc9c7d950 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -137,46 +137,6 @@ TEST(GpuHist, BuildHistSharedMem) { TestBuildHist(true); } -TEST(GpuHist, ApplySplit) { - RegTree tree; - GPUExpandEntry candidate; - candidate.nid = 0; - candidate.left_weight = 1.0f; - candidate.right_weight = 2.0f; - candidate.base_weight = 3.0f; - candidate.split.is_cat = true; - candidate.split.fvalue = 1.0f; // at cat 1 - - size_t n_rows = 10; - size_t n_cols = 10; - - auto m = RandomDataGenerator{n_rows, n_cols, 0}.GenerateDMatrix(true); - GenericParameter p; - p.InitAllowUnknown(Args{}); - - TrainParam tparam; - tparam.InitAllowUnknown(Args{}); - BatchParam bparam; - bparam.gpu_id = 0; - bparam.max_bin = 3; - Context ctx{CreateEmptyGenericParam(0)}; - - for (auto& ellpack : m->GetBatches(bparam)){ - auto impl = ellpack.Impl(); - HostDeviceVector feature_types(10, FeatureType::kCategorical); - feature_types.SetDevice(bparam.gpu_id); - tree::GPUHistMakerDevice updater( - &ctx, impl, feature_types.ConstDeviceSpan(), n_rows, tparam, 0, n_cols, bparam); - updater.ApplySplit(candidate, &tree); - - ASSERT_EQ(tree.GetSplitTypes().size(), 3); - ASSERT_EQ(tree.GetSplitTypes()[0], FeatureType::kCategorical); - ASSERT_EQ(tree.GetSplitCategories().size(), 1); - uint32_t bits = 1u << 30; // bits: 0, 1, 0, 0, 0, ..., 0 - ASSERT_EQ(tree.GetSplitCategories().back(), bits); - } -} - HistogramCutsWrapper GetHostCutMatrix () { HistogramCutsWrapper cmat; cmat.SetPtrs({0, 3, 6, 9, 12, 15, 18, 21, 24}); From bc744585f7832381fd525090718286b7f93b6d09 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Mon, 2 May 2022 06:21:51 -0700 Subject: [PATCH 06/14] Skip copy if no categoricals --- src/tree/gpu_hist/evaluate_splits.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/src/tree/gpu_hist/evaluate_splits.cu b/src/tree/gpu_hist/evaluate_splits.cu index 2966b84e75af..5326b103d2d7 100644 --- a/src/tree/gpu_hist/evaluate_splits.cu +++ b/src/tree/gpu_hist/evaluate_splits.cu @@ -361,6 +361,7 @@ void GPUHistEvaluator::EvaluateSplits( template void GPUHistEvaluator::CopyToHost(EvaluateSplitInputs const &input, common::Span cats_out) { + if (cats_out.empty()) return; dh::CUDAEvent event; event.Record(dh::DefaultStream()); auto h_cats = this->HostCatStorage(input.nidx); From c4f8eac8996262d8447e73ca24b88569e34fc5c2 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Thu, 5 May 2022 04:35:32 -0700 Subject: [PATCH 07/14] Review comment --- .gitignore | 5 ++++- src/tree/gpu_hist/evaluator.cu | 5 ----- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/.gitignore b/.gitignore index e847342b19bd..20b92c057e1a 100644 --- a/.gitignore +++ b/.gitignore @@ -130,4 +130,7 @@ credentials.csv # Visual Studio code + extensions .vscode .metals -.bloop \ No newline at end of file +.bloop + +# hypothesis python tests +.hypothesis \ No newline at end of file diff --git a/src/tree/gpu_hist/evaluator.cu b/src/tree/gpu_hist/evaluator.cu index 777b017be24e..aaf35243b2f5 100644 --- a/src/tree/gpu_hist/evaluator.cu +++ b/src/tree/gpu_hist/evaluator.cu @@ -44,11 +44,6 @@ void GPUHistEvaluator::Reset(common::HistogramCuts const &cuts, node_categorical_storage_size_ = common::CatBitField::ComputeStorageSize(cuts.MaxCategory() + 1); CHECK_NE(node_categorical_storage_size_, 0); - // We need to allocate for all nodes since the updater can grow the tree layer by - // layer, all nodes in the same layer must be preserved until that layer is - // finished. We can allocate one layer at a time, but the best case is reducing the - // size of the bitset by about a half, at the cost of invoking CUDA malloc many more - // times than necessary. split_cats_.resize(node_categorical_storage_size_); h_split_cats_.resize(node_categorical_storage_size_); dh::safe_cuda( From a1cddaabbf93bb0be86bfc293dea5a84e233d719 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Thu, 5 May 2022 07:30:55 -0700 Subject: [PATCH 08/14] Revert "Categoricals broken" This reverts commit 80a3e78f9e1dcbf2a78f6572897453c61afd60b0. --- src/tree/gpu_hist/evaluate_splits.cuh | 2 +- src/tree/gpu_hist/evaluator.cu | 73 ++++--- src/tree/gpu_hist/histogram.cu | 9 + src/tree/updater_gpu_hist.cu | 223 +++++++++------------- tests/cpp/tree/gpu_hist/test_histogram.cu | 1 + tests/cpp/tree/test_gpu_hist.cu | 45 ++--- 6 files changed, 167 insertions(+), 186 deletions(-) diff --git a/src/tree/gpu_hist/evaluate_splits.cuh b/src/tree/gpu_hist/evaluate_splits.cuh index 7d792051e5be..b03fd7b41b51 100644 --- a/src/tree/gpu_hist/evaluate_splits.cuh +++ b/src/tree/gpu_hist/evaluate_splits.cuh @@ -92,7 +92,7 @@ class GPUHistEvaluator { } /** - * \brief Get sorted index storage based on the left node of inputs. + * \brief Get sorted index storage based on the left node of inputs . */ auto SortedIdx(EvaluateSplitInputs left) { if (left.nidx == RegTree::kRoot && !cat_sorted_idx_.empty()) { diff --git a/src/tree/gpu_hist/evaluator.cu b/src/tree/gpu_hist/evaluator.cu index 381ef8fbb349..bc2027489131 100644 --- a/src/tree/gpu_hist/evaluator.cu +++ b/src/tree/gpu_hist/evaluator.cu @@ -21,36 +21,55 @@ void GPUHistEvaluator::Reset(common::HistogramCuts const &cuts, int32_t device) { param_ = param; tree_evaluator_ = TreeEvaluator{param, n_features, device}; - if (cuts.HasCategorical()) { + if (cuts.HasCategorical() && !task.UseOneHot()) { dh::XGBCachingDeviceAllocator alloc; - auto bit_storage_size = common::CatBitField::ComputeStorageSize(cuts.MaxCategory() + 1); - CHECK_NE(bit_storage_size, 0); - // We need to allocate for all nodes since the updater can grow the tree layer by - // layer, all nodes in the same layer must be preserved until that layer is - // finished. We can allocate one layer at a time, but the best case is reducing the - // size of the bitset by about a half, at the cost of invoking CUDA malloc many more - // times than necessary. - split_cats_.resize(param.MaxNodes() * bit_storage_size); - h_split_cats_.resize(split_cats_.size()); - dh::safe_cuda( - cudaMemsetAsync(split_cats_.data().get(), '\0', split_cats_.size() * sizeof(CatST))); + auto ptrs = cuts.cut_ptrs_.ConstDeviceSpan(); + auto beg = thrust::make_counting_iterator(1ul); + auto end = thrust::make_counting_iterator(ptrs.size()); + auto to_onehot = param.max_cat_to_onehot; + // This condition avoids sort-based split function calls if the users want + // onehot-encoding-based splits. + // For some reason, any_of adds 1.5 minutes to compilation time for CUDA 11.x. + has_sort_ = thrust::any_of(thrust::cuda::par(alloc), beg, end, [=] XGBOOST_DEVICE(size_t i) { + auto idx = i - 1; + if (common::IsCat(ft, idx)) { + auto n_bins = ptrs[i] - ptrs[idx]; + bool use_sort = !common::UseOneHot(n_bins, to_onehot, task); + return use_sort; + } + return false; + }); - cat_sorted_idx_.resize(cuts.cut_values_.Size() * 2); // evaluate 2 nodes at a time. - sort_input_.resize(cat_sorted_idx_.size()); + if (has_sort_) { + auto bit_storage_size = common::CatBitField::ComputeStorageSize(cuts.MaxCategory() + 1); + CHECK_NE(bit_storage_size, 0); + // We need to allocate for all nodes since the updater can grow the tree layer by + // layer, all nodes in the same layer must be preserved until that layer is + // finished. We can allocate one layer at a time, but the best case is reducing the + // size of the bitset by about a half, at the cost of invoking CUDA malloc many more + // times than necessary. + split_cats_.resize(param.MaxNodes() * bit_storage_size); + h_split_cats_.resize(split_cats_.size()); + dh::safe_cuda( + cudaMemsetAsync(split_cats_.data().get(), '\0', split_cats_.size() * sizeof(CatST))); - /** - * cache feature index binary search result - */ - feature_idx_.resize(cat_sorted_idx_.size()); - auto d_fidxes = dh::ToSpan(feature_idx_); - auto it = thrust::make_counting_iterator(0ul); - auto values = cuts.cut_values_.ConstDeviceSpan(); - auto ptrs = cuts.cut_ptrs_.ConstDeviceSpan(); - thrust::transform(thrust::cuda::par(alloc), it, it + feature_idx_.size(), feature_idx_.begin(), - [=] XGBOOST_DEVICE(size_t i) { - auto fidx = dh::SegmentId(ptrs, i); - return fidx; - }); + cat_sorted_idx_.resize(cuts.cut_values_.Size() * 2); // evaluate 2 nodes at a time. + sort_input_.resize(cat_sorted_idx_.size()); + + /** + * cache feature index binary search result + */ + feature_idx_.resize(cat_sorted_idx_.size()); + auto d_fidxes = dh::ToSpan(feature_idx_); + auto it = thrust::make_counting_iterator(0ul); + auto values = cuts.cut_values_.ConstDeviceSpan(); + auto ptrs = cuts.cut_ptrs_.ConstDeviceSpan(); + thrust::transform(thrust::cuda::par(alloc), it, it + feature_idx_.size(), + feature_idx_.begin(), [=] XGBOOST_DEVICE(size_t i) { + auto fidx = dh::SegmentId(ptrs, i); + return fidx; + }); + } } } diff --git a/src/tree/gpu_hist/histogram.cu b/src/tree/gpu_hist/histogram.cu index efb08d5e44e2..791363a05cdd 100644 --- a/src/tree/gpu_hist/histogram.cu +++ b/src/tree/gpu_hist/histogram.cu @@ -247,6 +247,15 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix, dh::safe_cuda(cudaGetLastError()); } +template void BuildGradientHistogram( + EllpackDeviceAccessor const& matrix, + FeatureGroupsAccessor const& feature_groups, + common::Span gpair, + common::Span ridx, + common::Span histogram, + HistRounding rounding, + bool force_global_memory); + template void BuildGradientHistogram( EllpackDeviceAccessor const& matrix, FeatureGroupsAccessor const& feature_groups, diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 2cd9d4babeb1..2340687983a8 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -57,7 +57,7 @@ DMLC_REGISTER_PARAMETER(GPUHistMakerTrainParam); #endif // !defined(GTEST_TEST) /** - * \struct DeviceHistogramStorage + * \struct DeviceHistogram * * \summary Data storage for node histograms on device. Automatically expands. * @@ -67,18 +67,12 @@ DMLC_REGISTER_PARAMETER(GPUHistMakerTrainParam); * \author Rory * \date 28/07/2018 */ -template -class DeviceHistogramStorage { +template +class DeviceHistogram { private: /*! \brief Map nidx to starting index of its histogram. */ std::map nidx_map_; - // Large buffer of zeroed memory, caches histograms dh::device_vector data_; - // If we run out of storage allocate one histogram at a time - // in overflow. Not cached, overwritten when a new histogram - // is requested - dh::device_vector overflow_; - std::map overflow_nidx_map_; int n_bins_; int device_id_; static constexpr size_t kNumItemsInGradientSum = @@ -87,8 +81,6 @@ class DeviceHistogramStorage { "Number of items in gradient type should be 2."); public: - // Start with about 16mb - DeviceHistogramStorage() { data_.reserve(1 << 22); } void Init(int device_id, int n_bins) { this->n_bins_ = n_bins; this->device_id_ = device_id; @@ -99,53 +91,21 @@ class DeviceHistogramStorage { dh::LaunchN(data_.size(), [=] __device__(size_t idx) { d_data[idx] = 0.0f; }); nidx_map_.clear(); - overflow_nidx_map_.clear(); } bool HistogramExists(int nidx) const { - return nidx_map_.find(nidx) != nidx_map_.cend() || overflow_nidx_map_.find(nidx) != overflow_nidx_map_.cend(); + return nidx_map_.find(nidx) != nidx_map_.cend(); } int Bins() const { return n_bins_; } - size_t HistogramSize() const { return n_bins_ * kNumItemsInGradientSum; } - dh::device_vector& Data() { return data_; } - - void AllocateHistograms(const std::vector& new_nidxs) { - for (int nidx : new_nidxs) { - CHECK(!HistogramExists(nidx)); - } - // Number of items currently used in data - const size_t used_size = nidx_map_.size() * HistogramSize(); - const size_t new_used_size = used_size + HistogramSize() * new_nidxs.size(); - if (used_size >= kStopGrowingSize) { - // Use overflow - // Delete previous entries - overflow_nidx_map_.clear(); - overflow_.resize(HistogramSize() * new_nidxs.size()); - // Zero memory - auto d_data = overflow_.data().get(); - dh::LaunchN(overflow_.size(), - [=] __device__(size_t idx) { d_data[idx] = 0.0; }); - // Append new histograms - for (int nidx : new_nidxs) { - overflow_nidx_map_[nidx] = overflow_nidx_map_.size() * HistogramSize(); - } - } else { - CHECK_GE(data_.size(), used_size); - // Expand if necessary - if (data_.size() < new_used_size) { - data_.resize(std::max(data_.size() * 2, new_used_size)); - } - // Append new histograms - for (int nidx : new_nidxs) { - nidx_map_[nidx] = nidx_map_.size() * HistogramSize(); - } - } + size_t HistogramSize() const { + return n_bins_ * kNumItemsInGradientSum; + } - CHECK_GE(data_.size(), nidx_map_.size() * HistogramSize()); + dh::device_vector& Data() { + return data_; } - /* void AllocateHistogram(int nidx) { if (HistogramExists(nidx)) return; // Number of items currently used in data @@ -179,7 +139,6 @@ class DeviceHistogramStorage { CHECK_GE(data_.size(), nidx_map_.size() * HistogramSize()); } - */ /** * \summary Return pointer to histogram memory for a given node. @@ -188,16 +147,9 @@ class DeviceHistogramStorage { */ common::Span GetNodeHistogram(int nidx) { CHECK(this->HistogramExists(nidx)); - - if (nidx_map_.find(nidx) != nidx_map_.cend()) { - // Fetch from normal cache - auto ptr = data_.data().get() + nidx_map_.at(nidx); - return common::Span(reinterpret_cast(ptr), n_bins_); - } else { - // Fetch from overflow - auto ptr = overflow_.data().get() + overflow_nidx_map_.at(nidx); - return common::Span(reinterpret_cast(ptr), n_bins_); - } + auto ptr = data_.data().get() + nidx_map_.at(nidx); + return common::Span( + reinterpret_cast(ptr), n_bins_); } }; @@ -214,7 +166,7 @@ struct GPUHistMakerDevice { BatchParam batch_param; std::unique_ptr row_partitioner; - DeviceHistogramStorage hist{}; + DeviceHistogram hist{}; dh::caching_device_vector d_gpair; // storage for gpair; common::Span gpair; @@ -237,6 +189,8 @@ struct GPUHistMakerDevice { std::unique_ptr sampler; std::unique_ptr feature_groups; + // Storing split categories for last node. + dh::caching_device_vector node_categories; GPUHistMakerDevice(int _device_id, EllpackPageImpl const* _page, common::Span _feature_types, bst_uint _n_rows, @@ -365,6 +319,7 @@ struct GPUHistMakerDevice { } void BuildHist(int nidx) { + hist.AllocateHistogram(nidx); auto d_node_hist = hist.GetNodeHistogram(nidx); auto d_ridx = row_partitioner->GetRows(nidx); BuildGradientHistogram(page->GetDeviceAccessor(device_id), @@ -372,12 +327,8 @@ struct GPUHistMakerDevice { d_ridx, d_node_hist, histogram_rounding); } - // Attempt to the subtraction trick - // return true if succeeded - bool SubtractionTrick(int nidx_parent, int nidx_histogram, int nidx_subtraction) { - if (!hist.HistogramExists(nidx_histogram) || !hist.HistogramExists(nidx_parent)) { - return false; - } + void SubtractionTrick(int nidx_parent, int nidx_histogram, + int nidx_subtraction) { auto d_node_hist_parent = hist.GetNodeHistogram(nidx_parent); auto d_node_hist_histogram = hist.GetNodeHistogram(nidx_histogram); auto d_node_hist_subtraction = hist.GetNodeHistogram(nidx_subtraction); @@ -386,18 +337,22 @@ struct GPUHistMakerDevice { d_node_hist_subtraction[idx] = d_node_hist_parent[idx] - d_node_hist_histogram[idx]; }); - return true; } - void UpdatePosition(const GPUExpandEntry &e, RegTree* p_tree) { - RegTree::Node split_node = (*p_tree)[e.nid]; - auto split_type = p_tree->NodeSplitType(e.nid); + bool CanDoSubtractionTrick(int nidx_parent, int nidx_histogram, int nidx_subtraction) { + // Make sure histograms are already allocated + hist.AllocateHistogram(nidx_subtraction); + return hist.HistogramExists(nidx_histogram) && hist.HistogramExists(nidx_parent); + } + + void UpdatePosition(int nidx, RegTree* p_tree) { + RegTree::Node split_node = (*p_tree)[nidx]; + auto split_type = p_tree->NodeSplitType(nidx); auto d_matrix = page->GetDeviceAccessor(device_id); - auto node_cats = e.split.split_cats.Bits(); - + auto node_cats = dh::ToSpan(node_categories); row_partitioner->UpdatePosition( - e.nid, split_node.LeftChild(), split_node.RightChild(), + nidx, split_node.LeftChild(), split_node.RightChild(), [=] __device__(bst_uint ridx) { // given a row index, returns the node id it belongs to bst_float cut_value = @@ -528,15 +483,13 @@ struct GPUHistMakerDevice { row_partitioner.reset(); } - // num histograms is the number of contiguous histograms in memory to reduce over - void AllReduceHist(int nidx, dh::AllReducer* reducer, int num_histograms) { + void AllReduceHist(int nidx, dh::AllReducer* reducer) { monitor.Start("AllReduce"); auto d_node_hist = hist.GetNodeHistogram(nidx).data(); - reducer->AllReduceSum(reinterpret_cast(d_node_hist), - reinterpret_cast(d_node_hist), - page->Cuts().TotalBins() * - (sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT)) * - num_histograms); + reducer->AllReduceSum( + reinterpret_cast(d_node_hist), + reinterpret_cast(d_node_hist), + page->Cuts().TotalBins() * (sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT))); monitor.Stop("AllReduce"); } @@ -544,49 +497,33 @@ struct GPUHistMakerDevice { /** * \brief Build GPU local histograms for the left and right child of some parent node */ - void BuildHistLeftRight(std::vectorconst &candidates, dh::AllReducer* reducer, const RegTree& tree) { - if(candidates.empty()) return; - // Some nodes we will manually compute histograms - // others we will do by subtraction - std::vector hist_nidx; - std::vector subtraction_nidx; - for (auto& e : candidates) { - // Decide whether to build the left histogram or right histogram - // Use sum of Hessian as a heuristic to select node with fewest training instances - bool fewer_right = e.split.right_sum.GetHess() < e.split.left_sum.GetHess(); - if (fewer_right) { - hist_nidx.emplace_back(tree[e.nid].RightChild()); - subtraction_nidx.emplace_back(tree[e.nid].LeftChild()); - } else { - hist_nidx.emplace_back(tree[e.nid].LeftChild()); - subtraction_nidx.emplace_back(tree[e.nid].RightChild()); - } - } - std::vector all_new = hist_nidx; - all_new.insert(all_new.end(), subtraction_nidx.begin(), subtraction_nidx.end()); - // Allocate the histograms - // Guaranteed contiguous memory - hist.AllocateHistograms(all_new); - - for(auto nidx:hist_nidx){ - this->BuildHist(nidx); + void BuildHistLeftRight(const GPUExpandEntry &candidate, int nidx_left, + int nidx_right, dh::AllReducer* reducer) { + auto build_hist_nidx = nidx_left; + auto subtraction_trick_nidx = nidx_right; + + // Decide whether to build the left histogram or right histogram + // Use sum of Hessian as a heuristic to select node with fewest training instances + bool fewer_right = candidate.split.right_sum.GetHess() < candidate.split.left_sum.GetHess(); + if (fewer_right) { + std::swap(build_hist_nidx, subtraction_trick_nidx); } - // Reduce all in one go - // This gives much better latency in a distributed setting - // when processing a large batch - this->AllReduceHist(hist_nidx.at(0), reducer, hist_nidx.size()); + this->BuildHist(build_hist_nidx); + this->AllReduceHist(build_hist_nidx, reducer); - for (int i = 0; i < subtraction_nidx.size(); i++) { - auto build_hist_nidx = hist_nidx.at(i); - auto subtraction_trick_nidx = subtraction_nidx.at(i); - auto parent_nidx = candidates.at(i).nid; + // Check whether we can use the subtraction trick to calculate the other + bool do_subtraction_trick = this->CanDoSubtractionTrick( + candidate.nid, build_hist_nidx, subtraction_trick_nidx); - if(!this->SubtractionTrick(parent_nidx, build_hist_nidx, subtraction_trick_nidx)){ - // Calculate other histogram manually - this->BuildHist(subtraction_trick_nidx); - this->AllReduceHist(subtraction_trick_nidx, reducer, 1); - } + if (do_subtraction_trick) { + // Calculate other histogram using subtraction trick + this->SubtractionTrick(candidate.nid, build_hist_nidx, + subtraction_trick_nidx); + } else { + // Calculate other histogram manually + this->BuildHist(subtraction_trick_nidx); + this->AllReduceHist(subtraction_trick_nidx, reducer); } } @@ -609,11 +546,27 @@ struct GPUHistMakerDevice { CHECK_LT(candidate.split.fvalue, std::numeric_limits::max()) << "Categorical feature value too large."; std::vector split_cats; - auto h_cats = this->evaluator_.GetHostNodeCats(candidate.nid); - auto max_cat = candidate.split.MaxCat(); - split_cats.resize(common::CatBitField::ComputeStorageSize(max_cat + 1), 0); - CHECK_LE(split_cats.size(), h_cats.size()); - std::copy(h_cats.data(), h_cats.data() + split_cats.size(), split_cats.data()); + if (candidate.split.split_cats.Bits().empty()) { + if (common::InvalidCat(candidate.split.fvalue)) { + common::InvalidCategory(); + } + auto cat = common::AsCat(candidate.split.fvalue); + split_cats.resize(LBitField32::ComputeStorageSize(cat + 1), 0); + common::CatBitField cats_bits(split_cats); + cats_bits.Set(cat); + dh::CopyToD(split_cats, &node_categories); + } else { + auto h_cats = this->evaluator_.GetHostNodeCats(candidate.nid); + auto max_cat = candidate.split.MaxCat(); + split_cats.resize(common::CatBitField::ComputeStorageSize(max_cat + 1), 0); + CHECK_LE(split_cats.size(), h_cats.size()); + std::copy(h_cats.data(), h_cats.data() + split_cats.size(), split_cats.data()); + + node_categories.resize(candidate.split.split_cats.Bits().size()); + dh::safe_cuda(cudaMemcpyAsync( + node_categories.data().get(), candidate.split.split_cats.Data(), + candidate.split.split_cats.Bits().size_bytes(), cudaMemcpyDeviceToDevice)); + } tree.ExpandCategorical( candidate.nid, candidate.split.findex, split_cats, candidate.split.dir == kLeftDir, @@ -645,9 +598,8 @@ struct GPUHistMakerDevice { GradientPairPrecise{}, thrust::plus{}); rabit::Allreduce(reinterpret_cast(&root_sum), 2); - hist.AllocateHistograms({kRootNIdx}); this->BuildHist(kRootNIdx); - this->AllReduceHist(kRootNIdx, reducer, 1); + this->AllReduceHist(kRootNIdx, reducer); // Remember root stats node_sum_gradients[kRootNIdx] = root_sum; @@ -686,7 +638,6 @@ struct GPUHistMakerDevice { std::copy_if(expand_set.begin(), expand_set.end(), std::back_inserter(filtered_expand_set), [&](const auto& e) { return driver.IsChildValid(e); }); - auto new_candidates = pinned.GetSpan(filtered_expand_set.size() * 2, GPUExpandEntry()); @@ -695,16 +646,22 @@ struct GPUHistMakerDevice { // Update position is only run when child is valid, instead of right after apply // split (as in approx tree method). Hense we have the finalise position call // in GPU Hist. - this->UpdatePosition(e, p_tree); + this->UpdatePosition(e.nid, p_tree); monitor.Stop("UpdatePosition"); } - monitor.Start("BuildHist"); - this->BuildHistLeftRight(filtered_expand_set, reducer, tree); - monitor.Stop("BuildHist"); + for (auto i = 0ull; i < filtered_expand_set.size(); i++) { + auto candidate = expand_set.at(i); + int left_child_nidx = tree[candidate.nid].LeftChild(); + int right_child_nidx = tree[candidate.nid].RightChild(); + + monitor.Start("BuildHist"); + this->BuildHistLeftRight(candidate, left_child_nidx, right_child_nidx, reducer); + monitor.Stop("BuildHist"); + } for (auto i = 0ull; i < filtered_expand_set.size(); i++) { - auto candidate = filtered_expand_set.at(i); + auto candidate = expand_set.at(i); int left_child_nidx = tree[candidate.nid].LeftChild(); int right_child_nidx = tree[candidate.nid].RightChild(); diff --git a/tests/cpp/tree/gpu_hist/test_histogram.cu b/tests/cpp/tree/gpu_hist/test_histogram.cu index 75d97b681a61..3b543a48d7cc 100644 --- a/tests/cpp/tree/gpu_hist/test_histogram.cu +++ b/tests/cpp/tree/gpu_hist/test_histogram.cu @@ -95,6 +95,7 @@ TEST(Histogram, GPUDeterministic) { std::vector shm_sizes{48 * 1024, 64 * 1024, 160 * 1024}; for (bool is_dense : is_dense_array) { for (int shm_size : shm_sizes) { + TestDeterministicHistogram(is_dense, shm_size); TestDeterministicHistogram(is_dense, shm_size); } } diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index bdabbbcb38c2..883537863307 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -29,38 +29,29 @@ TEST(GpuHist, DeviceHistogram) { constexpr size_t kNBins = 128; constexpr size_t kNNodes = 4; constexpr size_t kStopGrowing = kNNodes * kNBins * 2u; - DeviceHistogramStorage histogram; + DeviceHistogram histogram; histogram.Init(0, kNBins); - for (int i = 0; i < kNNodes; ++i) { - histogram.AllocateHistograms({i}); + for (size_t i = 0; i < kNNodes; ++i) { + histogram.AllocateHistogram(i); } histogram.Reset(); ASSERT_EQ(histogram.Data().size(), kStopGrowing); // Use allocated memory but do not erase nidx_map. - for (int i = 0; i < kNNodes; ++i) { - histogram.AllocateHistograms({i}); + for (size_t i = 0; i < kNNodes; ++i) { + histogram.AllocateHistogram(i); } - for (int i = 0; i < kNNodes; ++i) { + for (size_t i = 0; i < kNNodes; ++i) { ASSERT_TRUE(histogram.HistogramExists(i)); } - // Add two new nodes - histogram.AllocateHistograms({kNNodes}); - histogram.AllocateHistograms({kNNodes+1}); - - // Old cached nodes should still exist - for (int i = 0; i < kNNodes; ++i) { - ASSERT_TRUE(histogram.HistogramExists(i)); + // Erase existing nidx_map. + for (size_t i = kNNodes; i < kNNodes * 2; ++i) { + histogram.AllocateHistogram(i); + } + for (size_t i = 0; i < kNNodes; ++i) { + ASSERT_FALSE(histogram.HistogramExists(i)); } - - // Should be deleted - ASSERT_FALSE(histogram.HistogramExists({kNNodes})); - // Most recent node should exist - ASSERT_TRUE(histogram.HistogramExists({kNNodes + 1})); - - // Add same node again - should fail - EXPECT_ANY_THROW(histogram.AllocateHistograms({kNNodes+1});); } std::vector GetHostHistGpair() { @@ -104,9 +95,9 @@ void TestBuildHist(bool use_shared_memory_histograms) { thrust::host_vector h_gidx_buffer (page->gidx_buffer.HostVector()); maker.row_partitioner.reset(new RowPartitioner(0, kNRows)); - maker.hist.AllocateHistograms({0}); + maker.hist.AllocateHistogram(0); maker.gpair = gpair.DeviceSpan(); - maker.histogram_rounding = CreateRoundingFactor(maker.gpair); + maker.histogram_rounding = CreateRoundingFactor(maker.gpair);; BuildGradientHistogram( page->GetDeviceAccessor(0), maker.feature_groups->DeviceAccessor(0), @@ -114,7 +105,7 @@ void TestBuildHist(bool use_shared_memory_histograms) { maker.hist.GetNodeHistogram(0), maker.histogram_rounding, !use_shared_memory_histograms); - DeviceHistogramStorage& d_hist = maker.hist; + DeviceHistogram& d_hist = maker.hist; auto node_histogram = d_hist.GetNodeHistogram(0); // d_hist.data stored in float, not gradient pair @@ -137,10 +128,12 @@ void TestBuildHist(bool use_shared_memory_histograms) { TEST(GpuHist, BuildHistGlobalMem) { TestBuildHist(false); + TestBuildHist(false); } TEST(GpuHist, BuildHistSharedMem) { TestBuildHist(true); + TestBuildHist(true); } TEST(GpuHist, ApplySplit) { @@ -180,6 +173,8 @@ TEST(GpuHist, ApplySplit) { ASSERT_EQ(tree.GetSplitCategories().size(), 1); uint32_t bits = 1u << 30; // bits: 0, 1, 0, 0, 0, ..., 0 ASSERT_EQ(tree.GetSplitCategories().back(), bits); + + ASSERT_EQ(updater.node_categories.size(), 1); } } @@ -243,7 +238,7 @@ TEST(GpuHist, EvaluateRootSplit) { // Initialize GPUHistMakerDevice::hist maker.hist.Init(0, (max_bins - 1) * kNCols); - maker.hist.AllocateHistograms({0}); + maker.hist.AllocateHistogram(0); // Each row of hist_gpair represents gpairs for one feature. // Each entry represents a bin. std::vector hist_gpair = GetHostHistGpair(); From fd0e25e0bd2cf05f33766c7b1deb1471126f9447 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Thu, 5 May 2022 08:49:40 -0700 Subject: [PATCH 09/14] Lint --- src/tree/driver.h | 4 ++-- src/tree/updater_approx.cc | 8 +------- src/tree/updater_gpu_hist.cu | 4 ++-- src/tree/updater_quantile_hist.cc | 7 +------ 4 files changed, 6 insertions(+), 17 deletions(-) diff --git a/src/tree/driver.h b/src/tree/driver.h index 1e40cc32622f..e61255e043c7 100644 --- a/src/tree/driver.h +++ b/src/tree/driver.h @@ -57,7 +57,7 @@ class Driver { // Can a child of this entry still be expanded? // can be used to avoid extra work - bool IsChildValid(ExpandEntryT const& parent_entry){ + bool IsChildValid(ExpandEntryT const& parent_entry) { if (param_.max_depth > 0 && parent_entry.depth + 1 >= param_.max_depth) return false; if (param_.max_leaves > 0 && num_leaves_ >= param_.max_leaves) return false; return true; @@ -100,7 +100,7 @@ class Driver { private: TrainParam param_; - std::size_t num_leaves_=1; + std::size_t num_leaves_ = 1; ExpandQueue queue_; }; } // namespace tree diff --git a/src/tree/updater_approx.cc b/src/tree/updater_approx.cc index fc05aed0a3ee..99e7cf738200 100644 --- a/src/tree/updater_approx.cc +++ b/src/tree/updater_approx.cc @@ -184,7 +184,6 @@ class GloablApproxBuilder { Driver driver(param_); auto &tree = *p_tree; driver.Push({this->InitRoot(p_fmat, gpair, hess, p_tree)}); - bst_node_t num_leaves{1}; auto expand_set = driver.Pop(); /** @@ -203,14 +202,9 @@ class GloablApproxBuilder { // candidates that can be applied. std::vector applied; for (auto const &candidate : expand_set) { - if (!candidate.IsValid(param_, num_leaves)) { - continue; - } evaluator_.ApplyTreeSplit(candidate, p_tree); applied.push_back(candidate); - num_leaves++; - int left_child_nidx = tree[candidate.nid].LeftChild(); - if (CPUExpandEntry::ChildIsValid(param_, p_tree->GetDepth(left_child_nidx), num_leaves)) { + if (driver.IsChildValid(candidate)) { valid_candidates.emplace_back(candidate); } } diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 07f1499e213f..634f2969a090 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -637,7 +637,7 @@ struct GPUHistMakerDevice { // The set of leaves that can be expanded asynchronously auto expand_set = driver.Pop(); while (!expand_set.empty()) { - for(auto & candidate: expand_set){ + for (auto& candidate : expand_set) { this->ApplySplit(candidate, p_tree); } // Get the candidates we are allowed to expand further @@ -649,7 +649,7 @@ struct GPUHistMakerDevice { auto new_candidates = pinned.GetSpan(filtered_expand_set.size() * 2, GPUExpandEntry()); - for(const auto &e:filtered_expand_set){ + for (const auto& e : filtered_expand_set) { monitor.Start("UpdatePosition"); // Update position is only run when child is valid, instead of right after apply // split (as in approx tree method). Hense we have the finalise position call diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index c69f8c8dba0b..ed3dff67295a 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -196,7 +196,6 @@ void QuantileHistMaker::Builder::ExpandTree( Driver driver(param_); driver.Push(this->InitRoot(p_fmat, p_tree, gpair_h)); auto const &tree = *p_tree; - bst_node_t num_leaves{1}; auto expand_set = driver.Pop(); while (!expand_set.empty()) { @@ -206,13 +205,9 @@ void QuantileHistMaker::Builder::ExpandTree( std::vector applied; int32_t depth = expand_set.front().depth + 1; for (auto const& candidate : expand_set) { - if (!candidate.IsValid(param_, num_leaves)) { - continue; - } evaluator_->ApplyTreeSplit(candidate, p_tree); applied.push_back(candidate); - num_leaves++; - if (CPUExpandEntry::ChildIsValid(param_, depth, num_leaves)) { + if (driver.IsChildValid(candidate)) { valid_candidates.emplace_back(candidate); } } From 56785f3168c26a248572e6edd0f6c8b8c2885bde Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Fri, 6 May 2022 05:05:15 -0700 Subject: [PATCH 10/14] Revert "Revert "Categoricals broken"" This reverts commit a1cddaabbf93bb0be86bfc293dea5a84e233d719. --- src/tree/gpu_hist/evaluate_splits.cuh | 2 +- src/tree/gpu_hist/histogram.cu | 9 - src/tree/updater_gpu_hist.cu | 202 ++++++++++++---------- tests/cpp/tree/gpu_hist/test_histogram.cu | 1 - tests/cpp/tree/test_gpu_hist.cu | 43 +++-- tests/python-gpu/test_gpu_updaters.py | 2 +- 6 files changed, 140 insertions(+), 119 deletions(-) diff --git a/src/tree/gpu_hist/evaluate_splits.cuh b/src/tree/gpu_hist/evaluate_splits.cuh index 8d5cc809a280..08b0270ee4d7 100644 --- a/src/tree/gpu_hist/evaluate_splits.cuh +++ b/src/tree/gpu_hist/evaluate_splits.cuh @@ -103,7 +103,7 @@ class GPUHistEvaluator { } /** - * \brief Get sorted index storage based on the left node of inputs . + * \brief Get sorted index storage based on the left node of inputs. */ auto SortedIdx(EvaluateSplitInputs left) { if (left.nidx == RegTree::kRoot && !cat_sorted_idx_.empty()) { diff --git a/src/tree/gpu_hist/histogram.cu b/src/tree/gpu_hist/histogram.cu index 791363a05cdd..efb08d5e44e2 100644 --- a/src/tree/gpu_hist/histogram.cu +++ b/src/tree/gpu_hist/histogram.cu @@ -247,15 +247,6 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix, dh::safe_cuda(cudaGetLastError()); } -template void BuildGradientHistogram( - EllpackDeviceAccessor const& matrix, - FeatureGroupsAccessor const& feature_groups, - common::Span gpair, - common::Span ridx, - common::Span histogram, - HistRounding rounding, - bool force_global_memory); - template void BuildGradientHistogram( EllpackDeviceAccessor const& matrix, FeatureGroupsAccessor const& feature_groups, diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 634f2969a090..964a486baf16 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -62,7 +62,7 @@ DMLC_REGISTER_PARAMETER(GPUHistMakerTrainParam); #endif // !defined(GTEST_TEST) /** - * \struct DeviceHistogram + * \struct DeviceHistogramStorage * * \summary Data storage for node histograms on device. Automatically expands. * @@ -72,12 +72,18 @@ DMLC_REGISTER_PARAMETER(GPUHistMakerTrainParam); * \author Rory * \date 28/07/2018 */ -template -class DeviceHistogram { +template +class DeviceHistogramStorage { private: /*! \brief Map nidx to starting index of its histogram. */ std::map nidx_map_; + // Large buffer of zeroed memory, caches histograms dh::device_vector data_; + // If we run out of storage allocate one histogram at a time + // in overflow. Not cached, overwritten when a new histogram + // is requested + dh::device_vector overflow_; + std::map overflow_nidx_map_; int n_bins_; int device_id_; static constexpr size_t kNumItemsInGradientSum = @@ -86,6 +92,8 @@ class DeviceHistogram { "Number of items in gradient type should be 2."); public: + // Start with about 16mb + DeviceHistogramStorage() { data_.reserve(1 << 22); } void Init(int device_id, int n_bins) { this->n_bins_ = n_bins; this->device_id_ = device_id; @@ -93,52 +101,48 @@ class DeviceHistogram { void Reset() { auto d_data = data_.data().get(); - dh::LaunchN(data_.size(), - [=] __device__(size_t idx) { d_data[idx] = 0.0f; }); + dh::LaunchN(data_.size(), [=] __device__(size_t idx) { d_data[idx] = 0.0f; }); nidx_map_.clear(); + overflow_nidx_map_.clear(); } bool HistogramExists(int nidx) const { - return nidx_map_.find(nidx) != nidx_map_.cend(); + return nidx_map_.find(nidx) != nidx_map_.cend() || overflow_nidx_map_.find(nidx) != overflow_nidx_map_.cend(); } int Bins() const { return n_bins_; } - size_t HistogramSize() const { - return n_bins_ * kNumItemsInGradientSum; - } - - dh::device_vector& Data() { - return data_; - } + size_t HistogramSize() const { return n_bins_ * kNumItemsInGradientSum; } + dh::device_vector& Data() { return data_; } - void AllocateHistogram(int nidx) { - if (HistogramExists(nidx)) return; + void AllocateHistograms(const std::vector& new_nidxs) { + for (int nidx : new_nidxs) { + CHECK(!HistogramExists(nidx)); + } // Number of items currently used in data const size_t used_size = nidx_map_.size() * HistogramSize(); - const size_t new_used_size = used_size + HistogramSize(); - if (data_.size() >= kStopGrowingSize) { - // Recycle histogram memory - if (new_used_size <= data_.size()) { - // no need to remove old node, just insert the new one. - nidx_map_[nidx] = used_size; - // memset histogram size in bytes - } else { - std::pair old_entry = *nidx_map_.begin(); - nidx_map_.erase(old_entry.first); - nidx_map_[nidx] = old_entry.second; + const size_t new_used_size = used_size + HistogramSize() * new_nidxs.size(); + if (used_size >= kStopGrowingSize) { + // Use overflow + // Delete previous entries + overflow_nidx_map_.clear(); + overflow_.resize(HistogramSize() * new_nidxs.size()); + // Zero memory + auto d_data = overflow_.data().get(); + dh::LaunchN(overflow_.size(), + [=] __device__(size_t idx) { d_data[idx] = 0.0; }); + // Append new histograms + for (int nidx : new_nidxs) { + overflow_nidx_map_[nidx] = overflow_nidx_map_.size() * HistogramSize(); } - // Zero recycled memory - auto d_data = data_.data().get() + nidx_map_[nidx]; - dh::LaunchN(n_bins_ * 2, - [=] __device__(size_t idx) { d_data[idx] = 0.0f; }); } else { - // Append new node histogram - nidx_map_[nidx] = used_size; - // Check there is enough memory for another histogram node - if (data_.size() < new_used_size + HistogramSize()) { - size_t new_required_memory = - std::max(data_.size() * 2, HistogramSize()); - data_.resize(new_required_memory); + CHECK_GE(data_.size(), used_size); + // Expand if necessary + if (data_.size() < new_used_size) { + data_.resize(std::max(data_.size() * 2, new_used_size)); + } + // Append new histograms + for (int nidx : new_nidxs) { + nidx_map_[nidx] = nidx_map_.size() * HistogramSize(); } } @@ -152,9 +156,16 @@ class DeviceHistogram { */ common::Span GetNodeHistogram(int nidx) { CHECK(this->HistogramExists(nidx)); - auto ptr = data_.data().get() + nidx_map_.at(nidx); - return common::Span( - reinterpret_cast(ptr), n_bins_); + + if (nidx_map_.find(nidx) != nidx_map_.cend()) { + // Fetch from normal cache + auto ptr = data_.data().get() + nidx_map_.at(nidx); + return common::Span(reinterpret_cast(ptr), n_bins_); + } else { + // Fetch from overflow + auto ptr = overflow_.data().get() + overflow_nidx_map_.at(nidx); + return common::Span(reinterpret_cast(ptr), n_bins_); + } } }; @@ -171,7 +182,7 @@ struct GPUHistMakerDevice { BatchParam batch_param; std::unique_ptr row_partitioner; - DeviceHistogram hist{}; + DeviceHistogramStorage hist{}; dh::caching_device_vector d_gpair; // storage for gpair; common::Span gpair; @@ -322,7 +333,6 @@ struct GPUHistMakerDevice { } void BuildHist(int nidx) { - hist.AllocateHistogram(nidx); auto d_node_hist = hist.GetNodeHistogram(nidx); auto d_ridx = row_partitioner->GetRows(nidx); BuildGradientHistogram(page->GetDeviceAccessor(ctx_->gpu_id), @@ -330,8 +340,12 @@ struct GPUHistMakerDevice { d_ridx, d_node_hist, histogram_rounding); } - void SubtractionTrick(int nidx_parent, int nidx_histogram, - int nidx_subtraction) { + // Attempt to do subtraction trick + // return true if succeeded + bool SubtractionTrick(int nidx_parent, int nidx_histogram, int nidx_subtraction) { + if (!hist.HistogramExists(nidx_histogram) || !hist.HistogramExists(nidx_parent)) { + return false; + } auto d_node_hist_parent = hist.GetNodeHistogram(nidx_parent); auto d_node_hist_histogram = hist.GetNodeHistogram(nidx_histogram); auto d_node_hist_subtraction = hist.GetNodeHistogram(nidx_subtraction); @@ -340,12 +354,7 @@ struct GPUHistMakerDevice { d_node_hist_subtraction[idx] = d_node_hist_parent[idx] - d_node_hist_histogram[idx]; }); - } - - bool CanDoSubtractionTrick(int nidx_parent, int nidx_histogram, int nidx_subtraction) { - // Make sure histograms are already allocated - hist.AllocateHistogram(nidx_subtraction); - return hist.HistogramExists(nidx_histogram) && hist.HistogramExists(nidx_parent); + return true; } void UpdatePosition(const GPUExpandEntry &e, RegTree* p_tree) { @@ -505,13 +514,15 @@ struct GPUHistMakerDevice { row_partitioner.reset(); } - void AllReduceHist(int nidx, dh::AllReducer* reducer) { + // num histograms is the number of contiguous histograms in memory to reduce over + void AllReduceHist(int nidx, dh::AllReducer* reducer, int num_histograms) { monitor.Start("AllReduce"); auto d_node_hist = hist.GetNodeHistogram(nidx).data(); - reducer->AllReduceSum( - reinterpret_cast(d_node_hist), - reinterpret_cast(d_node_hist), - page->Cuts().TotalBins() * (sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT))); + reducer->AllReduceSum(reinterpret_cast(d_node_hist), + reinterpret_cast(d_node_hist), + page->Cuts().TotalBins() * + (sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT)) * + num_histograms); monitor.Stop("AllReduce"); } @@ -519,33 +530,50 @@ struct GPUHistMakerDevice { /** * \brief Build GPU local histograms for the left and right child of some parent node */ - void BuildHistLeftRight(const GPUExpandEntry &candidate, int nidx_left, - int nidx_right, dh::AllReducer* reducer) { - auto build_hist_nidx = nidx_left; - auto subtraction_trick_nidx = nidx_right; - - // Decide whether to build the left histogram or right histogram - // Use sum of Hessian as a heuristic to select node with fewest training instances - bool fewer_right = candidate.split.right_sum.GetHess() < candidate.split.left_sum.GetHess(); - if (fewer_right) { - std::swap(build_hist_nidx, subtraction_trick_nidx); + void BuildHistLeftRight(std::vector const& candidates, dh::AllReducer* reducer, + const RegTree& tree) { + if (candidates.empty()) return; + // Some nodes we will manually compute histograms + // others we will do by subtraction + std::vector hist_nidx; + std::vector subtraction_nidx; + for (auto& e : candidates) { + // Decide whether to build the left histogram or right histogram + // Use sum of Hessian as a heuristic to select node with fewest training instances + bool fewer_right = e.split.right_sum.GetHess() < e.split.left_sum.GetHess(); + if (fewer_right) { + hist_nidx.emplace_back(tree[e.nid].RightChild()); + subtraction_nidx.emplace_back(tree[e.nid].LeftChild()); + } else { + hist_nidx.emplace_back(tree[e.nid].LeftChild()); + subtraction_nidx.emplace_back(tree[e.nid].RightChild()); + } + } + std::vector all_new = hist_nidx; + all_new.insert(all_new.end(), subtraction_nidx.begin(), subtraction_nidx.end()); + // Allocate the histograms + // Guaranteed contiguous memory + hist.AllocateHistograms(all_new); + + for (auto nidx : hist_nidx) { + this->BuildHist(nidx); } - this->BuildHist(build_hist_nidx); - this->AllReduceHist(build_hist_nidx, reducer); + // Reduce all in one go + // This gives much better latency in a distributed setting + // when processing a large batch + this->AllReduceHist(hist_nidx.at(0), reducer, hist_nidx.size()); - // Check whether we can use the subtraction trick to calculate the other - bool do_subtraction_trick = this->CanDoSubtractionTrick( - candidate.nid, build_hist_nidx, subtraction_trick_nidx); + for (int i = 0; i < subtraction_nidx.size(); i++) { + auto build_hist_nidx = hist_nidx.at(i); + auto subtraction_trick_nidx = subtraction_nidx.at(i); + auto parent_nidx = candidates.at(i).nid; - if (do_subtraction_trick) { - // Calculate other histogram using subtraction trick - this->SubtractionTrick(candidate.nid, build_hist_nidx, - subtraction_trick_nidx); - } else { - // Calculate other histogram manually - this->BuildHist(subtraction_trick_nidx); - this->AllReduceHist(subtraction_trick_nidx, reducer); + if (!this->SubtractionTrick(parent_nidx, build_hist_nidx, subtraction_trick_nidx)) { + // Calculate other histogram manually + this->BuildHist(subtraction_trick_nidx); + this->AllReduceHist(subtraction_trick_nidx, reducer, 1); + } } } @@ -605,8 +633,9 @@ struct GPUHistMakerDevice { GradientPairPrecise{}, thrust::plus{}); rabit::Allreduce(reinterpret_cast(&root_sum), 2); + hist.AllocateHistograms({kRootNIdx}); this->BuildHist(kRootNIdx); - this->AllReduceHist(kRootNIdx, reducer); + this->AllReduceHist(kRootNIdx, reducer, 1); // Remember root stats node_sum_gradients[kRootNIdx] = root_sum; @@ -646,6 +675,7 @@ struct GPUHistMakerDevice { std::copy_if(expand_set.begin(), expand_set.end(), std::back_inserter(filtered_expand_set), [&](const auto& e) { return driver.IsChildValid(e); }); + auto new_candidates = pinned.GetSpan(filtered_expand_set.size() * 2, GPUExpandEntry()); @@ -658,18 +688,12 @@ struct GPUHistMakerDevice { monitor.Stop("UpdatePosition"); } - for (auto i = 0ull; i < filtered_expand_set.size(); i++) { - auto candidate = expand_set.at(i); - int left_child_nidx = tree[candidate.nid].LeftChild(); - int right_child_nidx = tree[candidate.nid].RightChild(); - - monitor.Start("BuildHist"); - this->BuildHistLeftRight(candidate, left_child_nidx, right_child_nidx, reducer); - monitor.Stop("BuildHist"); - } + monitor.Start("BuildHist"); + this->BuildHistLeftRight(filtered_expand_set, reducer, tree); + monitor.Stop("BuildHist"); for (auto i = 0ull; i < filtered_expand_set.size(); i++) { - auto candidate = expand_set.at(i); + auto candidate = filtered_expand_set.at(i); int left_child_nidx = tree[candidate.nid].LeftChild(); int right_child_nidx = tree[candidate.nid].RightChild(); diff --git a/tests/cpp/tree/gpu_hist/test_histogram.cu b/tests/cpp/tree/gpu_hist/test_histogram.cu index 3b543a48d7cc..75d97b681a61 100644 --- a/tests/cpp/tree/gpu_hist/test_histogram.cu +++ b/tests/cpp/tree/gpu_hist/test_histogram.cu @@ -95,7 +95,6 @@ TEST(Histogram, GPUDeterministic) { std::vector shm_sizes{48 * 1024, 64 * 1024, 160 * 1024}; for (bool is_dense : is_dense_array) { for (int shm_size : shm_sizes) { - TestDeterministicHistogram(is_dense, shm_size); TestDeterministicHistogram(is_dense, shm_size); } } diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index b3c08736c996..be51d3cc5e31 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -29,29 +29,38 @@ TEST(GpuHist, DeviceHistogram) { constexpr size_t kNBins = 128; constexpr size_t kNNodes = 4; constexpr size_t kStopGrowing = kNNodes * kNBins * 2u; - DeviceHistogram histogram; + DeviceHistogramStorage histogram; histogram.Init(0, kNBins); - for (size_t i = 0; i < kNNodes; ++i) { - histogram.AllocateHistogram(i); + for (int i = 0; i < kNNodes; ++i) { + histogram.AllocateHistograms({i}); } histogram.Reset(); ASSERT_EQ(histogram.Data().size(), kStopGrowing); // Use allocated memory but do not erase nidx_map. - for (size_t i = 0; i < kNNodes; ++i) { - histogram.AllocateHistogram(i); + for (int i = 0; i < kNNodes; ++i) { + histogram.AllocateHistograms({i}); } - for (size_t i = 0; i < kNNodes; ++i) { + for (int i = 0; i < kNNodes; ++i) { ASSERT_TRUE(histogram.HistogramExists(i)); } - // Erase existing nidx_map. - for (size_t i = kNNodes; i < kNNodes * 2; ++i) { - histogram.AllocateHistogram(i); - } - for (size_t i = 0; i < kNNodes; ++i) { - ASSERT_FALSE(histogram.HistogramExists(i)); + // Add two new nodes + histogram.AllocateHistograms({kNNodes}); + histogram.AllocateHistograms({kNNodes+1}); + + // Old cached nodes should still exist + for (int i = 0; i < kNNodes; ++i) { + ASSERT_TRUE(histogram.HistogramExists(i)); } + + // Should be deleted + ASSERT_FALSE(histogram.HistogramExists({kNNodes})); + // Most recent node should exist + ASSERT_TRUE(histogram.HistogramExists({kNNodes + 1})); + + // Add same node again - should fail + EXPECT_ANY_THROW(histogram.AllocateHistograms({kNNodes+1});); } std::vector GetHostHistGpair() { @@ -96,9 +105,9 @@ void TestBuildHist(bool use_shared_memory_histograms) { thrust::host_vector h_gidx_buffer (page->gidx_buffer.HostVector()); maker.row_partitioner.reset(new RowPartitioner(0, kNRows)); - maker.hist.AllocateHistogram(0); + maker.hist.AllocateHistograms({0}); maker.gpair = gpair.DeviceSpan(); - maker.histogram_rounding = CreateRoundingFactor(maker.gpair);; + maker.histogram_rounding = CreateRoundingFactor(maker.gpair); BuildGradientHistogram( page->GetDeviceAccessor(0), maker.feature_groups->DeviceAccessor(0), @@ -106,7 +115,7 @@ void TestBuildHist(bool use_shared_memory_histograms) { maker.hist.GetNodeHistogram(0), maker.histogram_rounding, !use_shared_memory_histograms); - DeviceHistogram& d_hist = maker.hist; + DeviceHistogramStorage& d_hist = maker.hist; auto node_histogram = d_hist.GetNodeHistogram(0); // d_hist.data stored in float, not gradient pair @@ -129,12 +138,10 @@ void TestBuildHist(bool use_shared_memory_histograms) { TEST(GpuHist, BuildHistGlobalMem) { TestBuildHist(false); - TestBuildHist(false); } TEST(GpuHist, BuildHistSharedMem) { TestBuildHist(true); - TestBuildHist(true); } HistogramCutsWrapper GetHostCutMatrix () { @@ -198,7 +205,7 @@ TEST(GpuHist, EvaluateRootSplit) { // Initialize GPUHistMakerDevice::hist maker.hist.Init(0, (max_bins - 1) * kNCols); - maker.hist.AllocateHistogram(0); + maker.hist.AllocateHistograms({0}); // Each row of hist_gpair represents gpairs for one feature. // Each entry represents a bin. std::vector hist_gpair = GetHostHistGpair(); diff --git a/tests/python-gpu/test_gpu_updaters.py b/tests/python-gpu/test_gpu_updaters.py index 257085b0c8f9..8748ddcbdf91 100644 --- a/tests/python-gpu/test_gpu_updaters.py +++ b/tests/python-gpu/test_gpu_updaters.py @@ -3,7 +3,7 @@ import gc import pytest import xgboost as xgb -from hypothesis import given, strategies, assume, settings, note, reproduce_failure +from hypothesis import given, strategies, assume, settings, note sys.path.append("tests/python") import testing as tm From 1dd1a6cc1c74a45dcb986546f0fab753d359c70b Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Tue, 10 May 2022 05:07:31 -0700 Subject: [PATCH 11/14] Limit concurrent nodes --- src/tree/driver.h | 10 ++++++---- src/tree/updater_gpu_hist.cu | 5 ++++- tests/cpp/tree/gpu_hist/test_driver.cu | 18 +++++++++++++----- 3 files changed, 23 insertions(+), 10 deletions(-) diff --git a/src/tree/driver.h b/src/tree/driver.h index e61255e043c7..0aef93ccf9cd 100644 --- a/src/tree/driver.h +++ b/src/tree/driver.h @@ -33,10 +33,11 @@ class Driver { std::function>; public: - explicit Driver(TrainParam param) + explicit Driver(TrainParam param, std::size_t max_node_batch_size = 256) : param_(param), - queue_(param.grow_policy == TrainParam::kDepthWise ? DepthWise : - LossGuide) {} + max_node_batch_size_(max_node_batch_size), + queue_(param.grow_policy == TrainParam::kDepthWise ? DepthWise + : LossGuide) {} template void Push(EntryIterT begin, EntryIterT end) { for (auto it = begin; it != end; ++it) { @@ -84,7 +85,7 @@ class Driver { std::vector result; ExpandEntryT e = queue_.top(); int level = e.depth; - while (e.depth == level && !queue_.empty()) { + while (e.depth == level && !queue_.empty() && result.size() < max_node_batch_size_) { queue_.pop(); if (e.IsValid(param_, num_leaves_)) { num_leaves_++; @@ -101,6 +102,7 @@ class Driver { private: TrainParam param_; std::size_t num_leaves_ = 1; + std::size_t max_node_batch_size_; ExpandQueue queue_; }; } // namespace tree diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 964a486baf16..eb10b42fc2fa 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -90,6 +90,7 @@ class DeviceHistogramStorage { sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT); static_assert(kNumItemsInGradientSum == 2, "Number of items in gradient type should be 2."); + public: // Start with about 16mb @@ -206,6 +207,7 @@ struct GPUHistMakerDevice { std::unique_ptr feature_groups; + GPUHistMakerDevice(Context const* ctx, EllpackPageImpl const* _page, common::Span _feature_types, bst_uint _n_rows, TrainParam _param, uint32_t column_sampler_seed, uint32_t n_features, @@ -653,7 +655,8 @@ struct GPUHistMakerDevice { RegTree* p_tree, dh::AllReducer* reducer, HostDeviceVector* p_out_position) { auto& tree = *p_tree; - Driver driver(param); + // Process maximum 32 nodes at a time + Driver driver(param, 32); monitor.Start("Reset"); this->Reset(gpair_all, p_fmat, p_fmat->Info().num_col_); diff --git a/tests/cpp/tree/gpu_hist/test_driver.cu b/tests/cpp/tree/gpu_hist/test_driver.cu index d7f8cc63869e..8e7164e37bec 100644 --- a/tests/cpp/tree/gpu_hist/test_driver.cu +++ b/tests/cpp/tree/gpu_hist/test_driver.cu @@ -8,8 +8,8 @@ namespace tree { TEST(GpuHist, DriverDepthWise) { TrainParam p; p.InitAllowUnknown(Args{}); - p.grow_policy=TrainParam::kDepthWise; - Driver driver(p); + p.grow_policy = TrainParam::kDepthWise; + Driver driver(p, 2); EXPECT_TRUE(driver.Pop().empty()); DeviceSplitCandidate split; split.loss_chg = 1.0f; @@ -20,15 +20,23 @@ TEST(GpuHist, DriverDepthWise) { EXPECT_EQ(driver.Pop().front().nid, 0); driver.Push({GPUExpandEntry{1, 1, split, 2.0f, 1.0f, 1.0f}}); driver.Push({GPUExpandEntry{2, 1, split, 2.0f, 1.0f, 1.0f}}); - driver.Push({GPUExpandEntry{3, 2, split, 2.0f, 1.0f, 1.0f}}); - // Should return entries from level 1 + driver.Push({GPUExpandEntry{3, 1, split, 2.0f, 1.0f, 1.0f}}); + driver.Push({GPUExpandEntry{4, 2, split, 2.0f, 1.0f, 1.0f}}); + // Should return 2 entries from level 1 + // as we limited the driver to pop maximum 2 nodes auto res = driver.Pop(); EXPECT_EQ(res.size(), 2); for (auto &e : res) { EXPECT_EQ(e.depth, 1); } + + // Should now return 1 entry from level 1 + res = driver.Pop(); + EXPECT_EQ(res.size(), 1); + EXPECT_EQ(res.at(0).depth, 1); + res = driver.Pop(); - EXPECT_EQ(res[0].depth, 2); + EXPECT_EQ(res.at(0).depth, 2); EXPECT_TRUE(driver.Pop().empty()); } From 8751d14956d3a85ef0aaef40f223cfe485539973 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Wed, 11 May 2022 04:17:36 -0700 Subject: [PATCH 12/14] Lint --- src/tree/updater_gpu_hist.cu | 11 ++++------- tests/cpp/tree/test_gpu_hist.cu | 4 ++-- 2 files changed, 6 insertions(+), 9 deletions(-) diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index eb10b42fc2fa..88978142ee2e 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -88,9 +88,7 @@ class DeviceHistogramStorage { int device_id_; static constexpr size_t kNumItemsInGradientSum = sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT); - static_assert(kNumItemsInGradientSum == 2, - "Number of items in gradient type should be 2."); - + static_assert(kNumItemsInGradientSum == 2, "Number of items in gradient type should be 2."); public: // Start with about 16mb @@ -107,11 +105,10 @@ class DeviceHistogramStorage { overflow_nidx_map_.clear(); } bool HistogramExists(int nidx) const { - return nidx_map_.find(nidx) != nidx_map_.cend() || overflow_nidx_map_.find(nidx) != overflow_nidx_map_.cend(); - } - int Bins() const { - return n_bins_; + return nidx_map_.find(nidx) != nidx_map_.cend() || + overflow_nidx_map_.find(nidx) != overflow_nidx_map_.cend(); } + int Bins() const { return n_bins_; } size_t HistogramSize() const { return n_bins_ * kNumItemsInGradientSum; } dh::device_vector& Data() { return data_; } diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index be51d3cc5e31..7d06d1731c5a 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -27,7 +27,7 @@ TEST(GpuHist, DeviceHistogram) { // Ensures that node allocates correctly after reaching `kStopGrowingSize`. dh::safe_cuda(cudaSetDevice(0)); constexpr size_t kNBins = 128; - constexpr size_t kNNodes = 4; + constexpr int kNNodes = 4; constexpr size_t kStopGrowing = kNNodes * kNBins * 2u; DeviceHistogramStorage histogram; histogram.Init(0, kNBins); @@ -47,7 +47,7 @@ TEST(GpuHist, DeviceHistogram) { // Add two new nodes histogram.AllocateHistograms({kNNodes}); - histogram.AllocateHistograms({kNNodes+1}); + histogram.AllocateHistograms({kNNodes + 1}); // Old cached nodes should still exist for (int i = 0; i < kNNodes; ++i) { From fd839b40688496c6ff9626a120c9a7b4db9effbb Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Tue, 17 May 2022 02:20:22 -0700 Subject: [PATCH 13/14] Lint --- tests/cpp/tree/test_gpu_hist.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 7d06d1731c5a..190d9aa687cd 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -55,12 +55,12 @@ TEST(GpuHist, DeviceHistogram) { } // Should be deleted - ASSERT_FALSE(histogram.HistogramExists({kNNodes})); + ASSERT_FALSE(histogram.HistogramExists(kNNodes)); // Most recent node should exist - ASSERT_TRUE(histogram.HistogramExists({kNNodes + 1})); + ASSERT_TRUE(histogram.HistogramExists(kNNodes + 1)); // Add same node again - should fail - EXPECT_ANY_THROW(histogram.AllocateHistograms({kNNodes+1});); + EXPECT_ANY_THROW(histogram.AllocateHistograms(kNNodes + 1);); } std::vector GetHostHistGpair() { From 5ecb3d8ef00b6940c75cd7209b3ffa6c6216499a Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Tue, 17 May 2022 02:48:51 -0700 Subject: [PATCH 14/14] Lint --- tests/cpp/tree/test_gpu_hist.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 190d9aa687cd..e6069cdfdd4d 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -60,7 +60,7 @@ TEST(GpuHist, DeviceHistogram) { ASSERT_TRUE(histogram.HistogramExists(kNNodes + 1)); // Add same node again - should fail - EXPECT_ANY_THROW(histogram.AllocateHistograms(kNNodes + 1);); + EXPECT_ANY_THROW(histogram.AllocateHistograms({kNNodes + 1});); } std::vector GetHostHistGpair() {