Skip to content

Commit

Permalink
Add apply split test.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Sep 25, 2020
1 parent cca9ec0 commit 8c01055
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 5 deletions.
7 changes: 2 additions & 5 deletions src/tree/updater_gpu_hist.cu
Expand Up @@ -193,7 +193,7 @@ struct GPUHistMakerDevice {
std::unique_ptr<GradientBasedSampler> sampler;

std::unique_ptr<FeatureGroups> feature_groups;
// Storing split categories for 1 node.
// Storing split categories for last node.
dh::caching_device_vector<uint32_t> node_categories;

GPUHistMakerDevice(int _device_id,
Expand Down Expand Up @@ -605,10 +605,7 @@ struct GPUHistMakerDevice {
std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(std::max(cat+1, 1)));
LBitField32 cats_bits(split_cats);
cats_bits.Set(cat);
node_categories.resize(split_cats.size());
dh::safe_cuda(cudaMemcpyAsync(
node_categories.data().get(), split_cats.data(),
split_cats.size() * sizeof(uint32_t), cudaMemcpyHostToDevice));
dh::CopyToD(split_cats, &node_categories);
tree.ExpandCategorical(
candidate.nid, candidate.split.findex, split_cats,
candidate.split.dir == kLeftDir, base_weight, left_weight,
Expand Down
41 changes: 41 additions & 0 deletions tests/cpp/tree/test_gpu_hist.cu
Expand Up @@ -130,6 +130,47 @@ TEST(GpuHist, BuildHistSharedMem) {
TestBuildHist<GradientPair>(true);
}

TEST(GpuHist, ApplySplit) {
RegTree tree;
ExpandEntry candidate;
candidate.left_weight = 1.0f;
candidate.right_weight = 2.0f;
candidate.base_weight = 3.0f;
candidate.split.is_cat = true;
candidate.split.fvalue = 1.0f; // at cat 1

size_t n_rows = 10;
size_t n_cols = 10;

auto m = RandomDataGenerator{n_rows, n_cols, 0}.GenerateDMatrix(true);
GenericParameter p;
p.InitAllowUnknown(Args{});

TrainParam tparam;
tparam.InitAllowUnknown(Args{});
BatchParam bparam;
bparam.gpu_id = 0;
bparam.max_bin = 3;
bparam.gpu_page_size = 0;

for (auto& ellpack : m->GetBatches<EllpackPage>(bparam)){
auto impl = ellpack.Impl();
HostDeviceVector<FeatureType> feature_types(10, FeatureType::kCategorical);
feature_types.SetDevice(bparam.gpu_id);
tree::GPUHistMakerDevice<GradientPairPrecise> updater(
0, impl, feature_types.ConstDeviceSpan(), n_rows, tparam, 0, n_cols, true, bparam);
updater.ApplySplit(candidate, &tree);

ASSERT_EQ(tree.GetSplitTypes().size(), 3);
ASSERT_EQ(tree.GetSplitTypes()[0], FeatureType::kCategorical);
ASSERT_EQ(tree.GetSplitCategories().size(), 1);
uint32_t bits = 1u << 30; // bits: 0, 1, 0, 0, 0, ..., 0
ASSERT_EQ(tree.GetSplitCategories().back(), bits);

ASSERT_EQ(updater.node_categories.size(), 1);
}
}

HistogramCutsWrapper GetHostCutMatrix () {
HistogramCutsWrapper cmat;
cmat.SetPtrs({0, 3, 6, 9, 12, 15, 18, 21, 24});
Expand Down

0 comments on commit 8c01055

Please sign in to comment.