diff --git a/src/common/quantile.cu b/src/common/quantile.cu index a686be9c11e4..607f8fa9fce0 100644 --- a/src/common/quantile.cu +++ b/src/common/quantile.cu @@ -409,31 +409,6 @@ size_t SketchContainer::ScanInput(Span entries, Span d_col return n_uniques; } -size_t SketchContainer::Unique() { - timer_.Start(__func__); - dh::safe_cuda(cudaSetDevice(device_)); - this->columns_ptr_.SetDevice(device_); - Span d_column_scan = this->columns_ptr_.DeviceSpan(); - CHECK_EQ(d_column_scan.size(), num_columns_ + 1); - Span entries = dh::ToSpan(this->Current()); - HostDeviceVector scan_out(d_column_scan.size()); - scan_out.SetDevice(device_); - auto d_scan_out = scan_out.DeviceSpan(); - - d_column_scan = this->columns_ptr_.DeviceSpan(); - size_t n_uniques = dh::SegmentedUnique( - d_column_scan.data(), d_column_scan.data() + d_column_scan.size(), - entries.data(), entries.data() + entries.size(), scan_out.DevicePointer(), - entries.data(), - detail::SketchUnique{}); - this->columns_ptr_.Copy(scan_out); - CHECK(!this->columns_ptr_.HostCanRead()); - - this->Current().resize(n_uniques); - timer_.Stop(__func__); - return n_uniques; -} - void SketchContainer::Prune(size_t to) { timer_.Start(__func__); dh::safe_cuda(cudaSetDevice(device_)); @@ -500,27 +475,10 @@ void SketchContainer::Merge(Span d_that_columns_ptr, this->Alternate(); if (this->HasCategorical()) { - Span d_column_scan = this->columns_ptr_.DeviceSpan(); - CHECK_EQ(d_column_scan.size(), num_columns_ + 1); - Span entries = dh::ToSpan(this->Current()); - HostDeviceVector scan_out(d_column_scan.size()); - scan_out.SetDevice(device_); - auto d_scan_out = scan_out.DeviceSpan(); - - d_column_scan = this->columns_ptr_.DeviceSpan(); auto d_feature_types = this->FeatureTypes().ConstDeviceSpan(); - dh::XGBCachingDeviceAllocator alloc; - size_t n_uniques = dh::SegmentedUnique( - thrust::cuda::par(alloc), - dh::tbegin(d_column_scan), dh::tend(d_column_scan), dh::tbegin(entries), - dh::tend(entries), scan_out.DevicePointer(), dh::tbegin(entries), - detail::SketchUnique{}, [d_feature_types]__device__(size_t l_fidx, size_t r_fidx) { - return l_fidx == r_fidx && IsCat(d_feature_types, l_fidx); - }); - this->columns_ptr_.Copy(scan_out); - CHECK(!this->columns_ptr_.HostCanRead()); - - this->Current().resize(n_uniques); + this->Unique([d_feature_types] __device__(size_t l_fidx, size_t r_fidx) { + return l_fidx == r_fidx && IsCat(d_feature_types, l_fidx); + }); } timer_.Stop(__func__); } diff --git a/src/common/quantile.cuh b/src/common/quantile.cuh index 540bddef906d..eab7290bf733 100644 --- a/src/common/quantile.cuh +++ b/src/common/quantile.cuh @@ -132,8 +132,6 @@ class SketchContainer { bool HasCategorical() const { return has_categorical_; } /* \brief Accumulate weights of duplicated entries in input. */ size_t ScanInput(Span entries, Span d_columns_ptr_in); - /* \brief Removes all the duplicated elements in quantile structure. */ - size_t Unique(); /* Fix rounding error and re-establish invariance. The error is mostly generated by the * addition inside `RMinNext` and subtraction in `RMaxPrev`. */ void FixError(); @@ -178,6 +176,34 @@ class SketchContainer { SketchContainer(const SketchContainer&) = delete; SketchContainer& operator=(const SketchContainer&) = delete; + + /* \brief Removes all the duplicated elements in quantile structure. */ + template > + size_t Unique(KeyComp key_comp = thrust::equal_to{}) { + timer_.Start(__func__); + dh::safe_cuda(cudaSetDevice(device_)); + this->columns_ptr_.SetDevice(device_); + Span d_column_scan = this->columns_ptr_.DeviceSpan(); + CHECK_EQ(d_column_scan.size(), num_columns_ + 1); + Span entries = dh::ToSpan(this->Current()); + HostDeviceVector scan_out(d_column_scan.size()); + scan_out.SetDevice(device_); + auto d_scan_out = scan_out.DeviceSpan(); + dh::XGBCachingDeviceAllocator alloc; + + d_column_scan = this->columns_ptr_.DeviceSpan(); + size_t n_uniques = dh::SegmentedUnique( + thrust::cuda::par(alloc), d_column_scan.data(), + d_column_scan.data() + d_column_scan.size(), entries.data(), + entries.data() + entries.size(), scan_out.DevicePointer(), + entries.data(), detail::SketchUnique{}, key_comp); + this->columns_ptr_.Copy(scan_out); + CHECK(!this->columns_ptr_.HostCanRead()); + + this->Current().resize(n_uniques); + timer_.Stop(__func__); + return n_uniques; + } }; } // namespace common } // namespace xgboost