diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index e899925cd146..8d0efd41872e 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -193,7 +193,7 @@ struct GPUHistMakerDevice { std::unique_ptr sampler; std::unique_ptr feature_groups; - // Storing split categories for 1 node. + // Storing split categories for last node. dh::caching_device_vector node_categories; GPUHistMakerDevice(int _device_id, @@ -605,10 +605,7 @@ struct GPUHistMakerDevice { std::vector split_cats(LBitField32::ComputeStorageSize(std::max(cat+1, 1))); LBitField32 cats_bits(split_cats); cats_bits.Set(cat); - node_categories.resize(split_cats.size()); - dh::safe_cuda(cudaMemcpyAsync( - node_categories.data().get(), split_cats.data(), - split_cats.size() * sizeof(uint32_t), cudaMemcpyHostToDevice)); + dh::CopyToD(split_cats, &node_categories); tree.ExpandCategorical( candidate.nid, candidate.split.findex, split_cats, candidate.split.dir == kLeftDir, base_weight, left_weight, diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 2c4b6be34dd6..4a96cb2877a3 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -130,6 +130,47 @@ TEST(GpuHist, BuildHistSharedMem) { TestBuildHist(true); } +TEST(GpuHist, ApplySplit) { + RegTree tree; + ExpandEntry candidate; + candidate.left_weight = 1.0f; + candidate.right_weight = 2.0f; + candidate.base_weight = 3.0f; + candidate.split.is_cat = true; + candidate.split.fvalue = 1.0f; // at cat 1 + + size_t n_rows = 10; + size_t n_cols = 10; + + auto m = RandomDataGenerator{n_rows, n_cols, 0}.GenerateDMatrix(true); + GenericParameter p; + p.InitAllowUnknown(Args{}); + + TrainParam tparam; + tparam.InitAllowUnknown(Args{}); + BatchParam bparam; + bparam.gpu_id = 0; + bparam.max_bin = 3; + bparam.gpu_page_size = 0; + + for (auto& ellpack : m->GetBatches(bparam)){ + auto impl = ellpack.Impl(); + HostDeviceVector feature_types(10, FeatureType::kCategorical); + feature_types.SetDevice(bparam.gpu_id); + tree::GPUHistMakerDevice updater( + 0, impl, feature_types.ConstDeviceSpan(), n_rows, tparam, 0, n_cols, true, bparam); + updater.ApplySplit(candidate, &tree); + + ASSERT_EQ(tree.GetSplitTypes().size(), 3); + ASSERT_EQ(tree.GetSplitTypes()[0], FeatureType::kCategorical); + ASSERT_EQ(tree.GetSplitCategories().size(), 1); + uint32_t bits = 1u << 30; // bits: 0, 1, 0, 0, 0, ..., 0 + ASSERT_EQ(tree.GetSplitCategories().back(), bits); + + ASSERT_EQ(updater.node_categories.size(), 1); + } +} + HistogramCutsWrapper GetHostCutMatrix () { HistogramCutsWrapper cmat; cmat.SetPtrs({0, 3, 6, 9, 12, 15, 18, 21, 24});