From 2b4cf67cd12672c0c8bca04549f0c65400ec6d0d Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Thu, 21 Apr 2022 03:10:55 -0700 Subject: [PATCH 1/4] Remove single_precision_histogram --- doc/gpu/index.rst | 4 +- doc/parameter.rst | 2 +- include/xgboost/tree_updater.h | 6 +- src/tree/tree_updater.cc | 3 +- src/tree/updater_approx.cc | 9 +- src/tree/updater_basemaker-inl.h | 7 +- src/tree/updater_colmaker.cc | 7 +- src/tree/updater_gpu_hist.cu | 162 +++++++--------------- src/tree/updater_histmaker.cc | 20 ++- src/tree/updater_prune.cc | 10 +- src/tree/updater_quantile_hist.cc | 4 +- src/tree/updater_quantile_hist.h | 3 +- src/tree/updater_refresh.cc | 13 +- src/tree/updater_sync.cc | 9 +- tests/cpp/tree/test_gpu_hist.cu | 16 +-- tests/python-gpu/test_gpu_basic_models.py | 4 +- tests/python-gpu/test_gpu_updaters.py | 3 +- 17 files changed, 107 insertions(+), 175 deletions(-) diff --git a/doc/gpu/index.rst b/doc/gpu/index.rst index e36fc72a1746..049cf311dff2 100644 --- a/doc/gpu/index.rst +++ b/doc/gpu/index.rst @@ -59,13 +59,11 @@ Supported parameters +--------------------------------+--------------+ | ``interaction_constraints`` | |tick| | +--------------------------------+--------------+ -| ``single_precision_histogram`` | |tick| | +| ``single_precision_histogram`` | |cross| | +--------------------------------+--------------+ GPU accelerated prediction is enabled by default for the above mentioned ``tree_method`` parameters but can be switched to CPU prediction by setting ``predictor`` to ``cpu_predictor``. This could be useful if you want to conserve GPU memory. Likewise when using CPU algorithms, GPU accelerated prediction can be enabled by setting ``predictor`` to ``gpu_predictor``. -The experimental parameter ``single_precision_histogram`` can be set to True to enable building histograms using single precision. This may improve speed, in particular on older architectures. - The device ordinal (which GPU to use if you have many of them) can be selected using the ``gpu_id`` parameter, which defaults to 0 (the first device reported by CUDA runtime). diff --git a/doc/parameter.rst b/doc/parameter.rst index 781150490082..4392b5bf7680 100644 --- a/doc/parameter.rst +++ b/doc/parameter.rst @@ -240,7 +240,7 @@ Additional parameters for ``hist``, ``gpu_hist`` and ``approx`` tree method * ``single_precision_histogram``, [default= ``false``] - - Use single precision to build histograms instead of double precision. + - Use single precision to build histograms instead of double precision. Currently disabled for ``gpu_hist``. * ``max_cat_to_onehot`` diff --git a/include/xgboost/tree_updater.h b/include/xgboost/tree_updater.h index 6189221dc0bf..6248a65e270d 100644 --- a/include/xgboost/tree_updater.h +++ b/include/xgboost/tree_updater.h @@ -35,6 +35,7 @@ class TreeUpdater : public Configurable { GenericParameter const* ctx_ = nullptr; public: + explicit TreeUpdater(const GenericParameter* ctx) : ctx_(ctx) {} /*! \brief virtual destructor */ ~TreeUpdater() override = default; /*! @@ -91,8 +92,9 @@ class TreeUpdater : public Configurable { * \brief Registry entry for tree updater. */ struct TreeUpdaterReg - : public dmlc::FunctionRegEntryBase > {}; + : public dmlc::FunctionRegEntryBase< + TreeUpdaterReg, + std::function > {}; /*! * \brief Macro to register tree updater. diff --git a/src/tree/tree_updater.cc b/src/tree/tree_updater.cc index 05f6c4bb5fd6..ee5659636305 100644 --- a/src/tree/tree_updater.cc +++ b/src/tree/tree_updater.cc @@ -20,8 +20,7 @@ TreeUpdater* TreeUpdater::Create(const std::string& name, GenericParameter const if (e == nullptr) { LOG(FATAL) << "Unknown tree updater " << name; } - auto p_updater = (e->body)(task); - p_updater->ctx_ = tparam; + auto p_updater = (e->body)(tparam, task); return p_updater; } diff --git a/src/tree/updater_approx.cc b/src/tree/updater_approx.cc index 3bad6f7da4cc..a06f195374b6 100644 --- a/src/tree/updater_approx.cc +++ b/src/tree/updater_approx.cc @@ -252,7 +252,10 @@ class GlobalApproxUpdater : public TreeUpdater { ObjInfo task_; public: - explicit GlobalApproxUpdater(ObjInfo task) : task_{task} { monitor_.Init(__func__); } + explicit GlobalApproxUpdater(GenericParameter const *ctx, ObjInfo task) + : task_{task}, TreeUpdater(ctx) { + monitor_.Init(__func__); + } void Configure(const Args &args) override { param_.UpdateAllowUnknown(args); @@ -343,6 +346,8 @@ XGBOOST_REGISTER_TREE_UPDATER(GlobalHistMaker, "grow_histmaker") .describe( "Tree constructor that uses approximate histogram construction " "for each node.") - .set_body([](ObjInfo task) { return new GlobalApproxUpdater(task); }); + .set_body([](GenericParameter const *ctx, ObjInfo task) { + return new GlobalApproxUpdater(ctx, task); + }); } // namespace tree } // namespace xgboost diff --git a/src/tree/updater_basemaker-inl.h b/src/tree/updater_basemaker-inl.h index da239b2090c7..7fc44a6d15fb 100644 --- a/src/tree/updater_basemaker-inl.h +++ b/src/tree/updater_basemaker-inl.h @@ -33,11 +33,10 @@ namespace tree { * \brief base tree maker class that defines common operation * needed in tree making */ -class BaseMaker: public TreeUpdater { +class BaseMaker : public TreeUpdater { public: - void Configure(const Args& args) override { - param_.UpdateAllowUnknown(args); - } + explicit BaseMaker(GenericParameter const *ctx) : TreeUpdater(ctx) {} + void Configure(const Args &args) override { param_.UpdateAllowUnknown(args); } void LoadConfig(Json const& in) override { auto const& config = get(in); diff --git a/src/tree/updater_colmaker.cc b/src/tree/updater_colmaker.cc index e3d716f2cba8..f4279a0a1c3b 100644 --- a/src/tree/updater_colmaker.cc +++ b/src/tree/updater_colmaker.cc @@ -57,7 +57,8 @@ DMLC_REGISTER_PARAMETER(ColMakerTrainParam); /*! \brief column-wise update to construct a tree */ class ColMaker: public TreeUpdater { public: - void Configure(const Args& args) override { + explicit ColMaker(GenericParameter const *ctx) : TreeUpdater(ctx) {} + void Configure(const Args &args) override { param_.UpdateAllowUnknown(args); colmaker_param_.UpdateAllowUnknown(args); } @@ -614,8 +615,8 @@ class ColMaker: public TreeUpdater { XGBOOST_REGISTER_TREE_UPDATER(ColMaker, "grow_colmaker") .describe("Grow tree with parallelization over columns.") -.set_body([](ObjInfo) { - return new ColMaker(); +.set_body([](GenericParameter const* ctx, ObjInfo) { + return new ColMaker(ctx); }); } // namespace tree } // namespace xgboost diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index cb7dd9b7e8e4..2cac6b6c4f4a 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -45,12 +45,9 @@ DMLC_REGISTRY_FILE_TAG(updater_gpu_hist); // training parameters specific to this algorithm struct GPUHistMakerTrainParam : public XGBoostParameter { - bool single_precision_histogram; bool debug_synchronize; // declare parameters DMLC_DECLARE_PARAMETER(GPUHistMakerTrainParam) { - DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe( - "Use single precision to build histograms."); DMLC_DECLARE_FIELD(debug_synchronize).set_default(false).describe( "Check if all distributed tree are identical after tree construction."); } @@ -532,6 +529,13 @@ struct GPUHistMakerDevice { void ApplySplit(const GPUExpandEntry& candidate, RegTree* p_tree) { RegTree& tree = *p_tree; + + // Sanity check - have we created a leaf with no training instances? + if (!rabit::IsDistributed()) { + CHECK(row_partitioner->GetRows(candidate.nid).size() > 0) + << "No training instances in this leaf!"; + } + auto parent_sum = candidate.split.left_sum + candidate.split.right_sum; auto base_weight = candidate.base_weight; auto left_weight = candidate.left_weight * param.learning_rate; @@ -676,20 +680,35 @@ struct GPUHistMakerDevice { } }; -template -class GPUHistMakerSpecialised { +class GPUHistMaker : public TreeUpdater { + using GradientSumT = GradientPairPrecise; + public: - explicit GPUHistMakerSpecialised(ObjInfo task) : task_{task} {}; - void Configure(const Args& args, GenericParameter const* generic_param) { + explicit GPUHistMaker(GenericParameter const* ctx, ObjInfo task) + : TreeUpdater(ctx), task_{task} {}; + void Configure(const Args& args) { + // Used in test to count how many configurations are performed + LOG(DEBUG) << "[GPU Hist]: Configure"; param_.UpdateAllowUnknown(args); - generic_param_ = generic_param; hist_maker_param_.UpdateAllowUnknown(args); dh::CheckComputeCapability(); monitor_.Init("updater_gpu_hist"); } - ~GPUHistMakerSpecialised() { // NOLINT + void LoadConfig(Json const& in) override { + auto const& config = get(in); + FromJson(config.at("gpu_hist_train_param"), &this->hist_maker_param_); + initialised_ = false; + FromJson(config.at("train_param"), ¶m_); + } + void SaveConfig(Json* p_out) const override { + auto& out = *p_out; + out["gpu_hist_train_param"] = ToJson(hist_maker_param_); + out["train_param"] = ToJson(param_); + } + + ~GPUHistMaker() { // NOLINT dh::GlobalMemoryLogger().Log(); } @@ -719,30 +738,24 @@ class GPUHistMakerSpecialised { } void InitDataOnce(DMatrix* dmat) { - device_ = generic_param_->gpu_id; - CHECK_GE(device_, 0) << "Must have at least one device"; + CHECK_GE(ctx_->gpu_id, 0) << "Must have at least one device"; info_ = &dmat->Info(); - reducer_.Init({device_}); // NOLINT + reducer_.Init({ctx_->gpu_id}); // NOLINT // Synchronise the column sampling seed uint32_t column_sampling_seed = common::GlobalRandom()(); rabit::Broadcast(&column_sampling_seed, sizeof(column_sampling_seed), 0); BatchParam batch_param{ - device_, - param_.max_bin, + ctx_->gpu_id, + param_.max_bin, }; auto page = (*dmat->GetBatches(batch_param).begin()).Impl(); - dh::safe_cuda(cudaSetDevice(device_)); - info_->feature_types.SetDevice(device_); - maker.reset(new GPUHistMakerDevice(device_, - page, - info_->feature_types.ConstDeviceSpan(), - info_->num_row_, - param_, - column_sampling_seed, - info_->num_col_, - batch_param)); + dh::safe_cuda(cudaSetDevice(ctx_->gpu_id)); + info_->feature_types.SetDevice(ctx_->gpu_id); + maker.reset(new GPUHistMakerDevice( + ctx_->gpu_id, page, info_->feature_types.ConstDeviceSpan(), info_->num_row_, param_, + column_sampling_seed, info_->num_col_, batch_param)); p_last_fmat_ = dmat; initialised_ = true; @@ -766,7 +779,7 @@ class GPUHistMakerSpecialised { } fs.Seek(0); rabit::Broadcast(&s_model, 0); - RegTree reference_tree {}; // rank 0 tree + RegTree reference_tree{}; // rank 0 tree reference_tree.Load(&fs); CHECK(*local_tree == reference_tree); } @@ -775,13 +788,11 @@ class GPUHistMakerSpecialised { monitor_.Start("InitData"); this->InitData(p_fmat); monitor_.Stop("InitData"); - - gpair->SetDevice(device_); + gpair->SetDevice(ctx_->gpu_id); maker->UpdateTree(gpair, p_fmat, task_, p_tree, &reducer_); } - bool UpdatePredictionCache(const DMatrix *data, - linalg::VectorView p_out_preds) { + bool UpdatePredictionCache(const DMatrix* data, linalg::VectorView p_out_preds) { if (maker == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) { return false; } @@ -791,107 +802,32 @@ class GPUHistMakerSpecialised { return true; } - TrainParam param_; // NOLINT - MetaInfo* info_{}; // NOLINT + TrainParam param_; // NOLINT + MetaInfo* info_{}; // NOLINT std::unique_ptr> maker; // NOLINT + char const* Name() const override { return "grow_gpu_hist"; } + private: - bool initialised_ { false }; + bool initialised_{false}; GPUHistMakerTrainParam hist_maker_param_; - GenericParameter const* generic_param_; dh::AllReducer reducer_; - DMatrix* p_last_fmat_ { nullptr }; - int device_{-1}; + DMatrix* p_last_fmat_{nullptr}; ObjInfo task_; common::Monitor monitor_; }; -class GPUHistMaker : public TreeUpdater { - public: - explicit GPUHistMaker(ObjInfo task) : task_{task} {} - void Configure(const Args& args) override { - // Used in test to count how many configurations are performed - LOG(DEBUG) << "[GPU Hist]: Configure"; - hist_maker_param_.UpdateAllowUnknown(args); - // The passed in args can be empty, if we simply purge the old maker without - // preserving parameters then we can't do Update on it. - TrainParam param; - if (float_maker_) { - param = float_maker_->param_; - } else if (double_maker_) { - param = double_maker_->param_; - } - if (hist_maker_param_.single_precision_histogram) { - float_maker_.reset(new GPUHistMakerSpecialised(task_)); - float_maker_->param_ = param; - float_maker_->Configure(args, ctx_); - } else { - double_maker_.reset(new GPUHistMakerSpecialised(task_)); - double_maker_->param_ = param; - double_maker_->Configure(args, ctx_); - } - } - - void LoadConfig(Json const& in) override { - auto const& config = get(in); - FromJson(config.at("gpu_hist_train_param"), &this->hist_maker_param_); - if (hist_maker_param_.single_precision_histogram) { - float_maker_.reset(new GPUHistMakerSpecialised(task_)); - FromJson(config.at("train_param"), &float_maker_->param_); - } else { - double_maker_.reset(new GPUHistMakerSpecialised(task_)); - FromJson(config.at("train_param"), &double_maker_->param_); - } - } - void SaveConfig(Json* p_out) const override { - auto& out = *p_out; - out["gpu_hist_train_param"] = ToJson(hist_maker_param_); - if (hist_maker_param_.single_precision_histogram) { - out["train_param"] = ToJson(float_maker_->param_); - } else { - out["train_param"] = ToJson(double_maker_->param_); - } - } - - void Update(HostDeviceVector* gpair, DMatrix* dmat, - const std::vector& trees) override { - if (hist_maker_param_.single_precision_histogram) { - float_maker_->Update(gpair, dmat, trees); - } else { - double_maker_->Update(gpair, dmat, trees); - } - } - - bool - UpdatePredictionCache(const DMatrix *data, - linalg::VectorView p_out_preds) override { - if (hist_maker_param_.single_precision_histogram) { - return float_maker_->UpdatePredictionCache(data, p_out_preds); - } else { - return double_maker_->UpdatePredictionCache(data, p_out_preds); - } - } - - char const* Name() const override { - return "grow_gpu_hist"; - } - - private: - GPUHistMakerTrainParam hist_maker_param_; - ObjInfo task_; - std::unique_ptr> float_maker_; - std::unique_ptr> double_maker_; -}; - #if !defined(GTEST_TEST) XGBOOST_REGISTER_TREE_UPDATER(GPUHistMaker, "grow_gpu_hist") .describe("Grow tree with GPU.") - .set_body([](ObjInfo task) { return new GPUHistMaker(task); }); + .set_body([](GenericParameter const* tparam, ObjInfo task) { + return new GPUHistMaker(tparam, task); + }); #endif // !defined(GTEST_TEST) } // namespace tree diff --git a/src/tree/updater_histmaker.cc b/src/tree/updater_histmaker.cc index 0a85d2d73832..9d36e4d16c0a 100644 --- a/src/tree/updater_histmaker.cc +++ b/src/tree/updater_histmaker.cc @@ -24,9 +24,9 @@ DMLC_REGISTRY_FILE_TAG(updater_histmaker); class HistMaker: public BaseMaker { public: - void Update(HostDeviceVector *gpair, - DMatrix *p_fmat, - const std::vector &trees) override { + explicit HistMaker(GenericParameter const *ctx) : BaseMaker(ctx) {} + void Update(HostDeviceVector *gpair, DMatrix *p_fmat, + const std::vector &trees) override { interaction_constraints_.Configure(param_, p_fmat->Info().num_col_); // rescale learning rate according to size of trees float lr = param_.learning_rate; @@ -262,12 +262,10 @@ class HistMaker: public BaseMaker { } }; -class CQHistMaker: public HistMaker { +class CQHistMaker : public HistMaker { public: - CQHistMaker() = default; - char const* Name() const override { - return "grow_local_histmaker"; - } + explicit CQHistMaker(GenericParameter const *ctx) : HistMaker(ctx) {} + char const *Name() const override { return "grow_local_histmaker"; } protected: struct HistEntry { @@ -624,9 +622,7 @@ class CQHistMaker: public HistMaker { }; XGBOOST_REGISTER_TREE_UPDATER(LocalHistMaker, "grow_local_histmaker") -.describe("Tree constructor that uses approximate histogram construction.") -.set_body([](ObjInfo) { - return new CQHistMaker(); - }); + .describe("Tree constructor that uses approximate histogram construction.") + .set_body([](GenericParameter const *ctx, ObjInfo) { return new CQHistMaker(ctx); }); } // namespace tree } // namespace xgboost diff --git a/src/tree/updater_prune.cc b/src/tree/updater_prune.cc index f71f1c698cb9..9e6fad883040 100644 --- a/src/tree/updater_prune.cc +++ b/src/tree/updater_prune.cc @@ -21,9 +21,9 @@ namespace tree { DMLC_REGISTRY_FILE_TAG(updater_prune); /*! \brief pruner that prunes a tree after growing finishes */ -class TreePruner: public TreeUpdater { +class TreePruner : public TreeUpdater { public: - explicit TreePruner(ObjInfo task) { + explicit TreePruner(GenericParameter const* ctx, ObjInfo task) : TreeUpdater(ctx) { syncher_.reset(TreeUpdater::Create("sync", ctx_, task)); pruner_monitor_.Init("TreePruner"); } @@ -112,9 +112,7 @@ class TreePruner: public TreeUpdater { }; XGBOOST_REGISTER_TREE_UPDATER(TreePruner, "prune") -.describe("Pruner that prune the tree according to statistics.") -.set_body([](ObjInfo task) { - return new TreePruner(task); - }); + .describe("Pruner that prune the tree according to statistics.") + .set_body([](GenericParameter const* ctx, ObjInfo task) { return new TreePruner(ctx, task); }); } // namespace tree } // namespace xgboost diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 0e1b6db47691..dcbb3dbfba3e 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -390,6 +390,8 @@ template struct QuantileHistMaker::Builder; XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker") .describe("Grow tree using quantized histogram.") - .set_body([](ObjInfo task) { return new QuantileHistMaker(task); }); + .set_body([](GenericParameter const *ctx, ObjInfo task) { + return new QuantileHistMaker(ctx, task); + }); } // namespace tree } // namespace xgboost diff --git a/src/tree/updater_quantile_hist.h b/src/tree/updater_quantile_hist.h index 3c03a371ebfb..463c7a54ab39 100644 --- a/src/tree/updater_quantile_hist.h +++ b/src/tree/updater_quantile_hist.h @@ -225,7 +225,8 @@ inline BatchParam HistBatch(TrainParam const& param) { /*! \brief construct a tree using quantized feature values */ class QuantileHistMaker: public TreeUpdater { public: - explicit QuantileHistMaker(ObjInfo task) : task_{task} {} + explicit QuantileHistMaker(GenericParameter const* ctx, ObjInfo task) + : task_{task}, TreeUpdater(ctx) {} void Configure(const Args& args) override; void Update(HostDeviceVector* gpair, diff --git a/src/tree/updater_refresh.cc b/src/tree/updater_refresh.cc index d17c1e1444f7..6110e964f891 100644 --- a/src/tree/updater_refresh.cc +++ b/src/tree/updater_refresh.cc @@ -22,11 +22,10 @@ namespace tree { DMLC_REGISTRY_FILE_TAG(updater_refresh); /*! \brief pruner that prunes a tree after growing finishs */ -class TreeRefresher: public TreeUpdater { +class TreeRefresher : public TreeUpdater { public: - void Configure(const Args& args) override { - param_.UpdateAllowUnknown(args); - } + explicit TreeRefresher(GenericParameter const *ctx) : TreeUpdater(ctx) {} + void Configure(const Args &args) override { param_.UpdateAllowUnknown(args); } void LoadConfig(Json const& in) override { auto const& config = get(in); FromJson(config.at("train_param"), &this->param_); @@ -160,9 +159,7 @@ class TreeRefresher: public TreeUpdater { }; XGBOOST_REGISTER_TREE_UPDATER(TreeRefresher, "refresh") -.describe("Refresher that refreshes the weight and statistics according to data.") -.set_body([](ObjInfo) { - return new TreeRefresher(); - }); + .describe("Refresher that refreshes the weight and statistics according to data.") + .set_body([](GenericParameter const *ctx, ObjInfo) { return new TreeRefresher(ctx); }); } // namespace tree } // namespace xgboost diff --git a/src/tree/updater_sync.cc b/src/tree/updater_sync.cc index 4f7c7a1a85a6..5a22675965dc 100644 --- a/src/tree/updater_sync.cc +++ b/src/tree/updater_sync.cc @@ -20,8 +20,9 @@ DMLC_REGISTRY_FILE_TAG(updater_sync); * \brief syncher that synchronize the tree in all distributed nodes * can implement various strategies, so far it is always set to node 0's tree */ -class TreeSyncher: public TreeUpdater { +class TreeSyncher : public TreeUpdater { public: + explicit TreeSyncher(GenericParameter const* tparam) : TreeUpdater(tparam) {} void Configure(const Args&) override {} void LoadConfig(Json const&) override {} @@ -52,9 +53,7 @@ class TreeSyncher: public TreeUpdater { }; XGBOOST_REGISTER_TREE_UPDATER(TreeSyncher, "sync") -.describe("Syncher that synchronize the tree in all distributed nodes.") -.set_body([](ObjInfo) { - return new TreeSyncher(); - }); + .describe("Syncher that synchronize the tree in all distributed nodes.") + .set_body([](GenericParameter const* tparam, ObjInfo) { return new TreeSyncher(tparam); }); } // namespace tree } // namespace xgboost diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 82f40465deb2..883537863307 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -275,8 +275,10 @@ void TestHistogramIndexImpl() { int constexpr kNRows = 1000, kNCols = 10; // Build 2 matrices and build a histogram maker with that - tree::GPUHistMakerSpecialised hist_maker{ObjInfo{ObjInfo::kRegression}}, - hist_maker_ext{ObjInfo{ObjInfo::kRegression}}; + + GenericParameter generic_param(CreateEmptyGenericParam(0)); + tree::GPUHistMaker hist_maker{&generic_param,ObjInfo{ObjInfo::kRegression}}, + hist_maker_ext{&generic_param,ObjInfo{ObjInfo::kRegression}}; std::unique_ptr hist_maker_dmat( CreateSparsePageDMatrixWithRC(kNRows, kNCols, 0, true)); @@ -289,10 +291,9 @@ void TestHistogramIndexImpl() { {"max_leaves", "0"} }; - GenericParameter generic_param(CreateEmptyGenericParam(0)); - hist_maker.Configure(training_params, &generic_param); + hist_maker.Configure(training_params); hist_maker.InitDataOnce(hist_maker_dmat.get()); - hist_maker_ext.Configure(training_params, &generic_param); + hist_maker_ext.Configure(training_params); hist_maker_ext.InitDataOnce(hist_maker_ext_dmat.get()); // Extract the device maker from the histogram makers and from that its compressed @@ -344,10 +345,9 @@ void UpdateTree(HostDeviceVector* gpair, DMatrix* dmat, {"sampling_method", sampling_method}, }; - tree::GPUHistMakerSpecialised hist_maker{ObjInfo{ObjInfo::kRegression}}; GenericParameter generic_param(CreateEmptyGenericParam(0)); - hist_maker.Configure(args, &generic_param); - + tree::GPUHistMaker hist_maker{&generic_param,ObjInfo{ObjInfo::kRegression}}; + hist_maker.Configure(args); hist_maker.Update(gpair, dmat, {tree}); auto cache = linalg::VectorView{preds->DeviceSpan(), {preds->Size()}, 0}; hist_maker.UpdatePredictionCache(dmat, cache); diff --git a/tests/python-gpu/test_gpu_basic_models.py b/tests/python-gpu/test_gpu_basic_models.py index 06e63bdd56d9..9e955eac2931 100644 --- a/tests/python-gpu/test_gpu_basic_models.py +++ b/tests/python-gpu/test_gpu_basic_models.py @@ -16,11 +16,11 @@ class TestGPUBasicModels: cpu_test_bm = test_bm.TestModels() def run_cls(self, X, y): - cls = xgb.XGBClassifier(tree_method='gpu_hist', single_precision_histogram=True) + cls = xgb.XGBClassifier(tree_method='gpu_hist') cls.fit(X, y) cls.get_booster().save_model('test_deterministic_gpu_hist-0.json') - cls = xgb.XGBClassifier(tree_method='gpu_hist', single_precision_histogram=True) + cls = xgb.XGBClassifier(tree_method='gpu_hist') cls.fit(X, y) cls.get_booster().save_model('test_deterministic_gpu_hist-1.json') diff --git a/tests/python-gpu/test_gpu_updaters.py b/tests/python-gpu/test_gpu_updaters.py index a3427b566360..8f3cbcaac61f 100644 --- a/tests/python-gpu/test_gpu_updaters.py +++ b/tests/python-gpu/test_gpu_updaters.py @@ -3,7 +3,7 @@ import gc import pytest import xgboost as xgb -from hypothesis import given, strategies, assume, settings, note +from hypothesis import given, strategies, assume, settings, note, reproduce_failure sys.path.append("tests/python") import testing as tm @@ -15,7 +15,6 @@ 'max_leaves': strategies.integers(0, 256), 'max_bin': strategies.integers(2, 1024), 'grow_policy': strategies.sampled_from(['lossguide', 'depthwise']), - 'single_precision_histogram': strategies.booleans(), 'min_child_weight': strategies.floats(0.5, 2.0), 'seed': strategies.integers(0, 10), # We cannot enable subsampling as the training loss can increase From 047ce1ed1ebe1ee1835e2d246accbc7c1ff9d77e Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Mon, 25 Apr 2022 06:27:43 -0700 Subject: [PATCH 2/4] Lint --- src/tree/updater_gpu_hist.cu | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 2cac6b6c4f4a..08b6b079f1b4 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -531,7 +531,7 @@ struct GPUHistMakerDevice { RegTree& tree = *p_tree; // Sanity check - have we created a leaf with no training instances? - if (!rabit::IsDistributed()) { + if (!rabit::IsDistributed() && row_partitioner) { CHECK(row_partitioner->GetRows(candidate.nid).size() > 0) << "No training instances in this leaf!"; } @@ -686,7 +686,7 @@ class GPUHistMaker : public TreeUpdater { public: explicit GPUHistMaker(GenericParameter const* ctx, ObjInfo task) : TreeUpdater(ctx), task_{task} {}; - void Configure(const Args& args) { + void Configure(const Args& args) override { // Used in test to count how many configurations are performed LOG(DEBUG) << "[GPU Hist]: Configure"; param_.UpdateAllowUnknown(args); @@ -713,7 +713,7 @@ class GPUHistMaker : public TreeUpdater { } void Update(HostDeviceVector* gpair, DMatrix* dmat, - const std::vector& trees) { + const std::vector& trees) override { monitor_.Start("Update"); // rescale learning rate according to size of trees @@ -792,7 +792,8 @@ class GPUHistMaker : public TreeUpdater { maker->UpdateTree(gpair, p_fmat, task_, p_tree, &reducer_); } - bool UpdatePredictionCache(const DMatrix* data, linalg::VectorView p_out_preds) { + bool UpdatePredictionCache(const DMatrix* data, + linalg::VectorView p_out_preds) override { if (maker == nullptr || p_last_fmat_ == nullptr || p_last_fmat_ != data) { return false; } From 0486996a03a9f4612a8a5a8d35e546b2a01ca663 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Tue, 26 Apr 2022 03:16:07 -0700 Subject: [PATCH 3/4] Reset on configure. --- src/tree/updater_gpu_hist.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 08b6b079f1b4..082e447571e5 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -692,6 +692,7 @@ class GPUHistMaker : public TreeUpdater { param_.UpdateAllowUnknown(args); hist_maker_param_.UpdateAllowUnknown(args); dh::CheckComputeCapability(); + initialised_ = false; monitor_.Init("updater_gpu_hist"); } From b3a1ae33c9b6b1ff6a5870413de644119d1dbd21 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Mon, 2 May 2022 06:27:30 -0700 Subject: [PATCH 4/4] Lint --- src/tree/updater_gpu_hist.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index c21a1476eff6..dbc9ea100ca5 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -741,7 +741,7 @@ class GPUHistMaker : public TreeUpdater { void Update(HostDeviceVector* gpair, DMatrix* dmat, common::Span> out_position, - const std::vector& trees) { + const std::vector& trees) override { monitor_.Start("Update"); // rescale learning rate according to size of trees