From 444131a2e6efd21825b10074e62d2b33735058e6 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 29 Sep 2020 11:27:25 +0800 Subject: [PATCH] Add categorical data support to GPU Hist. (#6164) --- src/common/device_helpers.cuh | 15 ++ src/common/hist_util.cu | 2 +- src/tree/gpu_hist/evaluate_splits.cu | 127 ++++++++++---- src/tree/gpu_hist/evaluate_splits.cuh | 1 + src/tree/updater_gpu_common.cuh | 4 + src/tree/updater_gpu_hist.cu | 156 ++++++++++++------ tests/cpp/common/test_quantile.cu | 8 +- .../cpp/tree/gpu_hist/test_evaluate_splits.cu | 24 ++- tests/cpp/tree/test_gpu_hist.cu | 72 ++++++-- 9 files changed, 306 insertions(+), 103 deletions(-) diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 471ec31f42ac..57201deb490a 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -536,6 +536,21 @@ void CopyDeviceSpanToVector(std::vector *dst, xgboost::common::Span cudaMemcpyDeviceToHost)); } +template +void CopyToD(HContainer const &h, DContainer *d) { + if (h.empty()) { + d->clear(); + return; + } + d->resize(h.size()); + using HVT = std::remove_cv_t; + using DVT = std::remove_cv_t; + static_assert(std::is_same::value, + "Host and device containers must have same value type."); + dh::safe_cuda(cudaMemcpyAsync(d->data().get(), h.data(), h.size() * sizeof(HVT), + cudaMemcpyHostToDevice)); +} + // Keep track of pinned memory allocation struct PinnedMemory { void *temp_storage{nullptr}; diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index b60c2a5d52eb..376a64742abb 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -178,7 +178,7 @@ void ProcessBatch(int device, MetaInfo const &info, const SparsePage &page, dh::XGBCachingDeviceAllocator alloc; const auto& host_data = page.data.ConstHostVector(); dh::device_vector sorted_entries(host_data.begin() + begin, - host_data.begin() + end); + host_data.begin() + end); thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(), sorted_entries.end(), detail::EntryCompareOp()); diff --git a/src/tree/gpu_hist/evaluate_splits.cu b/src/tree/gpu_hist/evaluate_splits.cu index ef1c6946cf3b..7d7edcb55e36 100644 --- a/src/tree/gpu_hist/evaluate_splits.cu +++ b/src/tree/gpu_hist/evaluate_splits.cu @@ -1,8 +1,9 @@ /*! * Copyright 2020 by XGBoost Contributors */ -#include "evaluate_splits.cuh" #include +#include "evaluate_splits.cuh" +#include "../../common/categorical.h" namespace xgboost { namespace tree { @@ -66,13 +67,84 @@ ReduceFeature(common::Span feature_histogram, if (threadIdx.x == 0) { shared_sum = local_sum; } - __syncthreads(); + cub::CTA_SYNC(); return shared_sum; } +template struct OneHotBin { + GradientSumT __device__ operator()( + bool thread_active, uint32_t scan_begin, + SumCallbackOp*, + GradientSumT const &missing, + EvaluateSplitInputs const &inputs, TempStorageT *) { + GradientSumT bin = thread_active + ? inputs.gradient_histogram[scan_begin + threadIdx.x] + : GradientSumT(); + auto rest = inputs.parent_sum - bin - missing; + return rest; + } +}; + +template +struct UpdateOneHot { + void __device__ operator()(bool missing_left, uint32_t scan_begin, float gain, + bst_feature_t fidx, GradientSumT const &missing, + GradientSumT const &bin, + EvaluateSplitInputs const &inputs, + DeviceSplitCandidate *best_split) { + int split_gidx = (scan_begin + threadIdx.x); + float fvalue = inputs.feature_values[split_gidx]; + GradientSumT left = missing_left ? bin + missing : bin; + GradientSumT right = inputs.parent_sum - left; + best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, fidx, + GradientPair(left), GradientPair(right), true, + inputs.param); + } +}; + +template +struct NumericBin { + GradientSumT __device__ operator()(bool thread_active, uint32_t scan_begin, + SumCallbackOp* prefix_callback, + GradientSumT const &missing, + EvaluateSplitInputs inputs, + TempStorageT *temp_storage) { + GradientSumT bin = thread_active + ? inputs.gradient_histogram[scan_begin + threadIdx.x] + : GradientSumT(); + ScanT(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum(), *prefix_callback); + return bin; + } +}; + +template +struct UpdateNumeric { + void __device__ operator()(bool missing_left, uint32_t scan_begin, float gain, + bst_feature_t fidx, GradientSumT const &missing, + GradientSumT const &bin, + EvaluateSplitInputs const &inputs, + DeviceSplitCandidate *best_split) { + // Use pointer from cut to indicate begin and end of bins for each feature. + uint32_t gidx_begin = inputs.feature_segments[fidx]; // begining bin + int split_gidx = (scan_begin + threadIdx.x) - 1; + float fvalue; + if (split_gidx < static_cast(gidx_begin)) { + fvalue = inputs.min_fvalue[fidx]; + } else { + fvalue = inputs.feature_values[split_gidx]; + } + GradientSumT left = missing_left ? bin + missing : bin; + GradientSumT right = inputs.parent_sum - left; + best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, + fidx, GradientPair(left), GradientPair(right), + false, inputs.param); + } +}; + /*! \brief Find the thread with best gain. */ template + typename MaxReduceT, typename TempStorageT, typename GradientSumT, + typename BinFn, typename UpdateFn> __device__ void EvaluateFeature( int fidx, EvaluateSplitInputs inputs, TreeEvaluator::SplitEvaluator evaluator, @@ -83,12 +155,14 @@ __device__ void EvaluateFeature( uint32_t gidx_begin = inputs.feature_segments[fidx]; // begining bin uint32_t gidx_end = inputs.feature_segments[fidx + 1]; // end bin for i^th feature + auto feature_hist = inputs.gradient_histogram.subspan(gidx_begin, gidx_end - gidx_begin); + auto bin_fn = BinFn(); + auto update_fn = UpdateFn(); // Sum histogram bins for current feature GradientSumT const feature_sum = ReduceFeature( - inputs.gradient_histogram.subspan(gidx_begin, gidx_end - gidx_begin), - temp_storage); + feature_hist, temp_storage); GradientSumT const missing = inputs.parent_sum - feature_sum; float const null_gain = -std::numeric_limits::infinity(); @@ -97,12 +171,7 @@ __device__ void EvaluateFeature( for (int scan_begin = gidx_begin; scan_begin < gidx_end; scan_begin += BLOCK_THREADS) { bool thread_active = (scan_begin + threadIdx.x) < gidx_end; - - // Gradient value for current bin. - GradientSumT bin = thread_active - ? inputs.gradient_histogram[scan_begin + threadIdx.x] - : GradientSumT(); - ScanT(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum(), prefix_op); + auto bin = bin_fn(thread_active, scan_begin, &prefix_op, missing, inputs, temp_storage); // Whether the gradient of missing values is put to the left side. bool missing_left = true; @@ -127,24 +196,14 @@ __device__ void EvaluateFeature( block_max = best; } - __syncthreads(); + cub::CTA_SYNC(); // Best thread updates split if (threadIdx.x == block_max.key) { - int split_gidx = (scan_begin + threadIdx.x) - 1; - float fvalue; - if (split_gidx < static_cast(gidx_begin)) { - fvalue = inputs.min_fvalue[fidx]; - } else { - fvalue = inputs.feature_values[split_gidx]; - } - GradientSumT left = missing_left ? bin + missing : bin; - GradientSumT right = inputs.parent_sum - left; - best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, - fidx, GradientPair(left), GradientPair(right), - inputs.param); + update_fn(missing_left, scan_begin, gain, fidx, missing, bin, inputs, + best_split); } - __syncthreads(); + cub::CTA_SYNC(); } } @@ -186,11 +245,21 @@ __global__ void EvaluateSplitsKernel( // One block for each feature. Features are sampled, so fidx != blockIdx.x int fidx = inputs.feature_set[is_left ? blockIdx.x : blockIdx.x - left.feature_set.size()]; + if (common::IsCat(inputs.feature_types, fidx)) { + EvaluateFeature, + UpdateOneHot>(fidx, inputs, evaluator, &best_split, + &temp_storage); + } else { + EvaluateFeature, + UpdateNumeric>(fidx, inputs, evaluator, &best_split, + &temp_storage); + } - EvaluateFeature( - fidx, inputs, evaluator, &best_split, &temp_storage); - - __syncthreads(); + cub::CTA_SYNC(); if (threadIdx.x == 0) { // Record best loss for each feature diff --git a/src/tree/gpu_hist/evaluate_splits.cuh b/src/tree/gpu_hist/evaluate_splits.cuh index f847518dba68..e30901134cda 100644 --- a/src/tree/gpu_hist/evaluate_splits.cuh +++ b/src/tree/gpu_hist/evaluate_splits.cuh @@ -18,6 +18,7 @@ struct EvaluateSplitInputs { GradientSumT parent_sum; GPUTrainingParam param; common::Span feature_set; + common::Span feature_types; common::Span feature_segments; common::Span feature_values; common::Span min_fvalue; diff --git a/src/tree/updater_gpu_common.cuh b/src/tree/updater_gpu_common.cuh index e48e0d4f2cfc..4219a3399e40 100644 --- a/src/tree/updater_gpu_common.cuh +++ b/src/tree/updater_gpu_common.cuh @@ -59,6 +59,7 @@ struct DeviceSplitCandidate { DefaultDirection dir {kLeftDir}; int findex {-1}; float fvalue {0}; + bool is_cat { false }; GradientPair left_sum; GradientPair right_sum; @@ -79,6 +80,7 @@ struct DeviceSplitCandidate { float fvalue_in, int findex_in, GradientPair left_sum_in, GradientPair right_sum_in, + bool cat, const GPUTrainingParam& param) { if (loss_chg_in > loss_chg && left_sum_in.GetHess() >= param.min_child_weight && @@ -86,6 +88,7 @@ struct DeviceSplitCandidate { loss_chg = loss_chg_in; dir = dir_in; fvalue = fvalue_in; + is_cat = cat; left_sum = left_sum_in; right_sum = right_sum_in; findex = findex_in; @@ -98,6 +101,7 @@ struct DeviceSplitCandidate { << "dir: " << c.dir << ", " << "findex: " << c.findex << ", " << "fvalue: " << c.fvalue << ", " + << "is_cat: " << c.is_cat << ", " << "left sum: " << c.left_sum << ", " << "right sum: " << c.right_sum << std::endl; return os; diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 19698695544e..394dfd5d5c1a 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -19,7 +19,9 @@ #include "../common/io.h" #include "../common/device_helpers.cuh" #include "../common/hist_util.h" +#include "../common/bitfield.h" #include "../common/timer.h" +#include "../common/categorical.h" #include "../data/ellpack_page.cuh" #include "param.h" @@ -161,6 +163,7 @@ template struct GPUHistMakerDevice { int device_id; EllpackPageImpl* page; + common::Span feature_types; BatchParam batch_param; std::unique_ptr row_partitioner; @@ -169,7 +172,6 @@ struct GPUHistMakerDevice { common::Span gpair; dh::caching_device_vector monotone_constraints; - dh::caching_device_vector prediction_cache; /*! \brief Sum gradient for each node. */ std::vector node_sum_gradients; @@ -191,9 +193,12 @@ struct GPUHistMakerDevice { std::unique_ptr sampler; std::unique_ptr feature_groups; + // Storing split categories for last node. + dh::caching_device_vector node_categories; GPUHistMakerDevice(int _device_id, EllpackPageImpl* _page, + common::Span _feature_types, bst_uint _n_rows, TrainParam _param, uint32_t column_sampler_seed, @@ -202,6 +207,7 @@ struct GPUHistMakerDevice { BatchParam _batch_param) : device_id(_device_id), page(_page), + feature_types{_feature_types}, param(std::move(_param)), tree_evaluator(param, n_features, _device_id), column_sampler(column_sampler_seed), @@ -293,6 +299,7 @@ struct GPUHistMakerDevice { {root_sum.GetGrad(), root_sum.GetHess()}, gpu_param, feature_set, + feature_types, matrix.feature_segments, matrix.gidx_fvalue_map, matrix.min_fvalue, @@ -331,6 +338,7 @@ struct GPUHistMakerDevice { candidate.split.left_sum.GetHess()}, gpu_param, left_feature_set, + feature_types, matrix.feature_segments, matrix.gidx_fvalue_map, matrix.min_fvalue, @@ -341,6 +349,7 @@ struct GPUHistMakerDevice { candidate.split.right_sum.GetHess()}, gpu_param, right_feature_set, + feature_types, matrix.feature_segments, matrix.gidx_fvalue_map, matrix.min_fvalue, @@ -399,8 +408,11 @@ struct GPUHistMakerDevice { hist.HistogramExists(nidx_parent); } - void UpdatePosition(int nidx, RegTree::Node split_node) { + void UpdatePosition(int nidx, RegTree* p_tree) { + RegTree::Node split_node = (*p_tree)[nidx]; + auto split_type = p_tree->NodeSplitType(nidx); auto d_matrix = page->GetDeviceAccessor(device_id); + auto node_cats = dh::ToSpan(node_categories); row_partitioner->UpdatePosition( nidx, split_node.LeftChild(), split_node.RightChild(), @@ -409,11 +421,17 @@ struct GPUHistMakerDevice { bst_float cut_value = d_matrix.GetFvalue(ridx, split_node.SplitIndex()); // Missing value - int new_position = 0; + bst_node_t new_position = 0; if (isnan(cut_value)) { new_position = split_node.DefaultChild(); } else { - if (cut_value <= split_node.SplitCond()) { + bool go_left = true; + if (split_type == FeatureType::kCategorical) { + go_left = common::Decision(node_cats, common::AsCat(cut_value)); + } else { + go_left = cut_value <= split_node.SplitCond(); + } + if (go_left) { new_position = split_node.LeftChild(); } else { new_position = split_node.RightChild(); @@ -428,59 +446,84 @@ struct GPUHistMakerDevice { // prediction cache void FinalisePosition(RegTree const* p_tree, DMatrix* p_fmat) { dh::TemporaryArray d_nodes(p_tree->GetNodes().size()); - dh::safe_cuda(cudaMemcpy(d_nodes.data().get(), p_tree->GetNodes().data(), - d_nodes.size() * sizeof(RegTree::Node), - cudaMemcpyHostToDevice)); + dh::safe_cuda(cudaMemcpyAsync(d_nodes.data().get(), p_tree->GetNodes().data(), + d_nodes.size() * sizeof(RegTree::Node), + cudaMemcpyHostToDevice)); + auto const& h_split_types = p_tree->GetSplitTypes(); + auto const& categories = p_tree->GetSplitCategories(); + auto const& categories_segments = p_tree->GetSplitCategoriesPtr(); + + dh::caching_device_vector d_split_types; + dh::caching_device_vector d_categories; + dh::caching_device_vector d_categories_segments; + + if (!categories.empty()) { + dh::CopyToD(h_split_types, &d_split_types); + dh::CopyToD(categories, &d_categories); + dh::CopyToD(categories_segments, &d_categories_segments); + } if (row_partitioner->GetRows().size() != p_fmat->Info().num_row_) { row_partitioner.reset(); // Release the device memory first before reallocating row_partitioner.reset(new RowPartitioner(device_id, p_fmat->Info().num_row_)); } if (page->n_rows == p_fmat->Info().num_row_) { - FinalisePositionInPage(page, dh::ToSpan(d_nodes)); + FinalisePositionInPage(page, dh::ToSpan(d_nodes), + dh::ToSpan(d_split_types), dh::ToSpan(d_categories), + dh::ToSpan(d_categories_segments)); } else { for (auto& batch : p_fmat->GetBatches(batch_param)) { - FinalisePositionInPage(batch.Impl(), dh::ToSpan(d_nodes)); + FinalisePositionInPage(batch.Impl(), dh::ToSpan(d_nodes), + dh::ToSpan(d_split_types), dh::ToSpan(d_categories), + dh::ToSpan(d_categories_segments)); } } } - void FinalisePositionInPage(EllpackPageImpl* page, const common::Span d_nodes) { + void FinalisePositionInPage(EllpackPageImpl *page, + const common::Span d_nodes, + common::Span d_feature_types, + common::Span categories, + common::Span categories_segments) { auto d_matrix = page->GetDeviceAccessor(device_id); row_partitioner->FinalisePosition( [=] __device__(size_t row_id, int position) { - if (!d_matrix.IsInRange(row_id)) { - return RowPartitioner::kIgnoredTreePosition; - } - auto node = d_nodes[position]; + // What happens if user prune the tree? + if (!d_matrix.IsInRange(row_id)) { + return RowPartitioner::kIgnoredTreePosition; + } + auto node = d_nodes[position]; - while (!node.IsLeaf()) { - bst_float element = d_matrix.GetFvalue(row_id, node.SplitIndex()); - // Missing value - if (isnan(element)) { - position = node.DefaultChild(); - } else { - if (element <= node.SplitCond()) { - position = node.LeftChild(); - } else { - position = node.RightChild(); + while (!node.IsLeaf()) { + bst_float element = d_matrix.GetFvalue(row_id, node.SplitIndex()); + // Missing value + if (isnan(element)) { + position = node.DefaultChild(); + } else { + bool go_left = true; + if (common::IsCat(d_feature_types, position)) { + auto node_cats = + categories.subspan(categories_segments[position].beg, + categories_segments[position].size); + go_left = common::Decision(node_cats, common::AsCat(element)); + } else { + go_left = element <= node.SplitCond(); + } + if (go_left) { + position = node.LeftChild(); + } else { + position = node.RightChild(); + } + } + node = d_nodes[position]; } - } - node = d_nodes[position]; - } - return position; - }); + return position; + }); } - void UpdatePredictionCache(bst_float* out_preds_d) { + void UpdatePredictionCache(common::Span out_preds_d) { dh::safe_cuda(cudaSetDevice(device_id)); auto d_ridx = row_partitioner->GetRows(); - if (prediction_cache.size() != d_ridx.size()) { - prediction_cache.resize(d_ridx.size()); - dh::safe_cuda(cudaMemcpyAsync(prediction_cache.data().get(), out_preds_d, - prediction_cache.size() * sizeof(bst_float), - cudaMemcpyDefault)); - } GPUTrainingParam param_d(param); dh::TemporaryArray device_node_sum_gradients(node_sum_gradients.size()); @@ -491,21 +534,16 @@ struct GPUHistMakerDevice { cudaMemcpyHostToDevice)); auto d_position = row_partitioner->GetPosition(); auto d_node_sum_gradients = device_node_sum_gradients.data().get(); - auto d_prediction_cache = prediction_cache.data().get(); auto evaluator = tree_evaluator.GetEvaluator(); dh::LaunchN( - device_id, prediction_cache.size(), [=] __device__(int local_idx) { + device_id, out_preds_d.size(), [=] __device__(int local_idx) { int pos = d_position[local_idx]; bst_float weight = evaluator.CalcWeight(pos, param_d, GradStats{d_node_sum_gradients[pos]}); - d_prediction_cache[d_ridx[local_idx]] += + out_preds_d[d_ridx[local_idx]] += weight * param_d.learning_rate; }); - - dh::safe_cuda(cudaMemcpyAsync( - out_preds_d, prediction_cache.data().get(), - prediction_cache.size() * sizeof(bst_float), cudaMemcpyDefault)); row_partitioner.reset(); } @@ -561,11 +599,27 @@ struct GPUHistMakerDevice { auto left_weight = candidate.left_weight * param.learning_rate; auto right_weight = candidate.right_weight * param.learning_rate; - tree.ExpandNode(candidate.nid, candidate.split.findex, - candidate.split.fvalue, candidate.split.dir == kLeftDir, - base_weight, left_weight, right_weight, - candidate.split.loss_chg, parent_sum.GetHess(), - candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess()); + auto is_cat = candidate.split.is_cat; + if (is_cat) { + auto cat = common::AsCat(candidate.split.fvalue); + std::vector split_cats(LBitField32::ComputeStorageSize(std::max(cat+1, 1)), 0); + LBitField32 cats_bits(split_cats); + cats_bits.Set(cat); + dh::CopyToD(split_cats, &node_categories); + tree.ExpandCategorical( + candidate.nid, candidate.split.findex, split_cats, + candidate.split.dir == kLeftDir, base_weight, left_weight, + right_weight, candidate.split.loss_chg, parent_sum.GetHess(), + candidate.split.left_sum.GetHess(), + candidate.split.right_sum.GetHess()); + } else { + tree.ExpandNode(candidate.nid, candidate.split.findex, + candidate.split.fvalue, candidate.split.dir == kLeftDir, + base_weight, left_weight, right_weight, + candidate.split.loss_chg, parent_sum.GetHess(), + candidate.split.left_sum.GetHess(), + candidate.split.right_sum.GetHess()); + } // Set up child constraints auto left_child = tree[candidate.nid].LeftChild(); @@ -664,7 +718,7 @@ struct GPUHistMakerDevice { if (ExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx), num_leaves)) { monitor.Start("UpdatePosition"); - this->UpdatePosition(candidate.nid, (*p_tree)[candidate.nid]); + this->UpdatePosition(candidate.nid, p_tree); monitor.Stop("UpdatePosition"); monitor.Start("BuildHist"); @@ -752,8 +806,10 @@ class GPUHistMakerSpecialised { }; auto page = (*dmat->GetBatches(batch_param).begin()).Impl(); dh::safe_cuda(cudaSetDevice(device_)); + info_->feature_types.SetDevice(device_); maker.reset(new GPUHistMakerDevice(device_, page, + info_->feature_types.ConstDeviceSpan(), info_->num_row_, param_, column_sampling_seed, @@ -804,7 +860,7 @@ class GPUHistMakerSpecialised { } monitor_.Start("UpdatePredictionCache"); p_out_preds->SetDevice(device_); - maker->UpdatePredictionCache(p_out_preds->DevicePointer()); + maker->UpdatePredictionCache(p_out_preds->DeviceSpan()); monitor_.Stop("UpdatePredictionCache"); return true; } diff --git a/tests/cpp/common/test_quantile.cu b/tests/cpp/common/test_quantile.cu index 8bb426aa4d76..5704de70b7a1 100644 --- a/tests/cpp/common/test_quantile.cu +++ b/tests/cpp/common/test_quantile.cu @@ -95,8 +95,7 @@ void TestQuantileElemRank(int32_t device, Span in, TEST(GPUQuantile, Prune) { constexpr size_t kRows = 1000, kCols = 100; - RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, - MetaInfo const &info) { + RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) { HostDeviceVector ft; SketchContainer sketch(ft, n_bins, kCols, kRows, 0); @@ -293,9 +292,8 @@ TEST(GPUQuantile, AllReduceBasic) { } constexpr size_t kRows = 1000, kCols = 100; - RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, - MetaInfo const &info) { - // Set up single node version + RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) { + // Set up single node version; HostDeviceVector ft; SketchContainer sketch_on_single_node(ft, n_bins, kCols, kRows, 0); diff --git a/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu b/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu index 84b2d13c7fb9..d90d47b150a7 100644 --- a/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu +++ b/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu @@ -15,7 +15,7 @@ auto ZeroParam() { } } // anonymous namespace -TEST(GpuHist, EvaluateSingleSplit) { +void TestEvaluateSingleSplit(bool is_categorical) { thrust::device_vector out_splits(1); GradientPair parent_sum(0.0, 1.0); TrainParam tparam = ZeroParam(); @@ -33,11 +33,19 @@ TEST(GpuHist, EvaluateSingleSplit) { thrust::device_vector feature_histogram = std::vector{ {-0.5, 0.5}, {0.5, 0.5}, {-1.0, 0.5}, {1.0, 0.5}}; + thrust::device_vector monotonic_constraints(feature_set.size(), 0); + dh::device_vector feature_types(feature_set.size(), + FeatureType::kCategorical); + common::Span d_feature_types; + if (is_categorical) { + d_feature_types = dh::ToSpan(feature_types); + } EvaluateSplitInputs input{1, parent_sum, param, dh::ToSpan(feature_set), + d_feature_types, dh::ToSpan(feature_segments), dh::ToSpan(feature_values), dh::ToSpan(feature_min_values), @@ -55,6 +63,14 @@ TEST(GpuHist, EvaluateSingleSplit) { parent_sum.GetHess()); } +TEST(GpuHist, EvaluateSingleSplit) { + TestEvaluateSingleSplit(false); +} + +TEST(GpuHist, EvaluateCategoricalSplit) { + TestEvaluateSingleSplit(true); +} + TEST(GpuHist, EvaluateSingleSplitMissing) { thrust::device_vector out_splits(1); GradientPair parent_sum(1.0, 1.5); @@ -74,6 +90,7 @@ TEST(GpuHist, EvaluateSingleSplitMissing) { parent_sum, param, dh::ToSpan(feature_set), + {}, dh::ToSpan(feature_segments), dh::ToSpan(feature_values), dh::ToSpan(feature_min_values), @@ -134,6 +151,7 @@ TEST(GpuHist, EvaluateSingleSplitFeatureSampling) { parent_sum, param, dh::ToSpan(feature_set), + {}, dh::ToSpan(feature_segments), dh::ToSpan(feature_values), dh::ToSpan(feature_min_values), @@ -174,6 +192,7 @@ TEST(GpuHist, EvaluateSingleSplitBreakTies) { parent_sum, param, dh::ToSpan(feature_set), + {}, dh::ToSpan(feature_segments), dh::ToSpan(feature_values), dh::ToSpan(feature_min_values), @@ -215,6 +234,7 @@ TEST(GpuHist, EvaluateSplits) { parent_sum, param, dh::ToSpan(feature_set), + {}, dh::ToSpan(feature_segments), dh::ToSpan(feature_values), dh::ToSpan(feature_min_values), @@ -224,6 +244,7 @@ TEST(GpuHist, EvaluateSplits) { parent_sum, param, dh::ToSpan(feature_set), + {}, dh::ToSpan(feature_segments), dh::ToSpan(feature_values), dh::ToSpan(feature_min_values), @@ -241,6 +262,5 @@ TEST(GpuHist, EvaluateSplits) { EXPECT_EQ(result_right.findex, 0); EXPECT_EQ(result_right.fvalue, 1.0); } - } // namespace tree } // namespace xgboost diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index ec598c5fce56..37f738c24ae7 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -80,7 +80,7 @@ void TestBuildHist(bool use_shared_memory_histograms) { param.Init(args); auto page = BuildEllpackPage(kNRows, kNCols); BatchParam batch_param{}; - GPUHistMakerDevice maker(0, page.get(), kNRows, param, kNCols, kNCols, + GPUHistMakerDevice maker(0, page.get(), {}, kNRows, param, kNCols, kNCols, true, batch_param); xgboost::SimpleLCG gen; xgboost::SimpleRealUniformDistribution dist(0.0f, 1.0f); @@ -130,6 +130,48 @@ TEST(GpuHist, BuildHistSharedMem) { TestBuildHist(true); } +TEST(GpuHist, ApplySplit) { + RegTree tree; + ExpandEntry candidate; + candidate.nid = 0; + candidate.left_weight = 1.0f; + candidate.right_weight = 2.0f; + candidate.base_weight = 3.0f; + candidate.split.is_cat = true; + candidate.split.fvalue = 1.0f; // at cat 1 + + size_t n_rows = 10; + size_t n_cols = 10; + + auto m = RandomDataGenerator{n_rows, n_cols, 0}.GenerateDMatrix(true); + GenericParameter p; + p.InitAllowUnknown(Args{}); + + TrainParam tparam; + tparam.InitAllowUnknown(Args{}); + BatchParam bparam; + bparam.gpu_id = 0; + bparam.max_bin = 3; + bparam.gpu_page_size = 0; + + for (auto& ellpack : m->GetBatches(bparam)){ + auto impl = ellpack.Impl(); + HostDeviceVector feature_types(10, FeatureType::kCategorical); + feature_types.SetDevice(bparam.gpu_id); + tree::GPUHistMakerDevice updater( + 0, impl, feature_types.ConstDeviceSpan(), n_rows, tparam, 0, n_cols, true, bparam); + updater.ApplySplit(candidate, &tree); + + ASSERT_EQ(tree.GetSplitTypes().size(), 3); + ASSERT_EQ(tree.GetSplitTypes()[0], FeatureType::kCategorical); + ASSERT_EQ(tree.GetSplitCategories().size(), 1); + uint32_t bits = 1u << 30; // bits: 0, 1, 0, 0, 0, ..., 0 + ASSERT_EQ(tree.GetSplitCategories().back(), bits); + + ASSERT_EQ(updater.node_categories.size(), 1); + } +} + HistogramCutsWrapper GetHostCutMatrix () { HistogramCutsWrapper cmat; cmat.SetPtrs({0, 3, 6, 9, 12, 15, 18, 21, 24}); @@ -154,19 +196,18 @@ TEST(GpuHist, EvaluateRootSplit) { TrainParam param; - std::vector> args { - {"max_depth", "1"}, - {"max_leaves", "0"}, - - // Disable all other parameters. - {"colsample_bynode", "1"}, - {"colsample_bylevel", "1"}, - {"colsample_bytree", "1"}, - {"min_child_weight", "0.01"}, - {"reg_alpha", "0"}, - {"reg_lambda", "0"}, - {"max_delta_step", "0"} - }; + std::vector> args{ + {"max_depth", "1"}, + {"max_leaves", "0"}, + + // Disable all other parameters. + {"colsample_bynode", "1"}, + {"colsample_bylevel", "1"}, + {"colsample_bytree", "1"}, + {"min_child_weight", "0.01"}, + {"reg_alpha", "0"}, + {"reg_lambda", "0"}, + {"max_delta_step", "0"}}; param.Init(args); for (size_t i = 0; i < kNCols; ++i) { param.monotone_constraints.emplace_back(0); @@ -178,7 +219,7 @@ TEST(GpuHist, EvaluateRootSplit) { auto page = BuildEllpackPage(kNRows, kNCols); BatchParam batch_param{}; GPUHistMakerDevice - maker(0, page.get(), kNRows, param, kNCols, kNCols, true, batch_param); + maker(0, page.get(), {}, kNRows, param, kNCols, kNCols, true, batch_param); // Initialize GPUHistMakerDevice::node_sum_gradients maker.node_sum_gradients = {}; @@ -257,7 +298,6 @@ void TestHistogramIndexImpl() { ASSERT_EQ(maker->page->Cuts().TotalBins(), maker_ext->page->Cuts().TotalBins()); ASSERT_EQ(maker->page->gidx_buffer.Size(), maker_ext->page->gidx_buffer.Size()); - } TEST(GpuHist, TestHistogramIndex) {