From eca7735723a24a9d48f6c5568dee470aaa508234 Mon Sep 17 00:00:00 2001 From: jiamingy Date: Sun, 1 May 2022 00:27:11 +0800 Subject: [PATCH 1/4] Always use partition based categorical splits. --- doc/tutorials/categorical.rst | 27 ++++++++++++--------------- include/xgboost/task.h | 3 --- src/common/categorical.h | 2 +- src/tree/gpu_hist/evaluator.cu | 2 +- 4 files changed, 14 insertions(+), 20 deletions(-) diff --git a/doc/tutorials/categorical.rst b/doc/tutorials/categorical.rst index 3f106962d7af..f5ae16bafc10 100644 --- a/doc/tutorials/categorical.rst +++ b/doc/tutorials/categorical.rst @@ -72,23 +72,20 @@ Optimal Partitioning .. versionadded:: 1.6 Optimal partitioning is a technique for partitioning the categorical predictors for each -node split, the proof of optimality for numerical objectives like ``RMSE`` was first -introduced by `[1] <#references>`__. The algorithm is used in decision trees for handling -regression and binary classification tasks `[2] <#references>`__, later LightGBM `[3] -<#references>`__ brought it to the context of gradient boosting trees and now is also -adopted in XGBoost as an optional feature for handling categorical splits. More -specifically, the proof by Fisher `[1] <#references>`__ states that, when trying to -partition a set of discrete values into groups based on the distances between a measure of -these values, one only needs to look at sorted partitions instead of enumerating all -possible permutations. In the context of decision trees, the discrete values are -categories, and the measure is the output leaf value. Intuitively, we want to group the -categories that output similar leaf values. During split finding, we first sort the -gradient histogram to prepare the contiguous partitions then enumerate the splits +node split, the proof of optimality for numerical output was first introduced by `[1] +<#references>`__. The algorithm is used in decision trees `[2] <#references>`__, later +LightGBM `[3] <#references>`__ brought it to the context of gradient boosting trees and +now is also adopted in XGBoost as an optional feature for handling categorical +splits. More specifically, the proof by Fisher `[1] <#references>`__ states that, when +trying to partition a set of discrete values into groups based on the distances between a +measure of these values, one only needs to look at sorted partitions instead of +enumerating all possible permutations. In the context of decision trees, the discrete +values are categories, and the measure is the output leaf value. Intuitively, we want to +group the categories that output similar leaf values. During split finding, we first sort +the gradient histogram to prepare the contiguous partitions then enumerate the splits according to these sorted values. One of the related parameters for XGBoost is ``max_cat_to_one_hot``, which controls whether one-hot encoding or partitioning should be -used for each feature, see :doc:`/parameter` for details. When objective is not -regression or binary classification, XGBoost will fallback to using onehot encoding -instead. +used for each feature, see :doc:`/parameter` for details. ********************** diff --git a/include/xgboost/task.h b/include/xgboost/task.h index 739207a309d8..8f57383ddf32 100644 --- a/include/xgboost/task.h +++ b/include/xgboost/task.h @@ -38,9 +38,6 @@ struct ObjInfo { ObjInfo(Task t) : task{t} {} // NOLINT ObjInfo(Task t, bool khess, bool zhess) : task{t}, const_hess{khess}, zero_hess(zhess) {} - XGBOOST_DEVICE bool UseOneHot() const { - return (task != ObjInfo::kRegression && task != ObjInfo::kBinary); - } /** * \brief Use adaptive tree if the objective doesn't have valid hessian value. */ diff --git a/src/common/categorical.h b/src/common/categorical.h index 5eff62264cf2..f07d14318cbf 100644 --- a/src/common/categorical.h +++ b/src/common/categorical.h @@ -83,7 +83,7 @@ inline void InvalidCategory() { * \brief Whether should we use onehot encoding for categorical data. */ XGBOOST_DEVICE inline bool UseOneHot(uint32_t n_cats, uint32_t max_cat_to_onehot, ObjInfo task) { - bool use_one_hot = n_cats < max_cat_to_onehot || task.UseOneHot(); + bool use_one_hot = n_cats < max_cat_to_onehot; return use_one_hot; } diff --git a/src/tree/gpu_hist/evaluator.cu b/src/tree/gpu_hist/evaluator.cu index bc2027489131..91ee3600dbe7 100644 --- a/src/tree/gpu_hist/evaluator.cu +++ b/src/tree/gpu_hist/evaluator.cu @@ -21,7 +21,7 @@ void GPUHistEvaluator::Reset(common::HistogramCuts const &cuts, int32_t device) { param_ = param; tree_evaluator_ = TreeEvaluator{param, n_features, device}; - if (cuts.HasCategorical() && !task.UseOneHot()) { + if (cuts.HasCategorical()) { dh::XGBCachingDeviceAllocator alloc; auto ptrs = cuts.cut_ptrs_.ConstDeviceSpan(); auto beg = thrust::make_counting_iterator(1ul); From 8fe53c97c0fb5401d41818ce2321719ec8ddc4d9 Mon Sep 17 00:00:00 2001 From: jiamingy Date: Mon, 2 May 2022 21:42:15 +0800 Subject: [PATCH 2/4] Remove unused parameter. --- src/common/categorical.h | 2 +- src/tree/gpu_hist/evaluate_splits.cu | 32 +++++++++---------- src/tree/gpu_hist/evaluate_splits.cuh | 9 +++--- src/tree/gpu_hist/evaluator.cu | 4 +-- src/tree/hist/evaluate_splits.h | 9 ++---- src/tree/updater_approx.cc | 17 +++++----- src/tree/updater_gpu_hist.cu | 26 +++++++-------- src/tree/updater_quantile_hist.cc | 4 +-- .../cpp/tree/gpu_hist/test_evaluate_splits.cu | 27 ++++++---------- tests/cpp/tree/hist/test_evaluate_splits.cc | 13 ++++---- tests/cpp/tree/test_gpu_hist.cu | 6 ++-- 11 files changed, 66 insertions(+), 83 deletions(-) diff --git a/src/common/categorical.h b/src/common/categorical.h index f07d14318cbf..341a887f48a9 100644 --- a/src/common/categorical.h +++ b/src/common/categorical.h @@ -82,7 +82,7 @@ inline void InvalidCategory() { /*! * \brief Whether should we use onehot encoding for categorical data. */ -XGBOOST_DEVICE inline bool UseOneHot(uint32_t n_cats, uint32_t max_cat_to_onehot, ObjInfo task) { +XGBOOST_DEVICE inline bool UseOneHot(uint32_t n_cats, uint32_t max_cat_to_onehot) { bool use_one_hot = n_cats < max_cat_to_onehot; return use_one_hot; } diff --git a/src/tree/gpu_hist/evaluate_splits.cu b/src/tree/gpu_hist/evaluate_splits.cu index ce8b13d0def2..26a571f25a9c 100644 --- a/src/tree/gpu_hist/evaluate_splits.cu +++ b/src/tree/gpu_hist/evaluate_splits.cu @@ -199,13 +199,11 @@ __device__ void EvaluateFeature( } template -__global__ void EvaluateSplitsKernel( - EvaluateSplitInputs left, - EvaluateSplitInputs right, - ObjInfo task, - common::Span sorted_idx, - TreeEvaluator::SplitEvaluator evaluator, - common::Span out_candidates) { +__global__ void EvaluateSplitsKernel(EvaluateSplitInputs left, + EvaluateSplitInputs right, + common::Span sorted_idx, + TreeEvaluator::SplitEvaluator evaluator, + common::Span out_candidates) { // KeyValuePair here used as threadIdx.x -> gain_value using ArgMaxT = cub::KeyValuePair; using BlockScanT = @@ -241,7 +239,7 @@ __global__ void EvaluateSplitsKernel( if (common::IsCat(inputs.feature_types, fidx)) { auto n_bins_in_feat = inputs.feature_segments[fidx + 1] - inputs.feature_segments[fidx]; - if (common::UseOneHot(n_bins_in_feat, inputs.param.max_cat_to_onehot, task)) { + if (common::UseOneHot(n_bins_in_feat, inputs.param.max_cat_to_onehot)) { EvaluateFeature(fidx, inputs, evaluator, sorted_idx, 0, &best_split, &temp_storage); } else { @@ -310,7 +308,7 @@ __device__ void SortBasedSplit(EvaluateSplitInputs const &input, template void GPUHistEvaluator::EvaluateSplits( - EvaluateSplitInputs left, EvaluateSplitInputs right, ObjInfo task, + EvaluateSplitInputs left, EvaluateSplitInputs right, TreeEvaluator::SplitEvaluator evaluator, common::Span out_splits) { if (!split_cats_.empty()) { @@ -322,8 +320,8 @@ void GPUHistEvaluator::EvaluateSplits( // One block for each feature uint32_t constexpr kBlockThreads = 256; - dh::LaunchKernel {static_cast(combined_num_features), kBlockThreads, 0}( - EvaluateSplitsKernel, left, right, task, this->SortedIdx(left), + dh::LaunchKernel{static_cast(combined_num_features), kBlockThreads, 0}( + EvaluateSplitsKernel, left, right, this->SortedIdx(left), evaluator, dh::ToSpan(feature_best_splits)); // Reduce to get best candidate for left and right child over all features @@ -365,7 +363,7 @@ void GPUHistEvaluator::CopyToHost(EvaluateSplitInputs -void GPUHistEvaluator::EvaluateSplits(GPUExpandEntry candidate, ObjInfo task, +void GPUHistEvaluator::EvaluateSplits(GPUExpandEntry candidate, EvaluateSplitInputs left, EvaluateSplitInputs right, common::Span out_entries) { @@ -373,7 +371,7 @@ void GPUHistEvaluator::EvaluateSplits(GPUExpandEntry candidate, Ob dh::TemporaryArray splits_out_storage(2); auto out_splits = dh::ToSpan(splits_out_storage); - this->EvaluateSplits(left, right, task, evaluator, out_splits); + this->EvaluateSplits(left, right, evaluator, out_splits); auto d_sorted_idx = this->SortedIdx(left); auto d_entries = out_entries; @@ -385,7 +383,7 @@ void GPUHistEvaluator::EvaluateSplits(GPUExpandEntry candidate, Ob auto fidx = out_splits[i].findex; if (split.is_cat && - !common::UseOneHot(input.FeatureBins(fidx), input.param.max_cat_to_onehot, task)) { + !common::UseOneHot(input.FeatureBins(fidx), input.param.max_cat_to_onehot)) { bool is_left = i == 0; auto out = is_left ? cats_out.first(cats_out.size() / 2) : cats_out.last(cats_out.size() / 2); SortBasedSplit(input, d_sorted_idx, fidx, is_left, out, &out_splits[i]); @@ -405,11 +403,11 @@ void GPUHistEvaluator::EvaluateSplits(GPUExpandEntry candidate, Ob template GPUExpandEntry GPUHistEvaluator::EvaluateSingleSplit( - EvaluateSplitInputs input, float weight, ObjInfo task) { + EvaluateSplitInputs input, float weight) { dh::TemporaryArray splits_out(1); auto out_split = dh::ToSpan(splits_out); auto evaluator = tree_evaluator_.GetEvaluator(); - this->EvaluateSplits(input, {}, task, evaluator, out_split); + this->EvaluateSplits(input, {}, evaluator, out_split); auto cats_out = this->DeviceCatStorage(input.nidx); auto d_sorted_idx = this->SortedIdx(input); @@ -421,7 +419,7 @@ GPUExpandEntry GPUHistEvaluator::EvaluateSingleSplit( auto fidx = out_split[i].findex; if (split.is_cat && - !common::UseOneHot(input.FeatureBins(fidx), input.param.max_cat_to_onehot, task)) { + !common::UseOneHot(input.FeatureBins(fidx), input.param.max_cat_to_onehot)) { SortBasedSplit(input, d_sorted_idx, fidx, true, cats_out, &out_split[i]); } diff --git a/src/tree/gpu_hist/evaluate_splits.cuh b/src/tree/gpu_hist/evaluate_splits.cuh index b03fd7b41b51..ab4d2d97f2c8 100644 --- a/src/tree/gpu_hist/evaluate_splits.cuh +++ b/src/tree/gpu_hist/evaluate_splits.cuh @@ -114,7 +114,7 @@ class GPUHistEvaluator { /** * \brief Reset the evaluator, should be called before any use. */ - void Reset(common::HistogramCuts const &cuts, common::Span ft, ObjInfo task, + void Reset(common::HistogramCuts const &cuts, common::Span ft, bst_feature_t n_features, TrainParam const ¶m, int32_t device); /** @@ -150,21 +150,20 @@ class GPUHistEvaluator { // impl of evaluate splits, contains CUDA kernels so it's public void EvaluateSplits(EvaluateSplitInputs left, - EvaluateSplitInputs right, ObjInfo task, + EvaluateSplitInputs right, TreeEvaluator::SplitEvaluator evaluator, common::Span out_splits); /** * \brief Evaluate splits for left and right nodes. */ - void EvaluateSplits(GPUExpandEntry candidate, ObjInfo task, + void EvaluateSplits(GPUExpandEntry candidate, EvaluateSplitInputs left, EvaluateSplitInputs right, common::Span out_splits); /** * \brief Evaluate splits for root node. */ - GPUExpandEntry EvaluateSingleSplit(EvaluateSplitInputs input, float weight, - ObjInfo task); + GPUExpandEntry EvaluateSingleSplit(EvaluateSplitInputs input, float weight); }; } // namespace tree } // namespace xgboost diff --git a/src/tree/gpu_hist/evaluator.cu b/src/tree/gpu_hist/evaluator.cu index 91ee3600dbe7..e94acdf494ce 100644 --- a/src/tree/gpu_hist/evaluator.cu +++ b/src/tree/gpu_hist/evaluator.cu @@ -16,7 +16,7 @@ namespace xgboost { namespace tree { template void GPUHistEvaluator::Reset(common::HistogramCuts const &cuts, - common::Span ft, ObjInfo task, + common::Span ft, bst_feature_t n_features, TrainParam const ¶m, int32_t device) { param_ = param; @@ -34,7 +34,7 @@ void GPUHistEvaluator::Reset(common::HistogramCuts const &cuts, auto idx = i - 1; if (common::IsCat(ft, idx)) { auto n_bins = ptrs[i] - ptrs[idx]; - bool use_sort = !common::UseOneHot(n_bins, to_onehot, task); + bool use_sort = !common::UseOneHot(n_bins, to_onehot); return use_sort; } return false; diff --git a/src/tree/hist/evaluate_splits.h b/src/tree/hist/evaluate_splits.h index 4e445a0680e5..e53e39eefc16 100644 --- a/src/tree/hist/evaluate_splits.h +++ b/src/tree/hist/evaluate_splits.h @@ -39,7 +39,6 @@ template class HistEvaluator { int32_t n_threads_ {0}; FeatureInteractionConstraintHost interaction_constraints_; std::vector snode_; - ObjInfo task_; // if sum of statistics for non-missing values in the node // is equal to sum of statistics for all values: @@ -244,7 +243,7 @@ template class HistEvaluator { } if (is_cat) { auto n_bins = cut_ptrs.at(fidx + 1) - cut_ptrs[fidx]; - if (common::UseOneHot(n_bins, param_.max_cat_to_onehot, task_)) { + if (common::UseOneHot(n_bins, param_.max_cat_to_onehot)) { EnumerateSplit<+1, kOneHot>(cut, {}, histogram, fidx, nidx, evaluator, best); EnumerateSplit<-1, kOneHot>(cut, {}, histogram, fidx, nidx, evaluator, best); } else { @@ -345,7 +344,6 @@ template class HistEvaluator { auto Evaluator() const { return tree_evaluator_.GetEvaluator(); } auto const& Stats() const { return snode_; } - auto Task() const { return task_; } float InitRoot(GradStats const& root_sum) { snode_.resize(1); @@ -363,12 +361,11 @@ template class HistEvaluator { // The column sampler must be constructed by caller since we need to preserve the rng // for the entire training session. explicit HistEvaluator(TrainParam const ¶m, MetaInfo const &info, int32_t n_threads, - std::shared_ptr sampler, ObjInfo task) + std::shared_ptr sampler) : param_{param}, column_sampler_{std::move(sampler)}, tree_evaluator_{param, static_cast(info.num_col_), GenericParameter::kCpuId}, - n_threads_{n_threads}, - task_{task} { + n_threads_{n_threads} { interaction_constraints_.Configure(param, info.num_col_); column_sampler_->Init(info.num_col_, info.feature_weights.HostVector(), param_.colsample_bynode, param_.colsample_bylevel, param_.colsample_bytree); diff --git a/src/tree/updater_approx.cc b/src/tree/updater_approx.cc index 4222cddb1ee9..51d8b8deddb1 100644 --- a/src/tree/updater_approx.cc +++ b/src/tree/updater_approx.cc @@ -29,10 +29,8 @@ DMLC_REGISTRY_FILE_TAG(updater_approx); namespace { // Return the BatchParam used by DMatrix. -template -auto BatchSpec(TrainParam const &p, common::Span hess, - HistEvaluator const &evaluator) { - return BatchParam{p.max_bin, hess, !evaluator.Task().const_hess}; +auto BatchSpec(TrainParam const &p, common::Span hess, ObjInfo const task) { + return BatchParam{p.max_bin, hess, !task.const_hess}; } auto BatchSpec(TrainParam const &p, common::Span hess) { @@ -47,7 +45,8 @@ class GloablApproxBuilder { std::shared_ptr col_sampler_; HistEvaluator evaluator_; HistogramBuilder histogram_builder_; - GenericParameter const *ctx_; + Context const *ctx_; + ObjInfo const task_; std::vector partitioner_; // Pointer to last updated tree, used for update prediction cache. @@ -65,8 +64,7 @@ class GloablApproxBuilder { int32_t n_total_bins = 0; partitioner_.clear(); // Generating the GHistIndexMatrix is quite slow, is there a way to speed it up? - for (auto const &page : - p_fmat->GetBatches(BatchSpec(param_, hess, evaluator_))) { + for (auto const &page : p_fmat->GetBatches(BatchSpec(param_, hess, task_))) { if (n_total_bins == 0) { n_total_bins = page.cut.TotalBins(); feature_values_ = page.cut; @@ -158,7 +156,7 @@ class GloablApproxBuilder { void LeafPartition(RegTree const &tree, common::Span hess, std::vector *p_out_position) { monitor_->Start(__func__); - if (!evaluator_.Task().UpdateTreeLeaf()) { + if (!task_.UpdateTreeLeaf()) { return; } for (auto const &part : partitioner_) { @@ -173,8 +171,9 @@ class GloablApproxBuilder { common::Monitor *monitor) : param_{std::move(param)}, col_sampler_{std::move(column_sampler)}, - evaluator_{param_, info, ctx->Threads(), col_sampler_, task}, + evaluator_{param_, info, ctx->Threads(), col_sampler_}, ctx_{ctx}, + task_{task}, monitor_{monitor} {} void UpdateTree(DMatrix *p_fmat, std::vector const &gpair, common::Span hess, diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 569188fd5374..8768e22a65e5 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -235,16 +235,14 @@ struct GPUHistMakerDevice { // Reset values for each update iteration // Note that the column sampler must be passed by value because it is not // thread safe - void Reset(HostDeviceVector* dh_gpair, DMatrix* dmat, int64_t num_columns, - ObjInfo task) { + void Reset(HostDeviceVector* dh_gpair, DMatrix* dmat, int64_t num_columns) { auto const& info = dmat->Info(); this->column_sampler.Init(num_columns, info.feature_weights.HostVector(), param.colsample_bynode, param.colsample_bylevel, param.colsample_bytree); dh::safe_cuda(cudaSetDevice(ctx_->gpu_id)); - this->evaluator_.Reset(page->Cuts(), feature_types, task, dmat->Info().num_col_, param, - ctx_->gpu_id); + this->evaluator_.Reset(page->Cuts(), feature_types, dmat->Info().num_col_, param, ctx_->gpu_id); this->interaction_constraints.Reset(); std::fill(node_sum_gradients.begin(), node_sum_gradients.end(), GradientPairPrecise{}); @@ -266,7 +264,7 @@ struct GPUHistMakerDevice { hist.Reset(); } - GPUExpandEntry EvaluateRootSplit(GradientPairPrecise root_sum, float weight, ObjInfo task) { + GPUExpandEntry EvaluateRootSplit(GradientPairPrecise root_sum, float weight) { int nidx = RegTree::kRoot; GPUTrainingParam gpu_param(param); auto sampled_features = column_sampler.GetFeatureSet(0); @@ -283,12 +281,12 @@ struct GPUHistMakerDevice { matrix.gidx_fvalue_map, matrix.min_fvalue, hist.GetNodeHistogram(nidx)}; - auto split = this->evaluator_.EvaluateSingleSplit(inputs, weight, task); + auto split = this->evaluator_.EvaluateSingleSplit(inputs, weight); return split; } - void EvaluateLeftRightSplits(GPUExpandEntry candidate, ObjInfo task, int left_nidx, - int right_nidx, const RegTree& tree, + void EvaluateLeftRightSplits(GPUExpandEntry candidate, int left_nidx, int right_nidx, + const RegTree& tree, common::Span pinned_candidates_out) { dh::TemporaryArray splits_out(2); GPUTrainingParam gpu_param(param); @@ -322,7 +320,7 @@ struct GPUHistMakerDevice { hist.GetNodeHistogram(right_nidx)}; dh::TemporaryArray entries(2); - this->evaluator_.EvaluateSplits(candidate, task, left, right, dh::ToSpan(entries)); + this->evaluator_.EvaluateSplits(candidate, left, right, dh::ToSpan(entries)); dh::safe_cuda(cudaMemcpyAsync(pinned_candidates_out.data(), entries.data().get(), sizeof(GPUExpandEntry) * entries.size(), cudaMemcpyDeviceToHost)); } @@ -609,7 +607,7 @@ struct GPUHistMakerDevice { tree[candidate.nid].RightChild()); } - GPUExpandEntry InitRoot(RegTree* p_tree, ObjInfo task, dh::AllReducer* reducer) { + GPUExpandEntry InitRoot(RegTree* p_tree, dh::AllReducer* reducer) { constexpr bst_node_t kRootNIdx = 0; dh::XGBCachingDeviceAllocator alloc; auto gpair_it = dh::MakeTransformIterator( @@ -630,7 +628,7 @@ struct GPUHistMakerDevice { (*p_tree)[kRootNIdx].SetLeaf(param.learning_rate * weight); // Generate first split - auto root_entry = this->EvaluateRootSplit(root_sum, weight, task); + auto root_entry = this->EvaluateRootSplit(root_sum, weight); return root_entry; } @@ -641,11 +639,11 @@ struct GPUHistMakerDevice { Driver driver(static_cast(param.grow_policy)); monitor.Start("Reset"); - this->Reset(gpair_all, p_fmat, p_fmat->Info().num_col_, task); + this->Reset(gpair_all, p_fmat, p_fmat->Info().num_col_); monitor.Stop("Reset"); monitor.Start("InitRoot"); - driver.Push({ this->InitRoot(p_tree, task, reducer) }); + driver.Push({ this->InitRoot(p_tree, reducer) }); monitor.Stop("InitRoot"); auto num_leaves = 1; @@ -682,7 +680,7 @@ struct GPUHistMakerDevice { monitor.Stop("BuildHist"); monitor.Start("EvaluateSplits"); - this->EvaluateLeftRightSplits(candidate, task, left_child_nidx, right_child_nidx, *p_tree, + this->EvaluateLeftRightSplits(candidate, left_child_nidx, right_child_nidx, *p_tree, new_candidates.subspan(i * 2, 2)); monitor.Stop("EvaluateSplits"); } else { diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 011733b4582a..d82a122be6fe 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -178,7 +178,7 @@ void QuantileHistMaker::Builder::LeafPartition( RegTree const &tree, common::Span gpair, std::vector *p_out_position) { monitor_->Start(__func__); - if (!evaluator_->Task().UpdateTreeLeaf()) { + if (!task_.UpdateTreeLeaf()) { return; } for (auto const &part : partitioner_) { @@ -363,7 +363,7 @@ void QuantileHistMaker::Builder::InitData(DMatrix *fmat, const Reg // store a pointer to the tree p_last_tree_ = &tree; evaluator_.reset(new HistEvaluator{ - param_, info, this->ctx_->Threads(), column_sampler_, task_}); + param_, info, this->ctx_->Threads(), column_sampler_}); monitor_->Stop(__func__); } diff --git a/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu b/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu index 0cbfc9f2a6cf..d49a256ce9ed 100644 --- a/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu +++ b/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu @@ -57,8 +57,7 @@ void TestEvaluateSingleSplit(bool is_categorical) { GPUHistEvaluator evaluator{ tparam, static_cast(feature_min_values.size()), 0}; dh::device_vector out_cats; - DeviceSplitCandidate result = - evaluator.EvaluateSingleSplit(input, 0, ObjInfo{ObjInfo::kRegression}).split; + DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, 0).split; EXPECT_EQ(result.findex, 1); EXPECT_EQ(result.fvalue, 11.0); @@ -101,8 +100,7 @@ TEST(GpuHist, EvaluateSingleSplitMissing) { dh::ToSpan(feature_histogram)}; GPUHistEvaluator evaluator(tparam, feature_set.size(), 0); - DeviceSplitCandidate result = - evaluator.EvaluateSingleSplit(input, 0, ObjInfo{ObjInfo::kRegression}).split; + DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, 0).split; EXPECT_EQ(result.findex, 0); EXPECT_EQ(result.fvalue, 1.0); @@ -114,10 +112,8 @@ TEST(GpuHist, EvaluateSingleSplitMissing) { TEST(GpuHist, EvaluateSingleSplitEmpty) { TrainParam tparam = ZeroParam(); GPUHistEvaluator evaluator(tparam, 1, 0); - DeviceSplitCandidate result = evaluator - .EvaluateSingleSplit(EvaluateSplitInputs{}, 0, - ObjInfo{ObjInfo::kRegression}) - .split; + DeviceSplitCandidate result = + evaluator.EvaluateSingleSplit(EvaluateSplitInputs{}, 0).split; EXPECT_EQ(result.findex, -1); EXPECT_LT(result.loss_chg, 0.0f); } @@ -152,8 +148,7 @@ TEST(GpuHist, EvaluateSingleSplitFeatureSampling) { dh::ToSpan(feature_histogram)}; GPUHistEvaluator evaluator(tparam, feature_min_values.size(), 0); - DeviceSplitCandidate result = - evaluator.EvaluateSingleSplit(input, 0, ObjInfo{ObjInfo::kRegression}).split; + DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, 0).split; EXPECT_EQ(result.findex, 1); EXPECT_EQ(result.fvalue, 11.0); @@ -191,8 +186,7 @@ TEST(GpuHist, EvaluateSingleSplitBreakTies) { dh::ToSpan(feature_histogram)}; GPUHistEvaluator evaluator(tparam, feature_min_values.size(), 0); - DeviceSplitCandidate result = - evaluator.EvaluateSingleSplit(input, 0, ObjInfo{ObjInfo::kRegression}).split; + DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, 0).split; EXPECT_EQ(result.findex, 0); EXPECT_EQ(result.fvalue, 1.0); @@ -243,8 +237,8 @@ TEST(GpuHist, EvaluateSplits) { GPUHistEvaluator evaluator{ tparam, static_cast(feature_min_values.size()), 0}; - evaluator.EvaluateSplits(input_left, input_right, ObjInfo{ObjInfo::kRegression}, - evaluator.GetEvaluator(), dh::ToSpan(out_splits)); + evaluator.EvaluateSplits(input_left, input_right, evaluator.GetEvaluator(), + dh::ToSpan(out_splits)); DeviceSplitCandidate result_left = out_splits[0]; EXPECT_EQ(result_left.findex, 1); @@ -264,8 +258,7 @@ TEST_F(TestPartitionBasedSplit, GpuHist) { cuts_.cut_values_.SetDevice(0); cuts_.min_vals_.SetDevice(0); - ObjInfo task{ObjInfo::kRegression}; - evaluator.Reset(cuts_, dh::ToSpan(ft), task, info_.num_col_, param_, 0); + evaluator.Reset(cuts_, dh::ToSpan(ft), info_.num_col_, param_, 0); dh::device_vector d_hist(hist_[0].size()); auto node_hist = hist_[0]; @@ -282,7 +275,7 @@ TEST_F(TestPartitionBasedSplit, GpuHist) { cuts_.cut_values_.ConstDeviceSpan(), cuts_.min_vals_.ConstDeviceSpan(), dh::ToSpan(d_hist)}; - auto split = evaluator.EvaluateSingleSplit(input, 0, ObjInfo{ObjInfo::kRegression}).split; + auto split = evaluator.EvaluateSingleSplit(input, 0).split; ASSERT_NEAR(split.loss_chg, best_score_, 1e-16); } } // namespace tree diff --git a/tests/cpp/tree/hist/test_evaluate_splits.cc b/tests/cpp/tree/hist/test_evaluate_splits.cc index e46726e7401c..8de84b2a1076 100644 --- a/tests/cpp/tree/hist/test_evaluate_splits.cc +++ b/tests/cpp/tree/hist/test_evaluate_splits.cc @@ -24,8 +24,8 @@ template void TestEvaluateSplits() { auto dmat = RandomDataGenerator(kRows, kCols, 0).Seed(3).GenerateDMatrix(); - auto evaluator = HistEvaluator{ - param, dmat->Info(), n_threads, sampler, ObjInfo{ObjInfo::kRegression}}; + auto evaluator = + HistEvaluator{param, dmat->Info(), n_threads, sampler}; common::HistCollection hist; std::vector row_gpairs = { {1.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {2.27f, 0.28f}, @@ -97,8 +97,7 @@ TEST(HistEvaluator, Apply) { param.UpdateAllowUnknown(Args{{"min_child_weight", "0"}, {"reg_lambda", "0.0"}}); auto dmat = RandomDataGenerator(kNRows, kNCols, 0).Seed(3).GenerateDMatrix(); auto sampler = std::make_shared(); - auto evaluator_ = HistEvaluator{param, dmat->Info(), 4, sampler, - ObjInfo{ObjInfo::kRegression}}; + auto evaluator_ = HistEvaluator{param, dmat->Info(), 4, sampler}; CPUExpandEntry entry{0, 0, 10.0f}; entry.split.left_sum = GradStats{0.4, 0.6f}; @@ -125,7 +124,7 @@ TEST_F(TestPartitionBasedSplit, CPUHist) { std::vector ft{FeatureType::kCategorical}; auto sampler = std::make_shared(); HistEvaluator evaluator{param_, info_, common::OmpGetNumThreads(0), - sampler, ObjInfo{ObjInfo::kRegression}}; + sampler}; evaluator.InitRoot(GradStats{total_gpair_}); RegTree tree; std::vector entries(1); @@ -156,8 +155,8 @@ auto CompareOneHotAndPartition(bool onehot) { int32_t n_threads = 16; auto sampler = std::make_shared(); - auto evaluator = HistEvaluator{ - param, dmat->Info(), n_threads, sampler, ObjInfo{ObjInfo::kRegression}}; + auto evaluator = + HistEvaluator{param, dmat->Info(), n_threads, sampler}; std::vector entries(1); for (auto const &gmat : dmat->GetBatches({32, param.sparse_threshold})) { diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 3c93c283917a..409a800379f1 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -264,7 +264,7 @@ TEST(GpuHist, EvaluateRootSplit) { info.num_col_ = kNCols; DeviceSplitCandidate res = - maker.EvaluateRootSplit({6.4f, 12.8f}, 0, ObjInfo{ObjInfo::kRegression}).split; + maker.EvaluateRootSplit({6.4f, 12.8f}, 0).split; ASSERT_EQ(res.findex, 7); ASSERT_NEAR(res.fvalue, 0.26, xgboost::kRtEps); @@ -302,11 +302,11 @@ void TestHistogramIndexImpl() { const auto &maker = hist_maker.maker; auto grad = GenerateRandomGradients(kNRows); grad.SetDevice(0); - maker->Reset(&grad, hist_maker_dmat.get(), kNCols, ObjInfo{ObjInfo::kRegression}); + maker->Reset(&grad, hist_maker_dmat.get(), kNCols); std::vector h_gidx_buffer(maker->page->gidx_buffer.HostVector()); const auto &maker_ext = hist_maker_ext.maker; - maker_ext->Reset(&grad, hist_maker_ext_dmat.get(), kNCols, ObjInfo{ObjInfo::kRegression}); + maker_ext->Reset(&grad, hist_maker_ext_dmat.get(), kNCols); std::vector h_gidx_buffer_ext(maker_ext->page->gidx_buffer.HostVector()); ASSERT_EQ(maker->page->Cuts().TotalBins(), maker_ext->page->Cuts().TotalBins()); From 8a03ca804ef91bcba35c21c7ce841e8846d873b4 Mon Sep 17 00:00:00 2001 From: jiamingy Date: Mon, 2 May 2022 21:47:08 +0800 Subject: [PATCH 3/4] Remove includes. --- src/common/categorical.h | 1 - src/tree/hist/evaluate_splits.h | 1 - 2 files changed, 2 deletions(-) diff --git a/src/common/categorical.h b/src/common/categorical.h index 341a887f48a9..1e3f5c960971 100644 --- a/src/common/categorical.h +++ b/src/common/categorical.h @@ -12,7 +12,6 @@ #include "xgboost/data.h" #include "xgboost/parameter.h" #include "xgboost/span.h" -#include "xgboost/task.h" namespace xgboost { namespace common { diff --git a/src/tree/hist/evaluate_splits.h b/src/tree/hist/evaluate_splits.h index e53e39eefc16..c7acdd383515 100644 --- a/src/tree/hist/evaluate_splits.h +++ b/src/tree/hist/evaluate_splits.h @@ -11,7 +11,6 @@ #include #include -#include "xgboost/task.h" #include "../param.h" #include "../constraints.h" #include "../split_evaluator.h" From ba40f5cf3e07c497d14f029af99687e540fcc072 Mon Sep 17 00:00:00 2001 From: jiamingy Date: Mon, 2 May 2022 22:04:06 +0800 Subject: [PATCH 4/4] lint. --- src/tree/gpu_hist/evaluate_splits.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tree/gpu_hist/evaluate_splits.cu b/src/tree/gpu_hist/evaluate_splits.cu index 26a571f25a9c..144af3201ad4 100644 --- a/src/tree/gpu_hist/evaluate_splits.cu +++ b/src/tree/gpu_hist/evaluate_splits.cu @@ -320,7 +320,7 @@ void GPUHistEvaluator::EvaluateSplits( // One block for each feature uint32_t constexpr kBlockThreads = 256; - dh::LaunchKernel{static_cast(combined_num_features), kBlockThreads, 0}( + dh::LaunchKernel {static_cast(combined_num_features), kBlockThreads, 0}( EvaluateSplitsKernel, left, right, this->SortedIdx(left), evaluator, dh::ToSpan(feature_best_splits));