From 210c131ce71e8cabe5c5f9e7018f9ad2b1de5ad3 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 21 Sep 2020 13:53:06 +0800 Subject: [PATCH] Support categorical data in GPU sketching. (#6137) --- src/common/hist_util.cu | 68 ++++++++++++++++++++++++---- src/common/quantile.cu | 65 ++++++++++++++++---------- src/common/quantile.cuh | 23 ++++++++-- src/data/iterative_device_dmatrix.cu | 8 +++- tests/cpp/common/test_hist_util.cu | 43 ++++++++++++++++-- tests/cpp/common/test_quantile.cu | 51 +++++++++++++-------- 6 files changed, 196 insertions(+), 62 deletions(-) diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index f44f416a1628..b60c2a5d52eb 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -24,6 +24,7 @@ #include "hist_util.cuh" #include "math.h" // NOLINT #include "quantile.h" +#include "categorical.h" #include "xgboost/host_device_vector.h" @@ -121,11 +122,59 @@ void SortByWeight(dh::XGBCachingDeviceAllocator* alloc, return a.index == b.index; }); } + +struct IsCatOp { + XGBOOST_DEVICE bool operator()(FeatureType ft) { return ft == FeatureType::kCategorical; } +}; + +void RemoveDuplicatedCategories( + int32_t device, MetaInfo const &info, Span d_cuts_ptr, + dh::device_vector *p_sorted_entries, + dh::caching_device_vector const &column_sizes_scan) { + auto d_feature_types = info.feature_types.ConstDeviceSpan(); + if (!info.feature_types.Empty() && + thrust::any_of(dh::tbegin(d_feature_types), dh::tend(d_feature_types), + IsCatOp{})) { + auto& sorted_entries = *p_sorted_entries; + // Removing duplicated entries in categorical features. + dh::caching_device_vector new_column_scan(column_sizes_scan.size()); + dh::SegmentedUnique(column_sizes_scan.data().get(), + column_sizes_scan.data().get() + + column_sizes_scan.size(), + sorted_entries.begin(), sorted_entries.end(), + new_column_scan.data().get(), sorted_entries.begin(), + [=] __device__(Entry const &l, Entry const &r) { + if (l.index == r.index) { + if (IsCat(d_feature_types, l.index)) { + return l.fvalue == r.fvalue; + } + } + return false; + }); + + // Renew the column scan and cut scan based on categorical data. + dh::caching_device_vector new_cuts_size( + info.num_col_ + 1); + auto d_new_cuts_size = dh::ToSpan(new_cuts_size); + auto d_new_columns_ptr = dh::ToSpan(new_column_scan); + CHECK_EQ(new_column_scan.size(), new_cuts_size.size()); + dh::LaunchN(device, new_column_scan.size() - 1, [=] __device__(size_t idx) { + if (IsCat(d_feature_types, idx)) { + d_new_cuts_size[idx] = + d_new_columns_ptr[idx + 1] - d_new_columns_ptr[idx]; + } else { + d_new_cuts_size[idx] = d_cuts_ptr[idx] - d_cuts_ptr[idx]; + } + }); + thrust::exclusive_scan(thrust::device, new_cuts_size.cbegin(), + new_cuts_size.cend(), d_cuts_ptr.data()); + } +} } // namespace detail -void ProcessBatch(int device, const SparsePage &page, size_t begin, size_t end, - SketchContainer *sketch_container, int num_cuts_per_feature, - size_t num_columns) { +void ProcessBatch(int device, MetaInfo const &info, const SparsePage &page, + size_t begin, size_t end, SketchContainer *sketch_container, + int num_cuts_per_feature, size_t num_columns) { dh::XGBCachingDeviceAllocator alloc; const auto& host_data = page.data.ConstHostVector(); dh::device_vector sorted_entries(host_data.begin() + begin, @@ -145,9 +194,10 @@ void ProcessBatch(int device, const SparsePage &page, size_t begin, size_t end, batch_it, dummy_is_valid, 0, sorted_entries.size(), &cuts_ptr, &column_sizes_scan); - + auto d_cuts_ptr = cuts_ptr.DeviceSpan(); + detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, &sorted_entries, + column_sizes_scan); auto const& h_cuts_ptr = cuts_ptr.ConstHostVector(); - auto d_cuts_ptr = cuts_ptr.ConstDeviceSpan(); CHECK_EQ(d_cuts_ptr.size(), column_sizes_scan.size()); // add cuts into sketches @@ -221,6 +271,8 @@ void ProcessWeightedBatch(int device, const SparsePage& page, HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins, size_t sketch_batch_num_elements) { + dmat->Info().feature_types.SetDevice(device); + dmat->Info().feature_types.ConstDevicePointer(); // pull to device early // Configure batch size based on available memory bool has_weights = dmat->Info().weights_.Size() > 0; size_t num_cuts_per_feature = @@ -233,7 +285,7 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins, device, num_cuts_per_feature, has_weights); HistogramCuts cuts; - SketchContainer sketch_container(max_bins, dmat->Info().num_col_, + SketchContainer sketch_container(dmat->Info().feature_types, max_bins, dmat->Info().num_col_, dmat->Info().num_row_, device); dmat->Info().weights_.SetDevice(device); @@ -253,8 +305,8 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins, dmat->Info().num_col_, is_ranking, dh::ToSpan(groups)); } else { - ProcessBatch(device, batch, begin, end, &sketch_container, num_cuts_per_feature, - dmat->Info().num_col_); + ProcessBatch(device, dmat->Info(), batch, begin, end, &sketch_container, + num_cuts_per_feature, dmat->Info().num_col_); } } } diff --git a/src/common/quantile.cu b/src/common/quantile.cu index 3b4d846ac2a7..42dd8837a29f 100644 --- a/src/common/quantile.cu +++ b/src/common/quantile.cu @@ -15,6 +15,7 @@ #include "quantile.cuh" #include "hist_util.h" #include "device_helpers.cuh" +#include "categorical.h" #include "common.h" namespace xgboost { @@ -57,6 +58,7 @@ void PruneImpl(int device, common::Span cuts_ptr, Span sorted_data, Span columns_ptr_in, // could be ptr for data or cuts + Span feature_types, Span out_cuts, ToSketchEntry to_sketch_entry) { dh::LaunchN(device, out_cuts.size(), [=] __device__(size_t idx) { @@ -71,7 +73,8 @@ void PruneImpl(int device, auto front = to_sketch_entry(0ul, in_column, column_id); auto back = to_sketch_entry(in_column.size() - 1, in_column, column_id); - if (in_column.size() <= to) { + auto is_cat = IsCat(feature_types, column_id); + if (in_column.size() <= to || is_cat) { // cut idx equals sample idx out_column[idx] = to_sketch_entry(idx, in_column, column_id); return; @@ -316,7 +319,7 @@ void SketchContainer::Push(Span entries, Span columns_ptr, this->Current().resize(total_cuts); out = dh::ToSpan(this->Current()); } - + auto ft = this->feature_types_.ConstDeviceSpan(); if (weights.empty()) { auto to_sketch_entry = [] __device__(size_t sample_idx, Span const &column, @@ -325,7 +328,7 @@ void SketchContainer::Push(Span entries, Span columns_ptr, float rmax = sample_idx + 1; return SketchEntry{rmin, rmax, 1, column[sample_idx].fvalue}; }; // NOLINT - PruneImpl(device_, cuts_ptr, entries, columns_ptr, out, + PruneImpl(device_, cuts_ptr, entries, columns_ptr, ft, out, to_sketch_entry); } else { auto to_sketch_entry = [weights, columns_ptr] __device__( @@ -340,7 +343,7 @@ void SketchContainer::Push(Span entries, Span columns_ptr, wmin = wmin < 0 ? kRtEps : wmin; // GPU scan can generate floating error. return SketchEntry{rmin, rmax, wmin, column[sample_idx].fvalue}; }; // NOLINT - PruneImpl(device_, cuts_ptr, entries, columns_ptr, out, + PruneImpl(device_, cuts_ptr, entries, columns_ptr, ft, out, to_sketch_entry); } @@ -388,26 +391,31 @@ void SketchContainer::Prune(size_t to) { this->Unique(); OffsetT to_total = 0; - HostDeviceVector new_columns_ptr{to_total}; + auto& h_columns_ptr = columns_ptr_b_.HostVector(); + h_columns_ptr[0] = to_total; + auto const& h_feature_types = feature_types_.ConstHostSpan(); for (bst_feature_t i = 0; i < num_columns_; ++i) { size_t length = this->Column(i).size(); length = std::min(length, to); + if (IsCat(h_feature_types, i)) { + length = this->Column(i).size(); + } to_total += length; - new_columns_ptr.HostVector().emplace_back(to_total); + h_columns_ptr[i+1] = to_total; } - new_columns_ptr.SetDevice(device_); this->Other().resize(to_total); auto d_columns_ptr_in = this->columns_ptr_.ConstDeviceSpan(); - auto d_columns_ptr_out = new_columns_ptr.ConstDeviceSpan(); + auto d_columns_ptr_out = columns_ptr_b_.ConstDeviceSpan(); auto out = dh::ToSpan(this->Other()); auto in = dh::ToSpan(this->Current()); auto no_op = [] __device__(size_t sample_idx, Span const &entries, size_t) { return entries[sample_idx]; }; // NOLINT - PruneImpl(device_, d_columns_ptr_out, in, d_columns_ptr_in, out, - no_op); - this->columns_ptr_.HostVector() = new_columns_ptr.HostVector(); + auto ft = this->feature_types_.ConstDeviceSpan(); + PruneImpl(device_, d_columns_ptr_out, in, d_columns_ptr_in, ft, + out, no_op); + this->columns_ptr_.Copy(columns_ptr_b_); this->Alternate(); timer_.Stop(__func__); } @@ -433,15 +441,11 @@ void SketchContainer::Merge(Span d_that_columns_ptr, this->Other().resize(this->Current().size() + that.size()); CHECK_EQ(d_that_columns_ptr.size(), this->columns_ptr_.Size()); - HostDeviceVector new_columns_ptr; - new_columns_ptr.SetDevice(device_); - new_columns_ptr.Resize(this->ColumnsPtr().size()); MergeImpl(device_, this->Data(), this->ColumnsPtr(), that, d_that_columns_ptr, - dh::ToSpan(this->Other()), new_columns_ptr.DeviceSpan()); - this->columns_ptr_ = std::move(new_columns_ptr); + dh::ToSpan(this->Other()), columns_ptr_b_.DeviceSpan()); + this->columns_ptr_.Copy(columns_ptr_b_); CHECK_EQ(this->columns_ptr_.Size(), num_columns_ + 1); - CHECK_EQ(new_columns_ptr.Size(), 0); this->Alternate(); timer_.Stop(__func__); } @@ -528,7 +532,8 @@ void SketchContainer::AllReduce() { } // Merge them into a new sketch. - SketchContainer new_sketch(num_bins_, this->num_columns_, global_sum_rows, + SketchContainer new_sketch(this->feature_types_, num_bins_, + this->num_columns_, global_sum_rows, this->device_); for (size_t i = 0; i < allworkers.size(); ++i) { auto worker = allworkers[i]; @@ -568,11 +573,16 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) { auto& h_out_columns_ptr = p_cuts->cut_ptrs_.HostVector(); h_out_columns_ptr.clear(); h_out_columns_ptr.push_back(0); + auto const& h_feature_types = this->feature_types_.ConstHostSpan(); for (bst_feature_t i = 0; i < num_columns_; ++i) { - h_out_columns_ptr.push_back( - std::min(static_cast(std::max(static_cast(1ul), - this->Column(i).size())), - static_cast(num_bins_))); + size_t column_size = std::max(static_cast(1ul), + this->Column(i).size()); + if (IsCat(h_feature_types, i)) { + h_out_columns_ptr.push_back(static_cast(column_size)); + } else { + h_out_columns_ptr.push_back(std::min(static_cast(column_size), + static_cast(num_bins_))); + } } std::partial_sum(h_out_columns_ptr.begin(), h_out_columns_ptr.end(), h_out_columns_ptr.begin()); @@ -583,6 +593,7 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) { p_cuts->cut_values_.SetDevice(device_); p_cuts->cut_values_.Resize(total_bins); auto out_cut_values = p_cuts->cut_values_.DeviceSpan(); + auto d_ft = feature_types_.ConstDeviceSpan(); dh::LaunchN(0, total_bins, [=] __device__(size_t idx) { auto column_id = dh::SegmentId(d_out_columns_ptr, idx); @@ -605,11 +616,17 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) { return; } - // First thread is responsible for setting min values. - if (idx == 0) { + if (idx == 0 && !IsCat(d_ft, column_id)) { auto mval = in_column[idx].value; d_min_values[column_id] = mval - (fabs(mval) + 1e-5); } + + if (IsCat(d_ft, column_id)) { + assert(out_column.size() == in_column.size()); + out_column[idx] = in_column[idx].value; + return; + } + // Last thread is responsible for setting a value that's greater than other cuts. if (idx == out_column.size() - 1) { const bst_float cpt = in_column.back().value; diff --git a/src/common/quantile.cuh b/src/common/quantile.cuh index 00cc193293e2..81ba92de7364 100644 --- a/src/common/quantile.cuh +++ b/src/common/quantile.cuh @@ -4,6 +4,7 @@ #include #include "xgboost/span.h" +#include "xgboost/data.h" #include "device_helpers.cuh" #include "quantile.h" #include "timer.h" @@ -28,6 +29,7 @@ class SketchContainer { private: Monitor timer_; std::unique_ptr reducer_; + HostDeviceVector feature_types_; bst_row_t num_rows_; bst_feature_t num_columns_; int32_t num_bins_; @@ -39,6 +41,7 @@ class SketchContainer { bool current_buffer_ {true}; // The container is just a CSC matrix. HostDeviceVector columns_ptr_; + HostDeviceVector columns_ptr_b_; dh::caching_device_vector& Current() { if (current_buffer_) { @@ -80,12 +83,25 @@ class SketchContainer { * \param num_rows Total number of rows in known dataset (typically the rows in current worker). * \param device GPU ID. */ - SketchContainer(int32_t max_bin, bst_feature_t num_columns, bst_row_t num_rows, int32_t device) : - num_rows_{num_rows}, num_columns_{num_columns}, num_bins_{max_bin}, device_{device} { + SketchContainer(HostDeviceVector const& feature_types, + int32_t max_bin, + bst_feature_t num_columns, bst_row_t num_rows, + int32_t device) + : num_rows_{num_rows}, + num_columns_{num_columns}, num_bins_{max_bin}, device_{device} { + CHECK_GE(device, 0); // Initialize Sketches for this dmatrix this->columns_ptr_.SetDevice(device_); this->columns_ptr_.Resize(num_columns + 1); - CHECK_GE(device, 0); + this->columns_ptr_b_.SetDevice(device_); + this->columns_ptr_b_.Resize(num_columns + 1); + + this->feature_types_.Resize(feature_types.Size()); + this->feature_types_.Copy(feature_types); + // Pull to device. + this->feature_types_.SetDevice(device); + this->feature_types_.ConstDeviceSpan(); + this->feature_types_.ConstHostSpan(); timer_.Init(__func__); } /* \brief Return GPU ID for this container. */ @@ -127,6 +143,7 @@ class SketchContainer { Span Data() const { return {this->Current().data().get(), this->Current().size()}; } + HostDeviceVector const& FeatureTypes() const { return feature_types_; } Span ColumnsPtr() const { return this->columns_ptr_.ConstDeviceSpan(); } diff --git a/src/data/iterative_device_dmatrix.cu b/src/data/iterative_device_dmatrix.cu index b99f99590bc0..eb3d34443659 100644 --- a/src/data/iterative_device_dmatrix.cu +++ b/src/data/iterative_device_dmatrix.cu @@ -79,7 +79,8 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin } else { CHECK_EQ(cols, num_cols()) << "Inconsistent number of columns."; } - sketch_containers.emplace_back(batch_param_.max_bin, cols, num_rows(), get_device()); + sketch_containers.emplace_back(proxy->Info().feature_types, + batch_param_.max_bin, cols, num_rows(), get_device()); auto* p_sketch = &sketch_containers.back(); proxy->Info().weights_.SetDevice(get_device()); Dispatch(proxy, [&](auto const &value) { @@ -101,7 +102,10 @@ void IterativeDeviceDMatrix::Initialize(DataIterHandle iter_handle, float missin } iter.Reset(); dh::safe_cuda(cudaSetDevice(get_device())); - common::SketchContainer final_sketch(batch_param_.max_bin, cols, accumulated_rows, get_device()); + HostDeviceVector ft; + common::SketchContainer final_sketch( + sketch_containers.empty() ? ft : sketch_containers.front().FeatureTypes(), + batch_param_.max_bin, cols, accumulated_rows, get_device()); for (auto const& sketch : sketch_containers) { final_sketch.Merge(sketch.ColumnsPtr(), sketch.Data()); final_sketch.FixError(); diff --git a/tests/cpp/common/test_hist_util.cu b/tests/cpp/common/test_hist_util.cu index b225acb2039d..5b35c537dc71 100644 --- a/tests/cpp/common/test_hist_util.cu +++ b/tests/cpp/common/test_hist_util.cu @@ -108,7 +108,7 @@ TEST(HistUtil, DeviceSketchDeterminism) { } } -TEST(HistUtil, DeviceSketchCategorical) { +TEST(HistUtil, DeviceSketchCategoricalAsNumeric) { int categorical_sizes[] = {2, 6, 8, 12}; int num_bins = 256; int sizes[] = {25, 100, 1000}; @@ -122,6 +122,33 @@ TEST(HistUtil, DeviceSketchCategorical) { } } +void TestCategoricalSketch(size_t n, size_t num_categories, int32_t num_bins) { + auto x = GenerateRandomCategoricalSingleColumn(n, num_categories); + auto dmat = GetDMatrixFromData(x, n, 1); + dmat->Info().feature_types.HostVector().push_back(FeatureType::kCategorical); + ASSERT_EQ(dmat->Info().feature_types.Size(), 1); + auto cuts = DeviceSketch(0, dmat.get(), num_bins); + std::sort(x.begin(), x.end()); + auto n_uniques = std::unique(x.begin(), x.end()) - x.begin(); + ASSERT_NE(n_uniques, x.size()); + ASSERT_EQ(cuts.TotalBins(), n_uniques); + ASSERT_EQ(n_uniques, num_categories); + + auto& values = cuts.cut_values_.HostVector(); + ASSERT_TRUE(std::is_sorted(values.cbegin(), values.cend())); + auto is_unique = (std::unique(values.begin(), values.end()) - values.begin()) == n_uniques; + ASSERT_TRUE(is_unique); + + x.resize(n_uniques); + for (size_t i = 0; i < n_uniques; ++i) { + ASSERT_EQ(x[i], values[i]); + } +} + +TEST(HistUtil, DeviceSketchCategoricalFeatures) { + TestCategoricalSketch(1000, 256, 32); +} + TEST(HistUtil, DeviceSketchMultipleColumns) { int bin_sizes[] = {2, 16, 256, 512}; int sizes[] = {100, 1000, 1500}; @@ -237,7 +264,8 @@ TEST(HistUtil, DeviceSketchExternalMemoryWithWeights) { template auto MakeUnweightedCutsForTest(Adapter adapter, int32_t num_bins, float missing, size_t batch_size = 0) { common::HistogramCuts batched_cuts; - SketchContainer sketch_container(num_bins, adapter.NumColumns(), adapter.NumRows(), 0); + HostDeviceVector ft; + SketchContainer sketch_container(ft, num_bins, adapter.NumColumns(), adapter.NumRows(), 0); MetaInfo info; AdapterDeviceSketch(adapter.Value(), num_bins, info, std::numeric_limits::quiet_NaN(), &sketch_container); @@ -305,7 +333,8 @@ TEST(HistUtil, AdapterSketchSlidingWindowMemory) { dh::GlobalMemoryLogger().Clear(); ConsoleLogger::Configure({{"verbosity", "3"}}); common::HistogramCuts batched_cuts; - SketchContainer sketch_container(num_bins, num_columns, num_rows, 0); + HostDeviceVector ft; + SketchContainer sketch_container(ft, num_bins, num_columns, num_rows, 0); AdapterDeviceSketch(adapter.Value(), num_bins, info, std::numeric_limits::quiet_NaN(), &sketch_container); HistogramCuts cuts; @@ -332,10 +361,12 @@ TEST(HistUtil, AdapterSketchSlidingWindowWeightedMemory) { dh::GlobalMemoryLogger().Clear(); ConsoleLogger::Configure({{"verbosity", "3"}}); common::HistogramCuts batched_cuts; - SketchContainer sketch_container(num_bins, num_columns, num_rows, 0); + HostDeviceVector ft; + SketchContainer sketch_container(ft, num_bins, num_columns, num_rows, 0); AdapterDeviceSketch(adapter.Value(), num_bins, info, std::numeric_limits::quiet_NaN(), &sketch_container); + HistogramCuts cuts; sketch_container.MakeCuts(&cuts); ConsoleLogger::Configure({{"verbosity", "0"}}); @@ -477,9 +508,11 @@ void TestAdapterSketchFromWeights(bool with_group) { data::CupyAdapter adapter(m); auto const& batch = adapter.Value(); - SketchContainer sketch_container(kBins, kCols, kRows, 0); + HostDeviceVector ft; + SketchContainer sketch_container(ft, kBins, kCols, kRows, 0); AdapterDeviceSketch(adapter.Value(), kBins, info, std::numeric_limits::quiet_NaN(), &sketch_container); + common::HistogramCuts cuts; sketch_container.MakeCuts(&cuts); diff --git a/tests/cpp/common/test_quantile.cu b/tests/cpp/common/test_quantile.cu index 5279135a920d..8bb426aa4d76 100644 --- a/tests/cpp/common/test_quantile.cu +++ b/tests/cpp/common/test_quantile.cu @@ -8,7 +8,8 @@ namespace xgboost { namespace common { TEST(GPUQuantile, Basic) { constexpr size_t kRows = 1000, kCols = 100, kBins = 256; - SketchContainer sketch(kBins, kCols, kRows, 0); + HostDeviceVector ft; + SketchContainer sketch(ft, kBins, kCols, kRows, 0); dh::caching_device_vector entries; dh::device_vector cuts_ptr(kCols+1); thrust::fill(cuts_ptr.begin(), cuts_ptr.end(), 0); @@ -20,7 +21,8 @@ TEST(GPUQuantile, Basic) { void TestSketchUnique(float sparsity) { constexpr size_t kRows = 1000, kCols = 100; RunWithSeedsAndBins(kRows, [kRows, kCols, sparsity](int32_t seed, size_t n_bins, MetaInfo const& info) { - SketchContainer sketch(n_bins, kCols, kRows, 0); + HostDeviceVector ft; + SketchContainer sketch(ft, n_bins, kCols, kRows, 0); HostDeviceVector storage; std::string interface_str = RandomDataGenerator{kRows, kCols, sparsity} @@ -93,8 +95,10 @@ void TestQuantileElemRank(int32_t device, Span in, TEST(GPUQuantile, Prune) { constexpr size_t kRows = 1000, kCols = 100; - RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) { - SketchContainer sketch(n_bins, kCols, kRows, 0); + RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, + MetaInfo const &info) { + HostDeviceVector ft; + SketchContainer sketch(ft, n_bins, kCols, kRows, 0); HostDeviceVector storage; std::string interface_str = RandomDataGenerator{kRows, kCols, 0} @@ -111,8 +115,8 @@ TEST(GPUQuantile, Prune) { if (n_bins <= kRows) { ASSERT_EQ(sketch.Data().size(), n_bins * kCols); } else { - // LE because kRows * kCols is pushed into sketch, after removing duplicated entries - // we might not have that much inputs for prune. + // LE because kRows * kCols is pushed into sketch, after removing + // duplicated entries we might not have that much inputs for prune. ASSERT_LE(sketch.Data().size(), kRows * kCols); } // This is not necessarily true for all inputs without calling unique after @@ -127,7 +131,8 @@ TEST(GPUQuantile, Prune) { TEST(GPUQuantile, MergeEmpty) { constexpr size_t kRows = 1000, kCols = 100; size_t n_bins = 10; - SketchContainer sketch_0(n_bins, kCols, kRows, 0); + HostDeviceVector ft; + SketchContainer sketch_0(ft, n_bins, kCols, kRows, 0); HostDeviceVector storage_0; std::string interface_str_0 = RandomDataGenerator{kRows, kCols, 0}.Device(0).GenerateArrayInterface( @@ -166,7 +171,8 @@ TEST(GPUQuantile, MergeEmpty) { TEST(GPUQuantile, MergeBasic) { constexpr size_t kRows = 1000, kCols = 100; RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) { - SketchContainer sketch_0(n_bins, kCols, kRows, 0); + HostDeviceVector ft; + SketchContainer sketch_0(ft, n_bins, kCols, kRows, 0); HostDeviceVector storage_0; std::string interface_str_0 = RandomDataGenerator{kRows, kCols, 0} .Device(0) @@ -176,7 +182,7 @@ TEST(GPUQuantile, MergeBasic) { AdapterDeviceSketch(adapter_0.Value(), n_bins, info, std::numeric_limits::quiet_NaN(), &sketch_0); - SketchContainer sketch_1(n_bins, kCols, kRows * kRows, 0); + SketchContainer sketch_1(ft, n_bins, kCols, kRows * kRows, 0); HostDeviceVector storage_1; std::string interface_str_1 = RandomDataGenerator{kRows, kCols, 0} .Device(0) @@ -212,7 +218,8 @@ TEST(GPUQuantile, MergeBasic) { void TestMergeDuplicated(int32_t n_bins, size_t cols, size_t rows, float frac) { MetaInfo info; int32_t seed = 0; - SketchContainer sketch_0(n_bins, cols, rows, 0); + HostDeviceVector ft; + SketchContainer sketch_0(ft, n_bins, cols, rows, 0); HostDeviceVector storage_0; std::string interface_str_0 = RandomDataGenerator{rows, cols, 0} .Device(0) @@ -224,7 +231,7 @@ void TestMergeDuplicated(int32_t n_bins, size_t cols, size_t rows, float frac) { &sketch_0); size_t f_rows = rows * frac; - SketchContainer sketch_1(n_bins, cols, f_rows, 0); + SketchContainer sketch_1(ft, n_bins, cols, f_rows, 0); HostDeviceVector storage_1; std::string interface_str_1 = RandomDataGenerator{f_rows, cols, 0} .Device(0) @@ -286,12 +293,14 @@ TEST(GPUQuantile, AllReduceBasic) { } constexpr size_t kRows = 1000, kCols = 100; - RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) { - // Set up single node version; - SketchContainer sketch_on_single_node(n_bins, kCols, kRows, 0); + RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, + MetaInfo const &info) { + // Set up single node version + HostDeviceVector ft; + SketchContainer sketch_on_single_node(ft, n_bins, kCols, kRows, 0); - size_t intermediate_num_cuts = - std::min(kRows * world, static_cast(n_bins * WQSketch::kFactor)); + size_t intermediate_num_cuts = std::min( + kRows * world, static_cast(n_bins * WQSketch::kFactor)); std::vector containers; for (auto rank = 0; rank < world; ++rank) { HostDeviceVector storage; @@ -300,12 +309,13 @@ TEST(GPUQuantile, AllReduceBasic) { .Seed(rank + seed) .GenerateArrayInterface(&storage); data::CupyAdapter adapter(interface_str); - containers.emplace_back(n_bins, kCols, kRows, 0); + HostDeviceVector ft; + containers.emplace_back(ft, n_bins, kCols, kRows, 0); AdapterDeviceSketch(adapter.Value(), n_bins, info, std::numeric_limits::quiet_NaN(), &containers.back()); } - for (auto& sketch : containers) { + for (auto &sketch : containers) { sketch.Prune(intermediate_num_cuts); sketch_on_single_node.Merge(sketch.ColumnsPtr(), sketch.Data()); sketch_on_single_node.FixError(); @@ -317,7 +327,7 @@ TEST(GPUQuantile, AllReduceBasic) { // Set up distributed version. We rely on using rank as seed to generate // the exact same copy of data. auto rank = rabit::GetRank(); - SketchContainer sketch_distributed(n_bins, kCols, kRows, 0); + SketchContainer sketch_distributed(ft, n_bins, kCols, kRows, 0); HostDeviceVector storage; std::string interface_str = RandomDataGenerator{kRows, kCols, 0} .Device(0) @@ -376,7 +386,8 @@ TEST(GPUQuantile, SameOnAllWorkers) { RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const &info) { auto rank = rabit::GetRank(); - SketchContainer sketch_distributed(n_bins, kCols, kRows, 0); + HostDeviceVector ft; + SketchContainer sketch_distributed(ft, n_bins, kCols, kRows, 0); HostDeviceVector storage; std::string interface_str = RandomDataGenerator{kRows, kCols, 0} .Device(0)