diff --git a/tests/cpp/tree/hist/test_histogram.cc b/tests/cpp/tree/hist/test_histogram.cc index ff721cf129bd..8acf0959fe70 100644 --- a/tests/cpp/tree/hist/test_histogram.cc +++ b/tests/cpp/tree/hist/test_histogram.cc @@ -2,7 +2,11 @@ * Copyright 2018-2021 by Contributors */ #include + #include "../../helpers.h" +#include "../../categorical_helpers.h" + +#include "../../../../src/common/categorical.h" #include "../../../../src/tree/hist/histogram.h" #include "../../../../src/tree/updater_quantile_hist.h" @@ -311,9 +315,72 @@ TEST(CPUHistogram, BuildHist) { TestBuildHistogram(false); } +namespace { +void TestHistogramCategorical(size_t n_categories) { + size_t constexpr kRows = 340; + int32_t constexpr kBins = 256; + auto x = GenerateRandomCategoricalSingleColumn(kRows, n_categories); + auto cat_m = GetDMatrixFromData(x, kRows, 1); + cat_m->Info().feature_types.HostVector().push_back(FeatureType::kCategorical); + BatchParam batch_param{0, static_cast(kBins)}; + + RegTree tree; + CPUExpandEntry node(RegTree::kRoot, tree.GetDepth(0), 0.0f); + std::vector nodes_for_explicit_hist_build; + nodes_for_explicit_hist_build.push_back(node); + + auto gpair = GenerateRandomGradients(kRows, 0, 2); + + RowSetCollection row_set_collection; + row_set_collection.Clear(); + std::vector &row_indices = *row_set_collection.Data(); + row_indices.resize(kRows); + std::iota(row_indices.begin(), row_indices.end(), 0); + row_set_collection.Init(); + + /** + * Generate hist with cat data. + */ + HistogramBuilder cat_hist; + for (auto const &gidx : cat_m->GetBatches( + {GenericParameter::kCpuId, kBins})) { + auto total_bins = gidx.cut.TotalBins(); + cat_hist.Reset(total_bins, {GenericParameter::kCpuId, kBins}, + omp_get_max_threads(), 1, false); + cat_hist.BuildHist(0, gidx, &tree, row_set_collection, + nodes_for_explicit_hist_build, {}, gpair.HostVector()); + } + + /** + * Generate hist with one hot encoded data. + */ + auto x_encoded = OneHotEncodeFeature(x, n_categories); + auto encode_m = GetDMatrixFromData(x_encoded, kRows, n_categories); + HistogramBuilder onehot_hist; + for (auto const &gidx : encode_m->GetBatches( + {GenericParameter::kCpuId, kBins})) { + auto total_bins = gidx.cut.TotalBins(); + onehot_hist.Reset(total_bins, {GenericParameter::kCpuId, kBins}, + omp_get_max_threads(), 1, false); + onehot_hist.BuildHist(0, gidx, &tree, row_set_collection, + nodes_for_explicit_hist_build, {}, + gpair.HostVector()); + } + + auto cat = cat_hist.Histogram()[0]; + auto onehot = onehot_hist.Histogram()[0]; + ValidateCategoricalHistogram(n_categories, onehot, cat); +} +} // anonymous namespace + +TEST(CPUHistogram, Categorical) { + for (size_t n_categories = 2; n_categories < 8; ++n_categories) { + TestHistogramCategorical(n_categories); + } +} + TEST(CPUHistogram, ExternalMemory) { size_t constexpr kEntries = 1 << 16; - int32_t constexpr kBins = 32; auto m = CreateSparsePageDMatrix(kEntries, "cache"); std::vector partition_size(1, 0); @@ -358,8 +425,8 @@ TEST(CPUHistogram, ExternalMemory) { size_t page_idx{0}; for (auto const &page : m->GetBatches( {GenericParameter::kCpuId, kBins, hess})) { - multi_build.BuildHist(page_idx, space, page, &tree, - rows_set.at(page_idx), nodes, {}, h_gpair); + multi_build.BuildHist(page_idx, space, page, &tree, rows_set.at(page_idx), nodes, {}, + h_gpair); ++page_idx; } ASSERT_EQ(page_idx, 2);