From 70ca1218ab18f89450b8a8a10a13d37f3b46b218 Mon Sep 17 00:00:00 2001 From: jiamingy Date: Thu, 12 May 2022 16:53:52 +0800 Subject: [PATCH 1/6] Drop single precision hist. --- doc/gpu/index.rst | 28 ----- doc/parameter.rst | 4 - doc/tutorials/saving_model.rst | 2 - src/common/hist_util.cc | 131 ++++++-------------- src/common/hist_util.h | 59 +++------ src/tree/hist/evaluate_splits.h | 15 ++- src/tree/hist/histogram.h | 24 ++-- src/tree/hist/param.cc | 10 -- src/tree/hist/param.h | 23 ---- src/tree/updater_approx.cc | 43 ++----- src/tree/updater_approx.h | 1 - src/tree/updater_quantile_hist.cc | 84 +++++-------- src/tree/updater_quantile_hist.h | 16 +-- tests/cpp/common/test_hist_util.cc | 37 ++---- tests/cpp/tree/hist/test_evaluate_splits.cc | 24 ++-- tests/cpp/tree/hist/test_histogram.cc | 63 ++++------ tests/cpp/tree/test_evaluate_splits.h | 4 +- 17 files changed, 166 insertions(+), 402 deletions(-) delete mode 100644 src/tree/hist/param.cc delete mode 100644 src/tree/hist/param.h diff --git a/doc/gpu/index.rst b/doc/gpu/index.rst index 049cf311dff2..82309523f4cf 100644 --- a/doc/gpu/index.rst +++ b/doc/gpu/index.rst @@ -34,34 +34,6 @@ Supported parameters .. |tick| unicode:: U+2714 .. |cross| unicode:: U+2718 -+--------------------------------+--------------+ -| parameter | ``gpu_hist`` | -+================================+==============+ -| ``subsample`` | |tick| | -+--------------------------------+--------------+ -| ``sampling_method`` | |tick| | -+--------------------------------+--------------+ -| ``colsample_bytree`` | |tick| | -+--------------------------------+--------------+ -| ``colsample_bylevel`` | |tick| | -+--------------------------------+--------------+ -| ``max_bin`` | |tick| | -+--------------------------------+--------------+ -| ``gamma`` | |tick| | -+--------------------------------+--------------+ -| ``gpu_id`` | |tick| | -+--------------------------------+--------------+ -| ``predictor`` | |tick| | -+--------------------------------+--------------+ -| ``grow_policy`` | |tick| | -+--------------------------------+--------------+ -| ``monotone_constraints`` | |tick| | -+--------------------------------+--------------+ -| ``interaction_constraints`` | |tick| | -+--------------------------------+--------------+ -| ``single_precision_histogram`` | |cross| | -+--------------------------------+--------------+ - GPU accelerated prediction is enabled by default for the above mentioned ``tree_method`` parameters but can be switched to CPU prediction by setting ``predictor`` to ``cpu_predictor``. This could be useful if you want to conserve GPU memory. Likewise when using CPU algorithms, GPU accelerated prediction can be enabled by setting ``predictor`` to ``gpu_predictor``. The device ordinal (which GPU to use if you have many of them) can be selected using the diff --git a/doc/parameter.rst b/doc/parameter.rst index deb2f635c1b2..eca5d46eebe9 100644 --- a/doc/parameter.rst +++ b/doc/parameter.rst @@ -238,10 +238,6 @@ Parameters for Tree Booster Additional parameters for ``hist``, ``gpu_hist`` and ``approx`` tree method =========================================================================== -* ``single_precision_histogram``, [default= ``false``] - - - Use single precision to build histograms instead of double precision. Currently disabled for ``gpu_hist``. - * ``max_cat_to_onehot`` .. versionadded:: 1.6 diff --git a/doc/tutorials/saving_model.rst b/doc/tutorials/saving_model.rst index ab60cfc1a1d0..723cde431bc4 100644 --- a/doc/tutorials/saving_model.rst +++ b/doc/tutorials/saving_model.rst @@ -171,8 +171,6 @@ Will print out something similar to (not actual output as it's too long for demo "grow_gpu_hist": { "gpu_hist_train_param": { "debug_synchronize": "0", - "gpu_batch_nrows": "0", - "single_precision_histogram": "0" }, "train_param": { "alpha": "0", diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index c14da59a7f60..64ede3102811 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -36,78 +36,51 @@ HistogramCuts::HistogramCuts() { /*! * \brief fill a histogram by zeros in range [begin, end) */ -template -void InitilizeHistByZeroes(GHistRow hist, size_t begin, size_t end) { +void InitilizeHistByZeroes(GHistRow hist, size_t begin, size_t end) { #if defined(XGBOOST_STRICT_R_MODE) && XGBOOST_STRICT_R_MODE == 1 - std::fill(hist.begin() + begin, hist.begin() + end, - xgboost::detail::GradientPairInternal()); + std::fill(hist.begin() + begin, hist.begin() + end, xgboost::GradientPairPrecise()); #else // defined(XGBOOST_STRICT_R_MODE) && XGBOOST_STRICT_R_MODE == 1 - memset(hist.data() + begin, '\0', (end-begin)* - sizeof(xgboost::detail::GradientPairInternal)); + memset(hist.data() + begin, '\0', (end - begin) * sizeof(xgboost::GradientPairPrecise)); #endif // defined(XGBOOST_STRICT_R_MODE) && XGBOOST_STRICT_R_MODE == 1 } -template void InitilizeHistByZeroes(GHistRow hist, size_t begin, - size_t end); -template void InitilizeHistByZeroes(GHistRow hist, size_t begin, - size_t end); /*! * \brief Increment hist as dst += add in range [begin, end) */ -template -void IncrementHist(GHistRow dst, const GHistRow add, - size_t begin, size_t end) { - GradientSumT* pdst = reinterpret_cast(dst.data()); - const GradientSumT* padd = reinterpret_cast(add.data()); +void IncrementHist(GHistRow dst, const GHistRow add, size_t begin, size_t end) { + double* pdst = reinterpret_cast(dst.data()); + const double *padd = reinterpret_cast(add.data()); for (size_t i = 2 * begin; i < 2 * end; ++i) { pdst[i] += padd[i]; } } -template void IncrementHist(GHistRow dst, const GHistRow add, - size_t begin, size_t end); -template void IncrementHist(GHistRow dst, const GHistRow add, - size_t begin, size_t end); /*! * \brief Copy hist from src to dst in range [begin, end) */ -template -void CopyHist(GHistRow dst, const GHistRow src, - size_t begin, size_t end) { - GradientSumT* pdst = reinterpret_cast(dst.data()); - const GradientSumT* psrc = reinterpret_cast(src.data()); +void CopyHist(GHistRow dst, const GHistRow src, size_t begin, size_t end) { + double *pdst = reinterpret_cast(dst.data()); + const double *psrc = reinterpret_cast(src.data()); for (size_t i = 2 * begin; i < 2 * end; ++i) { pdst[i] = psrc[i]; } } -template void CopyHist(GHistRow dst, const GHistRow src, - size_t begin, size_t end); -template void CopyHist(GHistRow dst, const GHistRow src, - size_t begin, size_t end); /*! * \brief Compute Subtraction: dst = src1 - src2 in range [begin, end) */ -template -void SubtractionHist(GHistRow dst, const GHistRow src1, - const GHistRow src2, - size_t begin, size_t end) { - GradientSumT* pdst = reinterpret_cast(dst.data()); - const GradientSumT* psrc1 = reinterpret_cast(src1.data()); - const GradientSumT* psrc2 = reinterpret_cast(src2.data()); +void SubtractionHist(GHistRow dst, const GHistRow src1, const GHistRow src2, size_t begin, + size_t end) { + double* pdst = reinterpret_cast(dst.data()); + const double* psrc1 = reinterpret_cast(src1.data()); + const double* psrc2 = reinterpret_cast(src2.data()); for (size_t i = 2 * begin; i < 2 * end; ++i) { pdst[i] = psrc1[i] - psrc2[i]; } } -template void SubtractionHist(GHistRow dst, const GHistRow src1, - const GHistRow src2, - size_t begin, size_t end); -template void SubtractionHist(GHistRow dst, const GHistRow src1, - const GHistRow src2, - size_t begin, size_t end); struct Prefetch { public: @@ -132,11 +105,10 @@ struct Prefetch { constexpr size_t Prefetch::kNoPrefetchSize; -template +template void BuildHistKernel(const std::vector &gpair, - const RowSetCollection::Elem row_indices, - const GHistIndexMatrix &gmat, GHistRow hist) { + const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat, + GHistRow hist) { const size_t size = row_indices.Size(); const size_t *rid = row_indices.begin; auto const *pgh = reinterpret_cast(gpair.data()); @@ -154,7 +126,7 @@ void BuildHistKernel(const std::vector &gpair, const size_t n_features = get_row_ptr(row_indices.begin[0] + 1) - get_row_ptr(row_indices.begin[0]); - auto hist_data = reinterpret_cast(hist.data()); + auto hist_data = reinterpret_cast(hist.data()); const uint32_t two{2}; // Each element from 'gpair' and 'hist' contains // 2 FP values: gradient and hessian. // So we need to multiply each row-index/bin-index by 2 @@ -195,24 +167,21 @@ void BuildHistKernel(const std::vector &gpair, } } -template +template void BuildHistDispatch(const std::vector &gpair, - const RowSetCollection::Elem row_indices, - const GHistIndexMatrix &gmat, GHistRow hist) { + const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat, + GHistRow hist) { auto first_page = gmat.base_rowid == 0; if (first_page) { switch (gmat.index.GetBinTypeSize()) { case kUint8BinsTypeSize: - BuildHistKernel( - gpair, row_indices, gmat, hist); + BuildHistKernel(gpair, row_indices, gmat, hist); break; case kUint16BinsTypeSize: - BuildHistKernel( - gpair, row_indices, gmat, hist); + BuildHistKernel(gpair, row_indices, gmat, hist); break; case kUint32BinsTypeSize: - BuildHistKernel( - gpair, row_indices, gmat, hist); + BuildHistKernel(gpair, row_indices, gmat, hist); break; default: CHECK(false); // no default behavior @@ -220,16 +189,13 @@ void BuildHistDispatch(const std::vector &gpair, } else { switch (gmat.index.GetBinTypeSize()) { case kUint8BinsTypeSize: - BuildHistKernel( - gpair, row_indices, gmat, hist); + BuildHistKernel(gpair, row_indices, gmat, hist); break; case kUint16BinsTypeSize: - BuildHistKernel( - gpair, row_indices, gmat, hist); + BuildHistKernel(gpair, row_indices, gmat, hist); break; case kUint32BinsTypeSize: - BuildHistKernel( - gpair, row_indices, gmat, hist); + BuildHistKernel(gpair, row_indices, gmat, hist); break; default: CHECK(false); // no default behavior @@ -237,12 +203,10 @@ void BuildHistDispatch(const std::vector &gpair, } } -template template -void GHistBuilder::BuildHist( - const std::vector &gpair, - const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat, - GHistRowT hist) const { +void GHistBuilder::BuildHist(const std::vector &gpair, + const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat, + GHistRow hist) const { const size_t nrows = row_indices.Size(); const size_t no_prefetch_size = Prefetch::NoPrefetchSize(nrows); @@ -252,7 +216,7 @@ void GHistBuilder::BuildHist( if (contiguousBlock) { // contiguous memory access, built-in HW prefetching is enough - BuildHistDispatch(gpair, row_indices, + BuildHistDispatch(gpair, row_indices, gmat, hist); } else { const RowSetCollection::Elem span1(row_indices.begin, @@ -260,33 +224,18 @@ void GHistBuilder::BuildHist( const RowSetCollection::Elem span2(row_indices.end - no_prefetch_size, row_indices.end); - BuildHistDispatch(gpair, span1, gmat, - hist); + BuildHistDispatch(gpair, span1, gmat, hist); // no prefetching to avoid loading extra memory - BuildHistDispatch(gpair, span2, gmat, - hist); + BuildHistDispatch(gpair, span2, gmat, hist); } } -template void -GHistBuilder::BuildHist(const std::vector &gpair, - const RowSetCollection::Elem row_indices, - const GHistIndexMatrix &gmat, - GHistRow hist) const; -template void -GHistBuilder::BuildHist(const std::vector &gpair, - const RowSetCollection::Elem row_indices, - const GHistIndexMatrix &gmat, - GHistRow hist) const; -template void -GHistBuilder::BuildHist(const std::vector &gpair, - const RowSetCollection::Elem row_indices, - const GHistIndexMatrix &gmat, - GHistRow hist) const; -template void -GHistBuilder::BuildHist(const std::vector &gpair, - const RowSetCollection::Elem row_indices, - const GHistIndexMatrix &gmat, - GHistRow hist) const; +template void GHistBuilder::BuildHist(const std::vector &gpair, + const RowSetCollection::Elem row_indices, + const GHistIndexMatrix &gmat, GHistRow hist) const; + +template void GHistBuilder::BuildHist(const std::vector &gpair, + const RowSetCollection::Elem row_indices, + const GHistIndexMatrix &gmat, GHistRow hist) const; } // namespace common } // namespace xgboost diff --git a/src/common/hist_util.h b/src/common/hist_util.h index fad082c2c596..c203a0eb4357 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -322,56 +322,44 @@ bst_bin_t XGBOOST_HOST_DEV_INLINE BinarySearchBin(size_t begin, size_t end, return -1; } -template -using GHistRow = Span >; +using GHistRow = Span; /*! * \brief fill a histogram by zeros */ -template -void InitilizeHistByZeroes(GHistRow hist, size_t begin, size_t end); +void InitilizeHistByZeroes(GHistRow hist, size_t begin, size_t end); /*! * \brief Increment hist as dst += add in range [begin, end) */ -template -void IncrementHist(GHistRow dst, const GHistRow add, - size_t begin, size_t end); +void IncrementHist(GHistRow dst, const GHistRow add, size_t begin, size_t end); /*! * \brief Copy hist from src to dst in range [begin, end) */ -template -void CopyHist(GHistRow dst, const GHistRow src, - size_t begin, size_t end); +void CopyHist(GHistRow dst, const GHistRow src, size_t begin, size_t end); /*! * \brief Compute Subtraction: dst = src1 - src2 in range [begin, end) */ -template -void SubtractionHist(GHistRow dst, const GHistRow src1, - const GHistRow src2, - size_t begin, size_t end); +void SubtractionHist(GHistRow dst, const GHistRow src1, const GHistRow src2, size_t begin, + size_t end); /*! * \brief histogram of gradient statistics for multiple nodes */ -template class HistCollection { public: - using GHistRowT = GHistRow; - using GradientPairT = xgboost::detail::GradientPairInternal; - // access histogram for i-th node - GHistRowT operator[](bst_uint nid) const { + GHistRow operator[](bst_uint nid) const { constexpr uint32_t kMax = std::numeric_limits::max(); const size_t id = row_ptr_.at(nid); CHECK_NE(id, kMax); - GradientPairT* ptr = nullptr; + GradientPairPrecise* ptr = nullptr; if (contiguous_allocation_) { - ptr = const_cast(data_[0].data() + nbins_*id); + ptr = const_cast(data_[0].data() + nbins_*id); } else { - ptr = const_cast(data_[id].data()); + ptr = const_cast(data_[id].data()); } return {ptr, nbins_}; } @@ -431,7 +419,7 @@ class HistCollection { /*! \brief flag to identify contiguous memory allocation */ bool contiguous_allocation_ = false; - std::vector> data_; + std::vector> data_; /*! \brief row_ptr_[nid] locates bin for histogram of node nid */ std::vector row_ptr_; @@ -442,11 +430,8 @@ class HistCollection { * Supports processing multiple tree-nodes for nested parallelism * Able to reduce histograms across threads in efficient way */ -template class ParallelGHistBuilder { public: - using GHistRowT = GHistRow; - void Init(size_t nbins) { if (nbins != nbins_) { hist_buffer_.Init(nbins); @@ -457,7 +442,7 @@ class ParallelGHistBuilder { // Add new elements if needed, mark all hists as unused // targeted_hists - already allocated hists which should contain final results after Reduce() call void Reset(size_t nthreads, size_t nodes, const BlockedSpace2d& space, - const std::vector& targeted_hists) { + const std::vector& targeted_hists) { hist_buffer_.Init(nbins_); tid_nid_to_hist_.clear(); threads_to_nids_map_.clear(); @@ -478,7 +463,7 @@ class ParallelGHistBuilder { } // Get specified hist, initialize hist by zeros if it wasn't used before - GHistRowT GetInitializedHist(size_t tid, size_t nid) { + GHistRow GetInitializedHist(size_t tid, size_t nid) { CHECK_LT(nid, nodes_); CHECK_LT(tid, nthreads_); @@ -486,7 +471,7 @@ class ParallelGHistBuilder { if (idx >= 0) { hist_buffer_.AllocateData(idx); } - GHistRowT hist = idx == -1 ? targeted_hists_[nid] : hist_buffer_[idx]; + GHistRow hist = idx == -1 ? targeted_hists_[nid] : hist_buffer_[idx]; if (!hist_was_used_[tid * nodes_ + nid]) { InitilizeHistByZeroes(hist, 0, hist.size()); @@ -501,7 +486,7 @@ class ParallelGHistBuilder { CHECK_GT(end, begin); CHECK_LT(nid, nodes_); - GHistRowT dst = targeted_hists_[nid]; + GHistRow dst = targeted_hists_[nid]; bool is_updated = false; for (size_t tid = 0; tid < nthreads_; ++tid) { @@ -509,7 +494,7 @@ class ParallelGHistBuilder { is_updated = true; int idx = tid_nid_to_hist_.at({tid, nid}); - GHistRowT src = idx == -1 ? targeted_hists_[nid] : hist_buffer_[idx]; + GHistRow src = idx == -1 ? targeted_hists_[nid] : hist_buffer_[idx]; if (dst.data() != src.data()) { IncrementHist(dst, src, begin, end); @@ -595,7 +580,7 @@ class ParallelGHistBuilder { /*! \brief number of nodes which will be processed in parallel */ size_t nodes_ = 0; /*! \brief Buffer for additional histograms for Parallel processing */ - HistCollection hist_buffer_; + HistCollection hist_buffer_; /*! * \brief Marks which hists were used, it means that they should be merged. * Contains only {true or false} values @@ -606,7 +591,7 @@ class ParallelGHistBuilder { /*! \brief Buffer for additional histograms for Parallel processing */ std::vector threads_to_nids_map_; /*! \brief Contains histograms for final results */ - std::vector targeted_hists_; + std::vector targeted_hists_; /*! * \brief map pair {tid, nid} to index of allocated histogram from hist_buffer_ and targeted_hists_, * -1 is reserved for targeted_hists_ @@ -617,19 +602,15 @@ class ParallelGHistBuilder { /*! * \brief builder for histograms of gradient statistics */ -template class GHistBuilder { public: - using GHistRowT = GHistRow; - GHistBuilder() = default; explicit GHistBuilder(uint32_t nbins): nbins_{nbins} {} // construct a histogram via histogram aggregation template - void BuildHist(const std::vector &gpair, - const RowSetCollection::Elem row_indices, - const GHistIndexMatrix &gmat, GHistRowT hist) const; + void BuildHist(const std::vector& gpair, const RowSetCollection::Elem row_indices, + const GHistIndexMatrix& gmat, GHistRow hist) const; uint32_t GetNumBins() const { return nbins_; } diff --git a/src/tree/hist/evaluate_splits.h b/src/tree/hist/evaluate_splits.h index c7acdd383515..894edfe325d3 100644 --- a/src/tree/hist/evaluate_splits.h +++ b/src/tree/hist/evaluate_splits.h @@ -22,7 +22,8 @@ namespace xgboost { namespace tree { -template class HistEvaluator { +template +class HistEvaluator { private: struct NodeEntry { /*! \brief statics for node entry */ @@ -57,7 +58,7 @@ template class HistEvaluator { // a non-missing value for the particular feature fid. template GradStats EnumerateSplit(common::HistogramCuts const &cut, common::Span sorted_idx, - const common::GHistRow &hist, bst_feature_t fidx, + const common::GHistRow &hist, bst_feature_t fidx, bst_node_t nidx, TreeEvaluator::SplitEvaluator const &evaluator, SplitEntry *p_best) const { @@ -197,10 +198,8 @@ template class HistEvaluator { } public: - void EvaluateSplits(const common::HistCollection &hist, - common::HistogramCuts const &cut, - common::Span feature_types, - const RegTree &tree, + void EvaluateSplits(const common::HistCollection &hist, common::HistogramCuts const &cut, + common::Span feature_types, const RegTree &tree, std::vector *p_entries) { auto& entries = *p_entries; // All nodes are on the same level, so we can store the shared ptr. @@ -377,10 +376,10 @@ template class HistEvaluator { * * \param p_last_tree The last tree being updated by tree updater */ -template +template void UpdatePredictionCacheImpl(GenericParameter const *ctx, RegTree const *p_last_tree, std::vector const &partitioner, - HistEvaluator const &hist_evaluator, + HistEvaluator const &hist_evaluator, TrainParam const ¶m, linalg::VectorView out_preds) { CHECK_GT(out_preds.Size(), 0U); diff --git a/src/tree/hist/histogram.h b/src/tree/hist/histogram.h index 6020de28d529..11c7b385a172 100644 --- a/src/tree/hist/histogram.h +++ b/src/tree/hist/histogram.h @@ -16,17 +16,15 @@ namespace xgboost { namespace tree { -template class HistogramBuilder { - using GradientPairT = xgboost::detail::GradientPairInternal; - using GHistRowT = common::GHistRow; - +template +class HistogramBuilder { /*! \brief culmulative histogram of gradients. */ - common::HistCollection hist_; + common::HistCollection hist_; /*! \brief culmulative local parent histogram of gradients. */ - common::HistCollection hist_local_worker_; - common::GHistBuilder builder_; - common::ParallelGHistBuilder buffer_; - rabit::Reducer reducer_; + common::HistCollection hist_local_worker_; + common::GHistBuilder builder_; + common::ParallelGHistBuilder buffer_; + rabit::Reducer reducer_; BatchParam param_; int32_t n_threads_{-1}; size_t n_batches_{0}; @@ -51,7 +49,7 @@ template class HistogramBuilder { hist_.Init(total_bins); hist_local_worker_.Init(total_bins); buffer_.Init(total_bins); - builder_ = common::GHistBuilder(total_bins); + builder_ = common::GHistBuilder(total_bins); is_distributed_ = is_distributed; } @@ -64,7 +62,7 @@ template class HistogramBuilder { const size_t n_nodes = nodes_for_explicit_hist_build.size(); CHECK_GT(n_nodes, 0); - std::vector target_hists(n_nodes); + std::vector target_hists(n_nodes); for (size_t i = 0; i < n_nodes; ++i) { const int32_t nid = nodes_for_explicit_hist_build[i].nid; target_hists[i] = hist_[nid]; @@ -243,9 +241,7 @@ template class HistogramBuilder { public: /* Getters for tests. */ - common::HistCollection const& Histogram() { - return hist_; - } + common::HistCollection const &Histogram() { return hist_; } auto& Buffer() { return buffer_; } private: diff --git a/src/tree/hist/param.cc b/src/tree/hist/param.cc deleted file mode 100644 index 05f1a24adc81..000000000000 --- a/src/tree/hist/param.cc +++ /dev/null @@ -1,10 +0,0 @@ -/*! - * Copyright 2022 XGBoost contributors - */ -#include "param.h" - -namespace xgboost { -namespace tree { -DMLC_REGISTER_PARAMETER(CPUHistMakerTrainParam); -} // namespace tree -} // namespace xgboost diff --git a/src/tree/hist/param.h b/src/tree/hist/param.h deleted file mode 100644 index 2fbee28c423b..000000000000 --- a/src/tree/hist/param.h +++ /dev/null @@ -1,23 +0,0 @@ -/*! - * Copyright 2021 XGBoost contributors - */ -#ifndef XGBOOST_TREE_HIST_PARAM_H_ -#define XGBOOST_TREE_HIST_PARAM_H_ -#include "xgboost/parameter.h" - -namespace xgboost { -namespace tree { -// training parameters specific to this algorithm -struct CPUHistMakerTrainParam - : public XGBoostParameter { - bool single_precision_histogram; - // declare parameters - DMLC_DECLARE_PARAMETER(CPUHistMakerTrainParam) { - DMLC_DECLARE_FIELD(single_precision_histogram).set_default(false).describe( - "Use single precision to build histograms."); - } -}; -} // namespace tree -} // namespace xgboost - -#endif // XGBOOST_TREE_HIST_PARAM_H_ diff --git a/src/tree/updater_approx.cc b/src/tree/updater_approx.cc index 39c22b507e71..b507b5220f2a 100644 --- a/src/tree/updater_approx.cc +++ b/src/tree/updater_approx.cc @@ -15,7 +15,6 @@ #include "driver.h" #include "hist/evaluate_splits.h" #include "hist/histogram.h" -#include "hist/param.h" #include "param.h" #include "xgboost/base.h" #include "xgboost/json.h" @@ -38,13 +37,12 @@ auto BatchSpec(TrainParam const &p, common::Span hess) { } } // anonymous namespace -template class GloablApproxBuilder { protected: TrainParam param_; std::shared_ptr col_sampler_; - HistEvaluator evaluator_; - HistogramBuilder histogram_builder_; + HistEvaluator evaluator_; + HistogramBuilder histogram_builder_; Context const *ctx_; ObjInfo const task_; @@ -166,7 +164,7 @@ class GloablApproxBuilder { } public: - explicit GloablApproxBuilder(TrainParam param, MetaInfo const &info, GenericParameter const *ctx, + explicit GloablApproxBuilder(TrainParam param, MetaInfo const &info, Context const *ctx, std::shared_ptr column_sampler, ObjInfo task, common::Monitor *monitor) : param_{std::move(param)}, @@ -256,10 +254,8 @@ class GloablApproxBuilder { class GlobalApproxUpdater : public TreeUpdater { TrainParam param_; common::Monitor monitor_; - CPUHistMakerTrainParam hist_param_; // specializations for different histogram precision. - std::unique_ptr> f32_impl_; - std::unique_ptr> f64_impl_; + std::unique_ptr pimpl_; // pointer to the last DMatrix, used for update prediction cache. DMatrix *cached_{nullptr}; std::shared_ptr column_sampler_ = @@ -272,19 +268,14 @@ class GlobalApproxUpdater : public TreeUpdater { monitor_.Init(__func__); } - void Configure(const Args &args) override { - param_.UpdateAllowUnknown(args); - hist_param_.UpdateAllowUnknown(args); - } + void Configure(const Args &args) override { param_.UpdateAllowUnknown(args); } void LoadConfig(Json const &in) override { auto const &config = get(in); FromJson(config.at("train_param"), &this->param_); - FromJson(config.at("hist_param"), &this->hist_param_); } void SaveConfig(Json *p_out) const override { auto &out = *p_out; out["train_param"] = ToJson(param_); - out["hist_param"] = ToJson(hist_param_); } void InitData(TrainParam const ¶m, HostDeviceVector const *gpair, @@ -316,13 +307,8 @@ class GlobalApproxUpdater : public TreeUpdater { float lr = param_.learning_rate; param_.learning_rate = lr / trees.size(); - if (hist_param_.single_precision_histogram) { - f32_impl_ = std::make_unique>(param_, m->Info(), ctx_, - column_sampler_, task_, &monitor_); - } else { - f64_impl_ = std::make_unique>(param_, m->Info(), ctx_, - column_sampler_, task_, &monitor_); - } + pimpl_ = std::make_unique(param_, m->Info(), ctx_, column_sampler_, task_, + &monitor_); std::vector h_gpair; InitData(param_, gpair, &h_gpair); @@ -335,26 +321,17 @@ class GlobalApproxUpdater : public TreeUpdater { size_t t_idx = 0; for (auto p_tree : trees) { - if (hist_param_.single_precision_histogram) { - this->f32_impl_->UpdateTree(m, h_gpair, hess, p_tree, &out_position[t_idx]); - } else { - this->f64_impl_->UpdateTree(m, h_gpair, hess, p_tree, &out_position[t_idx]); - } + this->pimpl_->UpdateTree(m, h_gpair, hess, p_tree, &out_position[t_idx]); ++t_idx; } param_.learning_rate = lr; } bool UpdatePredictionCache(const DMatrix *data, linalg::VectorView out_preds) override { - if (data != cached_ || (!this->f32_impl_ && !this->f64_impl_)) { + if (data != cached_ || !pimpl_) { return false; } - - if (hist_param_.single_precision_histogram) { - this->f32_impl_->UpdatePredictionCache(data, out_preds); - } else { - this->f64_impl_->UpdatePredictionCache(data, out_preds); - } + this->pimpl_->UpdatePredictionCache(data, out_preds); return true; } diff --git a/src/tree/updater_approx.h b/src/tree/updater_approx.h index bb37f99ec61d..73f60eedd408 100644 --- a/src/tree/updater_approx.h +++ b/src/tree/updater_approx.h @@ -16,7 +16,6 @@ #include "driver.h" #include "hist/evaluate_splits.h" #include "hist/expand_entry.h" -#include "hist/param.h" #include "param.h" #include "xgboost/generic_parameters.h" #include "xgboost/json.h" diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 968b86c1796d..3f44e33b4b02 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -32,7 +32,6 @@ DMLC_REGISTRY_FILE_TAG(updater_quantile_hist); void QuantileHistMaker::Configure(const Args &args) { param_.UpdateAllowUnknown(args); - hist_maker_param_.UpdateAllowUnknown(args); } void QuantileHistMaker::Update(HostDeviceVector *gpair, DMatrix *dmat, @@ -44,24 +43,14 @@ void QuantileHistMaker::Update(HostDeviceVector *gpair, DMatrix *d // build tree const size_t n_trees = trees.size(); - if (hist_maker_param_.single_precision_histogram) { - if (!float_builder_) { - float_builder_.reset(new Builder(n_trees, param_, dmat, task_, ctx_)); - } - } else { - if (!double_builder_) { - double_builder_.reset(new Builder(n_trees, param_, dmat, task_, ctx_)); - } + if (!pimpl_) { + pimpl_.reset(new Builder(n_trees, param_, dmat, task_, ctx_)); } size_t t_idx{0}; for (auto p_tree : trees) { auto &t_row_position = out_position[t_idx]; - if (hist_maker_param_.single_precision_histogram) { - this->float_builder_->UpdateTree(gpair, dmat, p_tree, &t_row_position); - } else { - this->double_builder_->UpdateTree(gpair, dmat, p_tree, &t_row_position); - } + this->pimpl_->UpdateTree(gpair, dmat, p_tree, &t_row_position); ++t_idx; } @@ -70,17 +59,14 @@ void QuantileHistMaker::Update(HostDeviceVector *gpair, DMatrix *d bool QuantileHistMaker::UpdatePredictionCache(const DMatrix *data, linalg::VectorView out_preds) { - if (hist_maker_param_.single_precision_histogram && float_builder_) { - return float_builder_->UpdatePredictionCache(data, out_preds); - } else if (double_builder_) { - return double_builder_->UpdatePredictionCache(data, out_preds); + if (pimpl_) { + return pimpl_->UpdatePredictionCache(data, out_preds); } else { return false; } } -template -CPUExpandEntry QuantileHistMaker::Builder::InitRoot( +CPUExpandEntry QuantileHistMaker::Builder::InitRoot( DMatrix *p_fmat, RegTree *p_tree, const std::vector &gpair_h) { CPUExpandEntry node(RegTree::kRoot, p_tree->GetDepth(0), 0.0f); @@ -117,8 +103,7 @@ CPUExpandEntry QuantileHistMaker::Builder::InitRoot( for (auto const &grad : gpair_h) { grad_stat.Add(grad.GetGrad(), grad.GetHess()); } - rabit::Allreduce(reinterpret_cast(&grad_stat), - 2); + rabit::Allreduce(reinterpret_cast(&grad_stat), 2); } auto weight = evaluator_->InitRoot(GradStats{grad_stat}); @@ -140,10 +125,9 @@ CPUExpandEntry QuantileHistMaker::Builder::InitRoot( return node; } -template -void QuantileHistMaker::Builder::BuildHistogram( - DMatrix *p_fmat, RegTree *p_tree, std::vector const &valid_candidates, - std::vector const &gpair) { +void QuantileHistMaker::Builder::BuildHistogram(DMatrix *p_fmat, RegTree *p_tree, + std::vector const &valid_candidates, + std::vector const &gpair) { std::vector nodes_to_build(valid_candidates.size()); std::vector nodes_to_sub(valid_candidates.size()); @@ -173,10 +157,9 @@ void QuantileHistMaker::Builder::BuildHistogram( } } -template -void QuantileHistMaker::Builder::LeafPartition( - RegTree const &tree, common::Span gpair, - std::vector *p_out_position) { +void QuantileHistMaker::Builder::LeafPartition(RegTree const &tree, + common::Span gpair, + std::vector *p_out_position) { monitor_->Start(__func__); if (!task_.UpdateTreeLeaf()) { return; @@ -187,10 +170,9 @@ void QuantileHistMaker::Builder::LeafPartition( monitor_->Stop(__func__); } -template -void QuantileHistMaker::Builder::ExpandTree( - DMatrix *p_fmat, RegTree *p_tree, const std::vector &gpair_h, - HostDeviceVector *p_out_position) { +void QuantileHistMaker::Builder::ExpandTree(DMatrix *p_fmat, RegTree *p_tree, + const std::vector &gpair_h, + HostDeviceVector *p_out_position) { monitor_->Start(__func__); Driver driver(static_cast(param_.grow_policy)); @@ -252,10 +234,9 @@ void QuantileHistMaker::Builder::ExpandTree( monitor_->Stop(__func__); } -template -void QuantileHistMaker::Builder::UpdateTree( - HostDeviceVector *gpair, DMatrix *p_fmat, RegTree *p_tree, - HostDeviceVector *p_out_position) { +void QuantileHistMaker::Builder::UpdateTree(HostDeviceVector *gpair, DMatrix *p_fmat, + RegTree *p_tree, + HostDeviceVector *p_out_position) { monitor_->Start(__func__); std::vector *gpair_ptr = &(gpair->HostVector()); @@ -272,9 +253,8 @@ void QuantileHistMaker::Builder::UpdateTree( monitor_->Stop(__func__); } -template -bool QuantileHistMaker::Builder::UpdatePredictionCache( - DMatrix const *data, linalg::VectorView out_preds) const { +bool QuantileHistMaker::Builder::UpdatePredictionCache(DMatrix const *data, + linalg::VectorView out_preds) const { // p_last_fmat_ is a valid pointer as long as UpdatePredictionCache() is called in // conjunction with Update(). if (!p_last_fmat_ || !p_last_tree_ || data != p_last_fmat_) { @@ -287,9 +267,8 @@ bool QuantileHistMaker::Builder::UpdatePredictionCache( return true; } -template -void QuantileHistMaker::Builder::InitSampling(const DMatrix &fmat, - std::vector *gpair) { +void QuantileHistMaker::Builder::InitSampling(const DMatrix &fmat, + std::vector *gpair) { monitor_->Start(__func__); const auto &info = fmat.Info(); auto& rnd = common::GlobalRandom(); @@ -325,14 +304,10 @@ void QuantileHistMaker::Builder::InitSampling(const DMatrix &fmat, #endif // XGBOOST_CUSTOMIZE_GLOBAL_PRNG monitor_->Stop(__func__); } -template -size_t QuantileHistMaker::Builder::GetNumberOfTrees() { - return n_trees_; -} +size_t QuantileHistMaker::Builder::GetNumberOfTrees() { return n_trees_; } -template -void QuantileHistMaker::Builder::InitData(DMatrix *fmat, const RegTree &tree, - std::vector *gpair) { +void QuantileHistMaker::Builder::InitData(DMatrix *fmat, const RegTree &tree, + std::vector *gpair) { monitor_->Start(__func__); const auto& info = fmat->Info(); @@ -362,8 +337,8 @@ void QuantileHistMaker::Builder::InitData(DMatrix *fmat, const Reg // store a pointer to the tree p_last_tree_ = &tree; - evaluator_.reset(new HistEvaluator{ - param_, info, this->ctx_->Threads(), column_sampler_}); + evaluator_.reset( + new HistEvaluator{param_, info, this->ctx_->Threads(), column_sampler_}); monitor_->Stop(__func__); } @@ -406,9 +381,6 @@ void HistRowPartitioner::AddSplitsToRowSet(const std::vector &no } } -template struct QuantileHistMaker::Builder; -template struct QuantileHistMaker::Builder; - XGBOOST_REGISTER_TREE_UPDATER(QuantileHistMaker, "grow_quantile_histmaker") .describe("Grow tree using quantized histogram.") .set_body([](GenericParameter const *ctx, ObjInfo task) { diff --git a/src/tree/updater_quantile_hist.h b/src/tree/updater_quantile_hist.h index 18b4d66824a1..c1811764f217 100644 --- a/src/tree/updater_quantile_hist.h +++ b/src/tree/updater_quantile_hist.h @@ -24,7 +24,6 @@ #include "hist/evaluate_splits.h" #include "hist/histogram.h" #include "hist/expand_entry.h" -#include "hist/param.h" #include "constraints.h" #include "./param.h" @@ -249,12 +248,10 @@ class QuantileHistMaker: public TreeUpdater { void LoadConfig(Json const& in) override { auto const& config = get(in); FromJson(config.at("train_param"), &this->param_); - FromJson(config.at("cpu_hist_train_param"), &this->hist_maker_param_); } void SaveConfig(Json* p_out) const override { auto& out = *p_out; out["train_param"] = ToJson(param_); - out["cpu_hist_train_param"] = ToJson(hist_maker_param_); } char const* Name() const override { @@ -264,22 +261,20 @@ class QuantileHistMaker: public TreeUpdater { bool HasNodePosition() const override { return true; } protected: - CPUHistMakerTrainParam hist_maker_param_; // training parameter TrainParam param_; // actual builder that runs the algorithm - template struct Builder { public: - using GradientPairT = xgboost::detail::GradientPairInternal; + using GradientPairT = xgboost::GradientPairPrecise; // constructor explicit Builder(const size_t n_trees, const TrainParam& param, DMatrix const* fmat, ObjInfo task, GenericParameter const* ctx) : n_trees_(n_trees), param_(param), p_last_fmat_(fmat), - histogram_builder_{new HistogramBuilder}, + histogram_builder_{new HistogramBuilder}, task_{task}, ctx_{ctx}, monitor_{std::make_unique()} { @@ -320,14 +315,14 @@ class QuantileHistMaker: public TreeUpdater { std::vector gpair_local_; - std::unique_ptr> evaluator_; + std::unique_ptr> evaluator_; std::vector partitioner_; // back pointers to tree and data matrix const RegTree* p_last_tree_{nullptr}; DMatrix const* const p_last_fmat_; - std::unique_ptr> histogram_builder_; + std::unique_ptr> histogram_builder_; ObjInfo task_; // Context for number of threads GenericParameter const* ctx_; @@ -336,8 +331,7 @@ class QuantileHistMaker: public TreeUpdater { }; protected: - std::unique_ptr> float_builder_; - std::unique_ptr> double_builder_; + std::unique_ptr pimpl_; ObjInfo task_; }; } // namespace tree diff --git a/tests/cpp/common/test_hist_util.cc b/tests/cpp/common/test_hist_util.cc index 9c48096bf251..3dd33e03a316 100644 --- a/tests/cpp/common/test_hist_util.cc +++ b/tests/cpp/common/test_hist_util.cc @@ -16,7 +16,6 @@ namespace common { size_t GetNThreads() { return common::OmpGetNumThreads(0); } -template void ParallelGHistBuilderReset() { constexpr size_t kBins = 10; constexpr size_t kNodes = 5; @@ -25,16 +24,16 @@ void ParallelGHistBuilderReset() { constexpr double kValue = 1.0; const size_t nthreads = GetNThreads(); - HistCollection collection; + HistCollection collection; collection.Init(kBins); for(size_t inode = 0; inode < kNodesExtended; inode++) { collection.AddHistRow(inode); } collection.AllocateAllData(); - ParallelGHistBuilder hist_builder; + ParallelGHistBuilder hist_builder; hist_builder.Init(kBins); - std::vector> target_hist(kNodes); + std::vector target_hist(kNodes); for(size_t i = 0; i < target_hist.size(); ++i) { target_hist[i] = collection[i]; } @@ -45,7 +44,7 @@ void ParallelGHistBuilderReset() { common::ParallelFor2d(space, nthreads, [&](size_t inode, common::Range1d r) { const size_t tid = omp_get_thread_num(); - GHistRow hist = hist_builder.GetInitializedHist(tid, inode); + GHistRow hist = hist_builder.GetInitializedHist(tid, inode); // fill hist by some non-null values for(size_t j = 0; j < kBins; ++j) { hist[j].Add(kValue, kValue); @@ -63,7 +62,7 @@ void ParallelGHistBuilderReset() { common::ParallelFor2d(space2, nthreads, [&](size_t inode, common::Range1d r) { const size_t tid = omp_get_thread_num(); - GHistRow hist = hist_builder.GetInitializedHist(tid, inode); + GHistRow hist = hist_builder.GetInitializedHist(tid, inode); // fill hist by some non-null values for(size_t j = 0; j < kBins; ++j) { ASSERT_EQ(0.0, hist[j].GetGrad()); @@ -72,8 +71,6 @@ void ParallelGHistBuilderReset() { }); } - -template void ParallelGHistBuilderReduceHist(){ constexpr size_t kBins = 10; constexpr size_t kNodes = 5; @@ -81,16 +78,16 @@ void ParallelGHistBuilderReduceHist(){ constexpr double kValue = 1.0; const size_t nthreads = GetNThreads(); - HistCollection collection; + HistCollection collection; collection.Init(kBins); for(size_t inode = 0; inode < kNodes; inode++) { collection.AddHistRow(inode); } collection.AllocateAllData(); - ParallelGHistBuilder hist_builder; + ParallelGHistBuilder hist_builder; hist_builder.Init(kBins); - std::vector> target_hist(kNodes); + std::vector target_hist(kNodes); for(size_t i = 0; i < target_hist.size(); ++i) { target_hist[i] = collection[i]; } @@ -102,7 +99,7 @@ void ParallelGHistBuilderReduceHist(){ common::ParallelFor2d(space, nthreads, [&](size_t inode, common::Range1d r) { const size_t tid = omp_get_thread_num(); - GHistRow hist = hist_builder.GetInitializedHist(tid, inode); + GHistRow hist = hist_builder.GetInitializedHist(tid, inode); for(size_t i = 0; i < kBins; ++i) { hist[i].Add(kValue, kValue); } @@ -120,21 +117,9 @@ void ParallelGHistBuilderReduceHist(){ } } -TEST(ParallelGHistBuilder, ResetDouble) { - ParallelGHistBuilderReset(); -} - -TEST(ParallelGHistBuilder, ResetFloat) { - ParallelGHistBuilderReset(); -} +TEST(ParallelGHistBuilder, Reset) { ParallelGHistBuilderReset(); } -TEST(ParallelGHistBuilder, ReduceHistDouble) { - ParallelGHistBuilderReduceHist(); -} - -TEST(ParallelGHistBuilder, ReduceHistFloat) { - ParallelGHistBuilderReduceHist(); -} +TEST(ParallelGHistBuilder, ReduceHist) { ParallelGHistBuilderReduceHist(); } TEST(CutsBuilder, SearchGroupInd) { size_t constexpr kNumGroups = 4; diff --git a/tests/cpp/tree/hist/test_evaluate_splits.cc b/tests/cpp/tree/hist/test_evaluate_splits.cc index 8de84b2a1076..81b5812fd27a 100644 --- a/tests/cpp/tree/hist/test_evaluate_splits.cc +++ b/tests/cpp/tree/hist/test_evaluate_splits.cc @@ -12,7 +12,7 @@ namespace xgboost { namespace tree { -template void TestEvaluateSplits() { +void TestEvaluateSplits() { int static constexpr kRows = 8, kCols = 16; auto orig = omp_get_max_threads(); int32_t n_threads = std::min(omp_get_max_threads(), 4); @@ -24,9 +24,8 @@ template void TestEvaluateSplits() { auto dmat = RandomDataGenerator(kRows, kCols, 0).Seed(3).GenerateDMatrix(); - auto evaluator = - HistEvaluator{param, dmat->Info(), n_threads, sampler}; - common::HistCollection hist; + auto evaluator = HistEvaluator{param, dmat->Info(), n_threads, sampler}; + common::HistCollection hist; std::vector row_gpairs = { {1.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {2.27f, 0.28f}, {0.27f, 0.29f}, {0.37f, 0.39f}, {-0.47f, 0.49f}, {0.57f, 0.59f}}; @@ -40,7 +39,7 @@ template void TestEvaluateSplits() { std::iota(row_indices.begin(), row_indices.end(), 0); row_set_collection.Init(); - auto hist_builder = common::GHistBuilder(gmat.cut.Ptrs().back()); + auto hist_builder = common::GHistBuilder(gmat.cut.Ptrs().back()); hist.Init(gmat.cut.Ptrs().back()); hist.AddHistRow(0); hist.AllocateAllData(); @@ -85,10 +84,7 @@ template void TestEvaluateSplits() { omp_set_num_threads(orig); } -TEST(HistEvaluator, Evaluate) { - TestEvaluateSplits(); - TestEvaluateSplits(); -} +TEST(HistEvaluator, Evaluate) { TestEvaluateSplits(); } TEST(HistEvaluator, Apply) { RegTree tree; @@ -97,7 +93,7 @@ TEST(HistEvaluator, Apply) { param.UpdateAllowUnknown(Args{{"min_child_weight", "0"}, {"reg_lambda", "0.0"}}); auto dmat = RandomDataGenerator(kNRows, kNCols, 0).Seed(3).GenerateDMatrix(); auto sampler = std::make_shared(); - auto evaluator_ = HistEvaluator{param, dmat->Info(), 4, sampler}; + auto evaluator_ = HistEvaluator{param, dmat->Info(), 4, sampler}; CPUExpandEntry entry{0, 0, 10.0f}; entry.split.left_sum = GradStats{0.4, 0.6f}; @@ -123,8 +119,7 @@ TEST_F(TestPartitionBasedSplit, CPUHist) { // check the evaluator is returning the optimal split std::vector ft{FeatureType::kCategorical}; auto sampler = std::make_shared(); - HistEvaluator evaluator{param_, info_, common::OmpGetNumThreads(0), - sampler}; + HistEvaluator evaluator{param_, info_, common::OmpGetNumThreads(0), sampler}; evaluator.InitRoot(GradStats{total_gpair_}); RegTree tree; std::vector entries(1); @@ -155,12 +150,11 @@ auto CompareOneHotAndPartition(bool onehot) { int32_t n_threads = 16; auto sampler = std::make_shared(); - auto evaluator = - HistEvaluator{param, dmat->Info(), n_threads, sampler}; + auto evaluator = HistEvaluator{param, dmat->Info(), n_threads, sampler}; std::vector entries(1); for (auto const &gmat : dmat->GetBatches({32, param.sparse_threshold})) { - common::HistCollection hist; + common::HistCollection hist; entries.front().nid = 0; entries.front().depth = 0; diff --git a/tests/cpp/tree/hist/test_histogram.cc b/tests/cpp/tree/hist/test_histogram.cc index 06147afa3700..1669caaaa0fa 100644 --- a/tests/cpp/tree/hist/test_histogram.cc +++ b/tests/cpp/tree/hist/test_histogram.cc @@ -23,7 +23,6 @@ void InitRowPartitionForTest(common::RowSetCollection *row_set, size_t n_samples } } // anonymous namespace -template void TestAddHistRows(bool is_distributed) { std::vector nodes_for_explicit_hist_build_; std::vector nodes_for_subtraction_trick_; @@ -46,7 +45,7 @@ void TestAddHistRows(bool is_distributed) { nodes_for_subtraction_trick_.emplace_back(5, tree.GetDepth(5), 0.0f); nodes_for_subtraction_trick_.emplace_back(6, tree.GetDepth(6), 0.0f); - HistogramBuilder histogram_builder; + HistogramBuilder histogram_builder; histogram_builder.Reset(gmat.cut.TotalBins(), {kMaxBins, 0.5}, omp_get_max_threads(), 1, is_distributed); histogram_builder.AddHistRows(&starting_index, &sync_count, @@ -66,14 +65,10 @@ void TestAddHistRows(bool is_distributed) { TEST(CPUHistogram, AddRows) { - TestAddHistRows(true); - TestAddHistRows(true); - - TestAddHistRows(false); - TestAddHistRows(false); + TestAddHistRows(true); + TestAddHistRows(false); } -template void TestSyncHist(bool is_distributed) { size_t constexpr kNRows = 8, kNCols = 16; int32_t constexpr kMaxBins = 4; @@ -88,7 +83,7 @@ void TestSyncHist(bool is_distributed) { RandomDataGenerator(kNRows, kNCols, 0.8).Seed(3).GenerateDMatrix(); auto const &gmat = *(p_fmat->GetBatches(BatchParam{kMaxBins, 0.5}).begin()); - HistogramBuilder histogram; + HistogramBuilder histogram; uint32_t total_bins = gmat.cut.Ptrs().back(); histogram.Reset(total_bins, {kMaxBins, 0.5}, omp_get_max_threads(), 1, is_distributed); @@ -153,7 +148,7 @@ void TestSyncHist(bool is_distributed) { }, 256); - std::vector> target_hists(n_nodes); + std::vector target_hists(n_nodes); for (size_t i = 0; i < nodes_for_explicit_hist_build_.size(); ++i) { const int32_t nid = nodes_for_explicit_hist_build_[i].nid; target_hists[i] = histogram.Histogram()[nid]; @@ -163,7 +158,7 @@ void TestSyncHist(bool is_distributed) { std::vector n_ids = {1, 2}; for (size_t i : n_ids) { auto this_hist = histogram.Histogram()[i]; - GradientSumT *p_hist = reinterpret_cast(this_hist.data()); + double *p_hist = reinterpret_cast(this_hist.data()); for (size_t bin_id = 0; bin_id < 2 * total_bins; ++bin_id) { p_hist[bin_id] = 2 * bin_id; } @@ -172,7 +167,7 @@ void TestSyncHist(bool is_distributed) { n_ids[1] = 5; for (size_t i : n_ids) { auto this_hist = histogram.Histogram()[i]; - GradientSumT *p_hist = reinterpret_cast(this_hist.data()); + double *p_hist = reinterpret_cast(this_hist.data()); for (size_t bin_id = 0; bin_id < 2 * total_bins; ++bin_id) { p_hist[bin_id] = bin_id; } @@ -190,15 +185,12 @@ void TestSyncHist(bool is_distributed) { sync_count); } - using GHistRowT = common::GHistRow; - auto check_hist = [](const GHistRowT parent, const GHistRowT left, - const GHistRowT right, size_t begin, size_t end) { - const GradientSumT *p_parent = - reinterpret_cast(parent.data()); - const GradientSumT *p_left = - reinterpret_cast(left.data()); - const GradientSumT *p_right = - reinterpret_cast(right.data()); + using GHistRowT = common::GHistRow; + auto check_hist = [](const GHistRowT parent, const GHistRowT left, const GHistRowT right, + size_t begin, size_t end) { + const double *p_parent = reinterpret_cast(parent.data()); + const double *p_left = reinterpret_cast(left.data()); + const double *p_right = reinterpret_cast(right.data()); for (size_t i = 2 * begin; i < 2 * end; ++i) { ASSERT_EQ(p_parent[i], p_left[i] + p_right[i]); } @@ -230,14 +222,10 @@ void TestSyncHist(bool is_distributed) { } TEST(CPUHistogram, SyncHist) { - TestSyncHist(true); - TestSyncHist(true); - - TestSyncHist(false); - TestSyncHist(false); + TestSyncHist(true); + TestSyncHist(false); } -template void TestBuildHistogram(bool is_distributed) { size_t constexpr kNRows = 8, kNCols = 16; int32_t constexpr kMaxBins = 4; @@ -252,7 +240,7 @@ void TestBuildHistogram(bool is_distributed) { {0.27f, 0.29f}, {0.37f, 0.39f}, {0.47f, 0.49f}, {0.57f, 0.59f}}; bst_node_t nid = 0; - HistogramBuilder histogram; + HistogramBuilder histogram; histogram.Reset(total_bins, {kMaxBins, 0.5}, omp_get_max_threads(), 1, is_distributed); RegTree tree; @@ -296,11 +284,8 @@ void TestBuildHistogram(bool is_distributed) { } TEST(CPUHistogram, BuildHist) { - TestBuildHistogram(true); - TestBuildHistogram(true); - - TestBuildHistogram(false); - TestBuildHistogram(false); + TestBuildHistogram(true); + TestBuildHistogram(false); } namespace { @@ -329,7 +314,7 @@ void TestHistogramCategorical(size_t n_categories) { /** * Generate hist with cat data. */ - HistogramBuilder cat_hist; + HistogramBuilder cat_hist; for (auto const &gidx : cat_m->GetBatches({kBins, 0.5})) { auto total_bins = gidx.cut.TotalBins(); cat_hist.Reset(total_bins, {kBins, 0.5}, omp_get_max_threads(), 1, false); @@ -342,7 +327,7 @@ void TestHistogramCategorical(size_t n_categories) { */ auto x_encoded = OneHotEncodeFeature(x, n_categories); auto encode_m = GetDMatrixFromData(x_encoded, kRows, n_categories); - HistogramBuilder onehot_hist; + HistogramBuilder onehot_hist; for (auto const &gidx : encode_m->GetBatches({kBins, 0.5})) { auto total_bins = gidx.cut.TotalBins(); onehot_hist.Reset(total_bins, {kBins, 0.5}, omp_get_max_threads(), 1, false); @@ -382,8 +367,8 @@ void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx) { std::vector nodes; nodes.emplace_back(0, tree.GetDepth(0), 0.0f); - common::GHistRow multi_page; - HistogramBuilder multi_build; + common::GHistRow multi_page; + HistogramBuilder multi_build; { /** * Multi page @@ -417,8 +402,8 @@ void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx) { multi_page = multi_build.Histogram()[0]; } - HistogramBuilder single_build; - common::GHistRow single_page; + HistogramBuilder single_build; + common::GHistRow single_page; { /** * Single page diff --git a/tests/cpp/tree/test_evaluate_splits.h b/tests/cpp/tree/test_evaluate_splits.h index 4b1a320319fb..c8e0f577e9fe 100644 --- a/tests/cpp/tree/test_evaluate_splits.h +++ b/tests/cpp/tree/test_evaluate_splits.h @@ -22,7 +22,7 @@ class TestPartitionBasedSplit : public ::testing::Test { MetaInfo info_; float best_score_{-std::numeric_limits::infinity()}; common::HistogramCuts cuts_; - common::HistCollection hist_; + common::HistCollection hist_; GradientPairPrecise total_gpair_; void SetUp() override { @@ -55,7 +55,7 @@ class TestPartitionBasedSplit : public ::testing::Test { total_gpair_ += e; } - auto enumerate = [this, n_feat = info_.num_col_](common::GHistRow hist, + auto enumerate = [this, n_feat = info_.num_col_](common::GHistRow hist, GradientPairPrecise parent_sum) { int32_t best_thresh = -1; float best_score{-std::numeric_limits::infinity()}; From 2035b460ca0b7affbcac2ba7c5a60e190e81e972 Mon Sep 17 00:00:00 2001 From: jiamingy Date: Thu, 12 May 2022 17:07:10 +0800 Subject: [PATCH 2/6] amalgamation. --- amalgamation/xgboost-all0.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/amalgamation/xgboost-all0.cc b/amalgamation/xgboost-all0.cc index c684e6309de8..ef3c2ffde8a7 100644 --- a/amalgamation/xgboost-all0.cc +++ b/amalgamation/xgboost-all0.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2015-2019 by Contributors. + * Copyright 2015-2022 by Contributors. * \brief XGBoost Amalgamation. * This offers an alternative way to compile the entire library from this single file. * @@ -50,7 +50,6 @@ // trees #include "../src/tree/constraints.cc" -#include "../src/tree/hist/param.cc" #include "../src/tree/param.cc" #include "../src/tree/tree_model.cc" #include "../src/tree/tree_updater.cc" From 19e375995634fafe5ff77bb7c98c372bbf4ed483 Mon Sep 17 00:00:00 2001 From: jiamingy Date: Thu, 12 May 2022 17:29:44 +0800 Subject: [PATCH 3/6] Force inline. --- include/xgboost/base.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/include/xgboost/base.h b/include/xgboost/base.h index 04cb6ddb7609..4a5c82283b66 100644 --- a/include/xgboost/base.h +++ b/include/xgboost/base.h @@ -157,7 +157,8 @@ class GradientPairInternal { hess_ += hess; } - inline static void Reduce(GradientPairInternal& a, const GradientPairInternal& b) { // NOLINT(*) + DMLC_ALWAYS_INLINE static void Reduce(GradientPairInternal &a, // NOLINT + const GradientPairInternal &b) { a += b; } From 04dd0dd07cfa66e0bdb7494f85f53872a0a87d32 Mon Sep 17 00:00:00 2001 From: jiamingy Date: Thu, 12 May 2022 19:31:18 +0800 Subject: [PATCH 4/6] instantiation. --- include/xgboost/base.h | 3 +-- src/tree/hist/histogram.h | 1 + src/tree/updater_quantile_hist.h | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/include/xgboost/base.h b/include/xgboost/base.h index 4a5c82283b66..04cb6ddb7609 100644 --- a/include/xgboost/base.h +++ b/include/xgboost/base.h @@ -157,8 +157,7 @@ class GradientPairInternal { hess_ += hess; } - DMLC_ALWAYS_INLINE static void Reduce(GradientPairInternal &a, // NOLINT - const GradientPairInternal &b) { + inline static void Reduce(GradientPairInternal& a, const GradientPairInternal& b) { // NOLINT(*) a += b; } diff --git a/src/tree/hist/histogram.h b/src/tree/hist/histogram.h index 11c7b385a172..6528a4953a44 100644 --- a/src/tree/hist/histogram.h +++ b/src/tree/hist/histogram.h @@ -51,6 +51,7 @@ class HistogramBuilder { buffer_.Init(total_bins); builder_ = common::GHistBuilder(total_bins); is_distributed_ = is_distributed; + auto DMLC_ATTRIBUTE_UNUSED __force_instantiation = &GradientPairPrecise::Reduce; } template diff --git a/src/tree/updater_quantile_hist.h b/src/tree/updater_quantile_hist.h index c1811764f217..7f88f0ec7924 100644 --- a/src/tree/updater_quantile_hist.h +++ b/src/tree/updater_quantile_hist.h @@ -235,7 +235,7 @@ inline BatchParam HistBatch(TrainParam const& param) { class QuantileHistMaker: public TreeUpdater { public: explicit QuantileHistMaker(GenericParameter const* ctx, ObjInfo task) - : task_{task}, TreeUpdater(ctx) {} + : TreeUpdater(ctx), task_{task} {} void Configure(const Args& args) override; void Update(HostDeviceVector* gpair, DMatrix* dmat, From 0b60da8a15086e95ab9d346d04b606273793ddfe Mon Sep 17 00:00:00 2001 From: jiamingy Date: Thu, 12 May 2022 19:51:12 +0800 Subject: [PATCH 5/6] Remove. --- src/tree/hist/histogram.h | 1 + src/tree/updater_quantile_hist.cc | 4 ++-- src/tree/updater_quantile_hist.h | 1 - 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/tree/hist/histogram.h b/src/tree/hist/histogram.h index 6528a4953a44..266086af11a4 100644 --- a/src/tree/hist/histogram.h +++ b/src/tree/hist/histogram.h @@ -51,6 +51,7 @@ class HistogramBuilder { buffer_.Init(total_bins); builder_ = common::GHistBuilder(total_bins); is_distributed_ = is_distributed; + // Workaround s390x gcc 7.5.0 auto DMLC_ATTRIBUTE_UNUSED __force_instantiation = &GradientPairPrecise::Reduce; } diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 3f44e33b4b02..a4a8ace83b5c 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -82,7 +82,7 @@ CPUExpandEntry QuantileHistMaker::Builder::InitRoot( } { - GradientPairT grad_stat; + GradientPairPrecise grad_stat; if (p_fmat->IsDense()) { /** * Specialized code for dense data: For dense data (with no missing value), the sum @@ -96,7 +96,7 @@ CPUExpandEntry QuantileHistMaker::Builder::InitRoot( auto hist = this->histogram_builder_->Histogram()[RegTree::kRoot]; auto begin = hist.data(); for (uint32_t i = ibegin; i < iend; ++i) { - GradientPairT const &et = begin[i]; + GradientPairPrecise const &et = begin[i]; grad_stat.Add(et.GetGrad(), et.GetHess()); } } else { diff --git a/src/tree/updater_quantile_hist.h b/src/tree/updater_quantile_hist.h index 7f88f0ec7924..d7c2b4dec3ef 100644 --- a/src/tree/updater_quantile_hist.h +++ b/src/tree/updater_quantile_hist.h @@ -267,7 +267,6 @@ class QuantileHistMaker: public TreeUpdater { // actual builder that runs the algorithm struct Builder { public: - using GradientPairT = xgboost::GradientPairPrecise; // constructor explicit Builder(const size_t n_trees, const TrainParam& param, DMatrix const* fmat, ObjInfo task, GenericParameter const* ctx) From eae0c2c32ee41c9b38663d729055393a20a9365c Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 13 May 2022 14:43:02 +0800 Subject: [PATCH 6/6] Update tests/cpp/tree/hist/test_histogram.cc Co-authored-by: Philip Hyunsu Cho --- tests/cpp/tree/hist/test_histogram.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/tree/hist/test_histogram.cc b/tests/cpp/tree/hist/test_histogram.cc index 1669caaaa0fa..c0bd62629899 100644 --- a/tests/cpp/tree/hist/test_histogram.cc +++ b/tests/cpp/tree/hist/test_histogram.cc @@ -167,7 +167,7 @@ void TestSyncHist(bool is_distributed) { n_ids[1] = 5; for (size_t i : n_ids) { auto this_hist = histogram.Histogram()[i]; - double *p_hist = reinterpret_cast(this_hist.data()); + double *p_hist = reinterpret_cast(this_hist.data()); for (size_t bin_id = 0; bin_id < 2 * total_bins; ++bin_id) { p_hist[bin_id] = bin_id; }