Skip to content

Commit

Permalink
Remove duplicated code.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jun 16, 2021
1 parent 103a430 commit de27bd3
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 47 deletions.
48 changes: 3 additions & 45 deletions src/common/quantile.cu
Expand Up @@ -409,31 +409,6 @@ size_t SketchContainer::ScanInput(Span<SketchEntry> entries, Span<OffsetT> d_col
return n_uniques;
}

size_t SketchContainer::Unique() {
timer_.Start(__func__);
dh::safe_cuda(cudaSetDevice(device_));
this->columns_ptr_.SetDevice(device_);
Span<OffsetT> d_column_scan = this->columns_ptr_.DeviceSpan();
CHECK_EQ(d_column_scan.size(), num_columns_ + 1);
Span<SketchEntry> entries = dh::ToSpan(this->Current());
HostDeviceVector<OffsetT> scan_out(d_column_scan.size());
scan_out.SetDevice(device_);
auto d_scan_out = scan_out.DeviceSpan();

d_column_scan = this->columns_ptr_.DeviceSpan();
size_t n_uniques = dh::SegmentedUnique(
d_column_scan.data(), d_column_scan.data() + d_column_scan.size(),
entries.data(), entries.data() + entries.size(), scan_out.DevicePointer(),
entries.data(),
detail::SketchUnique{});
this->columns_ptr_.Copy(scan_out);
CHECK(!this->columns_ptr_.HostCanRead());

this->Current().resize(n_uniques);
timer_.Stop(__func__);
return n_uniques;
}

void SketchContainer::Prune(size_t to) {
timer_.Start(__func__);
dh::safe_cuda(cudaSetDevice(device_));
Expand Down Expand Up @@ -500,27 +475,10 @@ void SketchContainer::Merge(Span<OffsetT const> d_that_columns_ptr,
this->Alternate();

if (this->HasCategorical()) {
Span<OffsetT> d_column_scan = this->columns_ptr_.DeviceSpan();
CHECK_EQ(d_column_scan.size(), num_columns_ + 1);
Span<SketchEntry> entries = dh::ToSpan(this->Current());
HostDeviceVector<OffsetT> scan_out(d_column_scan.size());
scan_out.SetDevice(device_);
auto d_scan_out = scan_out.DeviceSpan();

d_column_scan = this->columns_ptr_.DeviceSpan();
auto d_feature_types = this->FeatureTypes().ConstDeviceSpan();
dh::XGBCachingDeviceAllocator<char> alloc;
size_t n_uniques = dh::SegmentedUnique(
thrust::cuda::par(alloc),
dh::tbegin(d_column_scan), dh::tend(d_column_scan), dh::tbegin(entries),
dh::tend(entries), scan_out.DevicePointer(), dh::tbegin(entries),
detail::SketchUnique{}, [d_feature_types]__device__(size_t l_fidx, size_t r_fidx) {
return l_fidx == r_fidx && IsCat(d_feature_types, l_fidx);
});
this->columns_ptr_.Copy(scan_out);
CHECK(!this->columns_ptr_.HostCanRead());

this->Current().resize(n_uniques);
this->Unique([d_feature_types] __device__(size_t l_fidx, size_t r_fidx) {
return l_fidx == r_fidx && IsCat(d_feature_types, l_fidx);
});
}
timer_.Stop(__func__);
}
Expand Down
30 changes: 28 additions & 2 deletions src/common/quantile.cuh
Expand Up @@ -132,8 +132,6 @@ class SketchContainer {
bool HasCategorical() const { return has_categorical_; }
/* \brief Accumulate weights of duplicated entries in input. */
size_t ScanInput(Span<SketchEntry> entries, Span<OffsetT> d_columns_ptr_in);
/* \brief Removes all the duplicated elements in quantile structure. */
size_t Unique();
/* Fix rounding error and re-establish invariance. The error is mostly generated by the
* addition inside `RMinNext` and subtraction in `RMaxPrev`. */
void FixError();
Expand Down Expand Up @@ -178,6 +176,34 @@ class SketchContainer {

SketchContainer(const SketchContainer&) = delete;
SketchContainer& operator=(const SketchContainer&) = delete;

/* \brief Removes all the duplicated elements in quantile structure. */
template <typename KeyComp = thrust::equal_to<size_t>>
size_t Unique(KeyComp key_comp = thrust::equal_to<size_t>{}) {
timer_.Start(__func__);
dh::safe_cuda(cudaSetDevice(device_));
this->columns_ptr_.SetDevice(device_);
Span<OffsetT> d_column_scan = this->columns_ptr_.DeviceSpan();
CHECK_EQ(d_column_scan.size(), num_columns_ + 1);
Span<SketchEntry> entries = dh::ToSpan(this->Current());
HostDeviceVector<OffsetT> scan_out(d_column_scan.size());
scan_out.SetDevice(device_);
auto d_scan_out = scan_out.DeviceSpan();
dh::XGBCachingDeviceAllocator<char> alloc;

d_column_scan = this->columns_ptr_.DeviceSpan();
size_t n_uniques = dh::SegmentedUnique(
thrust::cuda::par(alloc), d_column_scan.data(),
d_column_scan.data() + d_column_scan.size(), entries.data(),
entries.data() + entries.size(), scan_out.DevicePointer(),
entries.data(), detail::SketchUnique{}, key_comp);
this->columns_ptr_.Copy(scan_out);
CHECK(!this->columns_ptr_.HostCanRead());

this->Current().resize(n_uniques);
timer_.Stop(__func__);
return n_uniques;
}
};
} // namespace common
} // namespace xgboost
Expand Down

0 comments on commit de27bd3

Please sign in to comment.