diff --git a/CMakeLists.txt b/CMakeLists.txt index 7953a10dd990..ede6c5b755af 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -149,6 +149,8 @@ if (USE_CUDA) set(GEN_CODE "") format_gencode_flags("${GPU_COMPUTE_VER}" GEN_CODE) add_subdirectory(${PROJECT_SOURCE_DIR}/gputreeshap) + + find_package(CUDAToolkit REQUIRED) endif (USE_CUDA) if (FORCE_COLORED_OUTPUT AND (CMAKE_GENERATOR STREQUAL "Ninja") AND diff --git a/cmake/Utils.cmake b/cmake/Utils.cmake index 57a45ca420e9..dc523d03a8a7 100644 --- a/cmake/Utils.cmake +++ b/cmake/Utils.cmake @@ -124,13 +124,6 @@ function(format_gencode_flags flags out) endif (CMAKE_VERSION VERSION_GREATER_EQUAL "3.18") endfunction(format_gencode_flags flags) -macro(enable_nvtx target) - find_package(NVTX REQUIRED) - target_include_directories(${target} PRIVATE "${NVTX_INCLUDE_DIR}") - target_link_libraries(${target} PRIVATE "${NVTX_LIBRARY}") - target_compile_definitions(${target} PRIVATE -DXGBOOST_USE_NVTX=1) -endmacro() - # Set CUDA related flags to target. Must be used after code `format_gencode_flags`. function(xgboost_set_cuda_flags target) target_compile_options(${target} PRIVATE @@ -162,11 +155,14 @@ function(xgboost_set_cuda_flags target) endif (USE_DEVICE_DEBUG) if (USE_NVTX) - enable_nvtx(${target}) + target_compile_definitions(${target} PRIVATE -DXGBOOST_USE_NVTX=1) endif (USE_NVTX) target_compile_definitions(${target} PRIVATE -DXGBOOST_USE_CUDA=1) - target_include_directories(${target} PRIVATE ${xgboost_SOURCE_DIR}/gputreeshap) + target_include_directories( + ${target} PRIVATE + ${xgboost_SOURCE_DIR}/gputreeshap + ${CUDAToolkit_INCLUDE_DIRS}) if (MSVC) target_compile_options(${target} PRIVATE @@ -289,7 +285,7 @@ macro(xgboost_target_link_libraries target) endif (USE_NCCL) if (USE_NVTX) - enable_nvtx(${target}) + target_link_libraries(${target} PRIVATE CUDA::nvToolsExt) endif (USE_NVTX) if (RABIT_BUILD_MPI) diff --git a/cmake/modules/FindNVTX.cmake b/cmake/modules/FindNVTX.cmake deleted file mode 100644 index 173e255c8951..000000000000 --- a/cmake/modules/FindNVTX.cmake +++ /dev/null @@ -1,26 +0,0 @@ -if (NVTX_LIBRARY) - unset(NVTX_LIBRARY CACHE) -endif (NVTX_LIBRARY) - -set(NVTX_LIB_NAME nvToolsExt) - - -find_path(NVTX_INCLUDE_DIR - NAMES nvToolsExt.h - PATHS ${CUDA_HOME}/include ${CUDA_INCLUDE} /usr/local/cuda/include) - - -find_library(NVTX_LIBRARY - NAMES nvToolsExt - PATHS ${CUDA_HOME}/lib64 /usr/local/cuda/lib64) - -message(STATUS "Using nvtx library: ${NVTX_LIBRARY}") - -include(FindPackageHandleStandardArgs) -find_package_handle_standard_args(NVTX DEFAULT_MSG - NVTX_INCLUDE_DIR NVTX_LIBRARY) - -mark_as_advanced( - NVTX_INCLUDE_DIR - NVTX_LIBRARY -) diff --git a/include/xgboost/span.h b/include/xgboost/span.h index 0b543b5372c2..aea7ee0ad10b 100644 --- a/include/xgboost/span.h +++ b/include/xgboost/span.h @@ -1,5 +1,5 @@ -/*! - * Copyright 2018 XGBoost contributors +/** + * Copyright 2018-2023, XGBoost contributors * \brief span class based on ISO++20 span * * About NOLINTs in this file: @@ -32,11 +32,12 @@ #include #include -#include // size_t -#include // numeric_limits +#include // size_t +#include #include +#include // numeric_limits #include -#include +#include // for move #if defined(__CUDACC__) #include @@ -668,6 +669,42 @@ XGBOOST_DEVICE auto as_writable_bytes(Span s) __span_noexcept -> // NOLIN Span::value> { return {reinterpret_cast(s.data()), s.size_bytes()}; } + +// A simple custom Span type that uses general iterator instead of pointer. +template +class IterSpan { + public: + using element_type = typename std::iterator_traits::value_type; // NOLINT + using index_type = std::size_t; // NOLINT + using iterator = It; // NOLINT + + private: + It it_; + index_type size_{0}; + + public: + IterSpan() = default; + XGBOOST_DEVICE IterSpan(It it, index_type size) : it_{std::move(it)}, size_{size} {} + XGBOOST_DEVICE explicit IterSpan(common::Span span) + : it_{span.data()}, size_{span.size()} {} + + XGBOOST_DEVICE index_type size() const { return size_; } // NOLINT + XGBOOST_DEVICE decltype(auto) operator[](index_type i) const { return it_[i]; } + XGBOOST_DEVICE decltype(auto) operator[](index_type i) { return it_[i]; } + XGBOOST_DEVICE bool empty() const { return size() == 0; } // NOLINT + XGBOOST_DEVICE It data() const { return it_; } // NOLINT + XGBOOST_DEVICE IterSpan subspan( // NOLINT + index_type _offset, index_type _count = dynamic_extent) const { + SPAN_CHECK((_count == dynamic_extent) ? (_offset <= size()) : (_offset + _count <= size())); + return {data() + _offset, _count == dynamic_extent ? size() - _offset : _count}; + } + XGBOOST_DEVICE constexpr iterator begin() const __span_noexcept { // NOLINT + return {this, 0}; + } + XGBOOST_DEVICE constexpr iterator end() const __span_noexcept { // NOLINT + return {this, size()}; + } +}; } // namespace common } // namespace xgboost diff --git a/python-package/xgboost/data.py b/python-package/xgboost/data.py index 2ebde84f0519..5e1a309e04da 100644 --- a/python-package/xgboost/data.py +++ b/python-package/xgboost/data.py @@ -882,7 +882,7 @@ def _transform_cupy_array(data: DataType) -> CupyT: if not hasattr(data, "__cuda_array_interface__") and hasattr(data, "__array__"): data = cupy.array(data, copy=False) - if data.dtype.hasobject or data.dtype in [cupy.float16, cupy.bool_]: + if data.dtype.hasobject or data.dtype in [cupy.bool_]: data = data.astype(cupy.float32, copy=False) return data diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index 08ef98ea10ac..73ade6d37a49 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -11,6 +11,7 @@ #include #include #include +#include #include // for size_t #include @@ -26,270 +27,233 @@ #include "quantile.h" #include "xgboost/host_device_vector.h" -namespace xgboost { -namespace common { - +namespace xgboost::common { constexpr float SketchContainer::kFactor; - namespace detail { size_t RequiredSampleCutsPerColumn(int max_bins, size_t num_rows) { double eps = 1.0 / (WQSketch::kFactor * max_bins); size_t dummy_nlevel; size_t num_cuts; - WQuantileSketch::LimitSizeLevel( - num_rows, eps, &dummy_nlevel, &num_cuts); + WQuantileSketch::LimitSizeLevel(num_rows, eps, &dummy_nlevel, &num_cuts); return std::min(num_cuts, num_rows); } -size_t RequiredSampleCuts(bst_row_t num_rows, bst_feature_t num_columns, - size_t max_bins, size_t nnz) { +size_t RequiredSampleCuts(bst_row_t num_rows, bst_feature_t num_columns, size_t max_bins, + size_t nnz) { auto per_column = RequiredSampleCutsPerColumn(max_bins, num_rows); auto if_dense = num_columns * per_column; auto result = std::min(nnz, if_dense); return result; } -size_t RequiredMemory(bst_row_t num_rows, bst_feature_t num_columns, size_t nnz, - size_t num_bins, bool with_weights) { - size_t peak = 0; - // 0. Allocate cut pointer in quantile container by increasing: n_columns + 1 - size_t total = (num_columns + 1) * sizeof(SketchContainer::OffsetT); - // 1. Copy and sort: 2 * bytes_per_element * shape - total += BytesPerElement(with_weights) * num_rows * num_columns; - peak = std::max(peak, total); - // 2. Deallocate bytes_per_element * shape due to reusing memory in sort. - total -= BytesPerElement(with_weights) * num_rows * num_columns / 2; - // 3. Allocate colomn size scan by increasing: n_columns + 1 - total += (num_columns + 1) * sizeof(SketchContainer::OffsetT); - // 4. Allocate cut pointer by increasing: n_columns + 1 - total += (num_columns + 1) * sizeof(SketchContainer::OffsetT); - // 5. Allocate cuts: assuming rows is greater than bins: n_columns * limit_size - total += RequiredSampleCuts(num_rows, num_bins, num_bins, nnz) * sizeof(SketchEntry); - // 6. Deallocate copied entries by reducing: bytes_per_element * shape. - peak = std::max(peak, total); - total -= (BytesPerElement(with_weights) * num_rows * num_columns) / 2; - // 7. Deallocate column size scan. - peak = std::max(peak, total); - total -= (num_columns + 1) * sizeof(SketchContainer::OffsetT); - // 8. Deallocate cut size scan. - total -= (num_columns + 1) * sizeof(SketchContainer::OffsetT); - // 9. Allocate final cut values, min values, cut ptrs: std::min(rows, bins + 1) * - // n_columns + n_columns + n_columns + 1 - total += std::min(num_rows, num_bins) * num_columns * sizeof(float); - total += num_columns * - sizeof(std::remove_reference_t().MinValues())>::value_type); - total += (num_columns + 1) * - sizeof(std::remove_reference_t().Ptrs())>::value_type); - peak = std::max(peak, total); +std::size_t ArgSortEntryMemUsage(std::size_t n) { + std::size_t bytes{0}; + cub_argsort::DeviceRadixSort::Argsort( + nullptr, bytes, Span::const_iterator{}, static_cast(nullptr), n); + return bytes; +} + +std::size_t RequiredMemory(bst_row_t num_rows, bst_feature_t num_columns, size_t nnz, + bst_bin_t num_bins, bool with_weights, bool d_can_read) { + std::size_t peak = 0; // peak memory consumption + std::size_t running = 0; // running memory consumption + + std::size_t n_entries = std::min(nnz, num_rows * num_columns); + + if (!d_can_read) { + // Pull data from host to device + running += sizeof(Entry) * n_entries; + if (with_weights) { + // Row offset. + running += num_rows + 1; + } + } + // Allocate sorted idx + running += sizeof(SortedIdxT) * n_entries; + // Extra memory used by sort. + running += ArgSortEntryMemUsage(n_entries); + peak = std::max(peak, running); + // Deallocate memory used by sort + running -= ArgSortEntryMemUsage(std::min(nnz, num_rows * num_columns)); + if (with_weights) { + // temp weight + running += n_entries * sizeof(float); + } + peak = std::max(peak, running); + // Allocate cut pointer in quantile container by increasing: n_columns + 1 + running += (num_columns + 1) * sizeof(SketchContainer::OffsetT); + // Allocate colomn size scan by increasing: n_columns + 1 + running += (num_columns + 1) * sizeof(SketchContainer::OffsetT); + // Allocate cuts: assuming rows is greater than bins: n_columns * limit_size + running += RequiredSampleCuts(num_rows, num_columns, num_bins, nnz) * sizeof(SketchEntry); + peak = std::max(peak, running); return peak; } -size_t SketchBatchNumElements(size_t sketch_batch_num_elements, - bst_row_t num_rows, bst_feature_t columns, - size_t nnz, int device, - size_t num_cuts, bool has_weight) { +std::size_t SketchBatchNumElements(size_t sketch_batch_num_elements, bst_row_t num_rows, + bst_feature_t columns, size_t nnz, int device, size_t num_cuts, + bool has_weight, bool d_can_read) { #if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 - // device available memory is not accurate when rmm is used. - return nnz; + // Device available memory is not accurate when rmm is used. + return std::min(nnz, static_cast( + std::numeric_limits::max())); #endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 if (sketch_batch_num_elements == 0) { - auto required_memory = RequiredMemory(num_rows, columns, nnz, num_cuts, has_weight); - // use up to 80% of available space + // Use up to 80% of available space auto avail = dh::AvailableMemory(device) * 0.8; - if (required_memory > avail) { - sketch_batch_num_elements = avail / BytesPerElement(has_weight); - } else { - sketch_batch_num_elements = std::min(num_rows * static_cast(columns), nnz); + nnz = std::min(num_rows * static_cast(columns), nnz); + std::size_t required_memory{0ul}; + + if (nnz <= 2) { + // short cut + return kMaxNumEntrySort; } - } - return sketch_batch_num_elements; -} -void SortByWeight(dh::device_vector* weights, - dh::device_vector* sorted_entries) { - // Sort both entries and wegihts. - dh::XGBDeviceAllocator alloc; - thrust::sort_by_key(thrust::cuda::par(alloc), sorted_entries->begin(), - sorted_entries->end(), weights->begin(), - detail::EntryCompareOp()); + do { + required_memory = RequiredMemory(num_rows, columns, nnz, num_cuts, has_weight, d_can_read); + if (required_memory > avail) { + LOG(WARNING) << "Insufficient memory, dividing the data into smaller batches."; + } + sketch_batch_num_elements = nnz; + if (required_memory > avail) { + nnz = nnz / 2; + } + } while (required_memory > avail && nnz >= 2); - // Scan weights - dh::XGBCachingDeviceAllocator caching; - thrust::inclusive_scan_by_key(thrust::cuda::par(caching), - sorted_entries->begin(), sorted_entries->end(), - weights->begin(), weights->begin(), - [=] __device__(const Entry& a, const Entry& b) { - return a.index == b.index; - }); -} + if (nnz <= 2) { + LOG(WARNING) << "Unable to finish sketching due to memory limit."; + // let it OOM. + return kMaxNumEntrySort; + } + } -void RemoveDuplicatedCategories( - int32_t device, MetaInfo const &info, Span d_cuts_ptr, - dh::device_vector *p_sorted_entries, - dh::caching_device_vector *p_column_sizes_scan) { - info.feature_types.SetDevice(device); - auto d_feature_types = info.feature_types.ConstDeviceSpan(); - CHECK(!d_feature_types.empty()); - auto &column_sizes_scan = *p_column_sizes_scan; - auto &sorted_entries = *p_sorted_entries; - // Removing duplicated entries in categorical features. - dh::caching_device_vector new_column_scan(column_sizes_scan.size()); - dh::SegmentedUnique(column_sizes_scan.data().get(), - column_sizes_scan.data().get() + column_sizes_scan.size(), - sorted_entries.begin(), sorted_entries.end(), - new_column_scan.data().get(), sorted_entries.begin(), - [=] __device__(Entry const &l, Entry const &r) { - if (l.index == r.index) { - if (IsCat(d_feature_types, l.index)) { - return l.fvalue == r.fvalue; - } - } - return false; - }); - - // Renew the column scan and cut scan based on categorical data. - auto d_old_column_sizes_scan = dh::ToSpan(column_sizes_scan); - dh::caching_device_vector new_cuts_size( - info.num_col_ + 1); - CHECK_EQ(new_column_scan.size(), new_cuts_size.size()); - dh::LaunchN( - new_column_scan.size(), - [=, d_new_cuts_size = dh::ToSpan(new_cuts_size), - d_old_column_sizes_scan = dh::ToSpan(column_sizes_scan), - d_new_columns_ptr = dh::ToSpan(new_column_scan)] __device__(size_t idx) { - d_old_column_sizes_scan[idx] = d_new_columns_ptr[idx]; - if (idx == d_new_columns_ptr.size() - 1) { - return; - } - if (IsCat(d_feature_types, idx)) { - // Cut size is the same as number of categories in input. - d_new_cuts_size[idx] = - d_new_columns_ptr[idx + 1] - d_new_columns_ptr[idx]; - } else { - d_new_cuts_size[idx] = d_cuts_ptr[idx + 1] - d_cuts_ptr[idx]; - } - }); - // Turn size into ptr. - thrust::exclusive_scan(thrust::device, new_cuts_size.cbegin(), - new_cuts_size.cend(), d_cuts_ptr.data()); + return std::min(sketch_batch_num_elements, kMaxNumEntrySort); } } // namespace detail -void ProcessBatch(int device, MetaInfo const &info, const SparsePage &page, - size_t begin, size_t end, SketchContainer *sketch_container, - int num_cuts_per_feature, size_t num_columns) { - dh::XGBCachingDeviceAllocator alloc; - dh::device_vector sorted_entries; +void ProcessBatch(std::int32_t device, MetaInfo const& info, const SparsePage& page, + std::size_t begin, std::size_t end, SketchContainer* sketch_container, + bst_bin_t num_cuts_per_feature, std::size_t num_columns) { + std::size_t n = end - begin; + dh::device_vector tmp_entries; + Span entries_view; if (page.data.DeviceCanRead()) { - const auto& device_data = page.data.ConstDevicePointer(); - sorted_entries = dh::device_vector(device_data + begin, device_data + end); + entries_view = page.data.ConstDeviceSpan().subspan(begin, n); } else { const auto& host_data = page.data.ConstHostVector(); - sorted_entries = dh::device_vector(host_data.begin() + begin, - host_data.begin() + end); + tmp_entries = dh::device_vector(host_data.begin() + begin, host_data.begin() + end); + entries_view = dh::ToSpan(tmp_entries); } - thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(), - sorted_entries.end(), detail::EntryCompareOp()); + + dh::device_vector sorted_idx(n); + detail::ArgSortEntry(std::as_const(entries_view).data(), &sorted_idx); + auto d_sorted_idx = dh::ToSpan(sorted_idx); HostDeviceVector cuts_ptr; dh::caching_device_vector column_sizes_scan; data::IsValidFunctor dummy_is_valid(std::numeric_limits::quiet_NaN()); - auto batch_it = dh::MakeTransformIterator( - sorted_entries.data().get(), - [] __device__(Entry const &e) -> data::COOTuple { - return {0, e.index, e.fvalue}; // row_idx is not needed for scanning column size. + auto d_sorted_entry_it = + thrust::make_permutation_iterator(entries_view.data(), dh::tcbegin(d_sorted_idx)); + auto sorted_batch_it = dh::MakeTransformIterator( + d_sorted_entry_it, [=] __device__(Entry const& e) -> data::COOTuple { + return {0, e.index, e.fvalue}; // row_idx is not needed for scaning column size. }); + detail::GetColumnSizesScan(device, num_columns, num_cuts_per_feature, - batch_it, dummy_is_valid, - 0, sorted_entries.size(), + IterSpan{sorted_batch_it, sorted_idx.size()}, dummy_is_valid, &cuts_ptr, &column_sizes_scan); auto d_cuts_ptr = cuts_ptr.DeviceSpan(); if (sketch_container->HasCategorical()) { - detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, - &sorted_entries, &column_sizes_scan); + detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, sorted_batch_it, &sorted_idx, + &column_sizes_scan); } auto const& h_cuts_ptr = cuts_ptr.ConstHostVector(); CHECK_EQ(d_cuts_ptr.size(), column_sizes_scan.size()); // add cuts into sketches - sketch_container->Push(dh::ToSpan(sorted_entries), dh::ToSpan(column_sizes_scan), - d_cuts_ptr, h_cuts_ptr.back()); - sorted_entries.clear(); - sorted_entries.shrink_to_fit(); - CHECK_EQ(sorted_entries.capacity(), 0); - CHECK_NE(cuts_ptr.Size(), 0); + sketch_container->Push(d_sorted_entry_it, dh::ToSpan(column_sizes_scan), d_cuts_ptr, + h_cuts_ptr.back()); } -void ProcessWeightedBatch(int device, const SparsePage& page, - MetaInfo const& info, size_t begin, size_t end, - SketchContainer* sketch_container, int num_cuts_per_feature, - size_t num_columns, +void ProcessWeightedBatch(int device, MetaInfo const& info, const SparsePage& page, + std::size_t begin, std::size_t end, SketchContainer* sketch_container, + bst_bin_t num_cuts_per_feature, bst_feature_t num_columns, bool is_ranking, Span d_group_ptr) { auto weights = info.weights_.ConstDeviceSpan(); + std::size_t n = end - begin; - dh::XGBCachingDeviceAllocator alloc; - const auto& host_data = page.data.ConstHostVector(); - dh::device_vector sorted_entries(host_data.begin() + begin, - host_data.begin() + end); + dh::device_vector tmp_entries; + common::Span entries_view; + if (page.data.DeviceCanRead()) { + entries_view = page.data.ConstDeviceSpan().subspan(begin, n); + } else { + const auto& host_data = page.data.ConstHostVector(); + tmp_entries = dh::device_vector(host_data.begin() + begin, host_data.begin() + end); + entries_view = dh::ToSpan(tmp_entries); + } + dh::device_vector sorted_idx(n); // Binary search to assign weights to each element - dh::device_vector temp_weights(sorted_entries.size()); - auto d_temp_weights = temp_weights.data().get(); + dh::device_vector temp_weights(sorted_idx.size()); + auto d_temp_weights = dh::ToSpan(temp_weights); page.offset.SetDevice(device); auto row_ptrs = page.offset.ConstDeviceSpan(); size_t base_rowid = page.base_rowid; if (is_ranking) { - CHECK_GE(d_group_ptr.size(), 2) - << "Must have at least 1 group for ranking."; + CHECK_GE(d_group_ptr.size(), 2) << "Must have at least 1 group for ranking."; CHECK_EQ(weights.size(), d_group_ptr.size() - 1) << "Weight size should equal to number of groups."; dh::LaunchN(temp_weights.size(), [=] __device__(size_t idx) { - size_t element_idx = idx + begin; - size_t ridx = dh::SegmentId(row_ptrs, element_idx); - bst_group_t group_idx = dh::SegmentId(d_group_ptr, ridx + base_rowid); - d_temp_weights[idx] = weights[group_idx]; - }); + std::size_t element_idx = idx + begin; + std::size_t ridx = dh::SegmentId(row_ptrs, element_idx); + bst_group_t group_idx = dh::SegmentId(d_group_ptr, ridx + base_rowid); + d_temp_weights[idx] = weights[group_idx]; + }); } else { dh::LaunchN(temp_weights.size(), [=] __device__(size_t idx) { - size_t element_idx = idx + begin; - size_t ridx = dh::SegmentId(row_ptrs, element_idx); - d_temp_weights[idx] = weights[ridx + base_rowid]; - }); + std::size_t element_idx = idx + begin; + std::size_t ridx = dh::SegmentId(row_ptrs, element_idx); + d_temp_weights[idx] = weights[ridx + base_rowid]; + }); } - detail::SortByWeight(&temp_weights, &sorted_entries); + + detail::ArgSortEntry(std::as_const(entries_view).data(), &sorted_idx); + auto d_sorted_entry_it = + thrust::make_permutation_iterator(entries_view.data(), sorted_idx.cbegin()); + auto d_sorted_weight_it = + thrust::make_permutation_iterator(dh::tbegin(d_temp_weights), sorted_idx.cbegin()); + + dh::XGBCachingDeviceAllocator caching; + thrust::inclusive_scan_by_key( + thrust::cuda::par(caching), d_sorted_entry_it, d_sorted_entry_it + sorted_idx.size(), + d_sorted_weight_it, d_sorted_weight_it, + [=] __device__(const Entry& a, const Entry& b) { return a.index == b.index; }); HostDeviceVector cuts_ptr; dh::caching_device_vector column_sizes_scan; data::IsValidFunctor dummy_is_valid(std::numeric_limits::quiet_NaN()); - auto batch_it = dh::MakeTransformIterator( - sorted_entries.data().get(), - [] __device__(Entry const &e) -> data::COOTuple { + auto sorted_batch_it = dh::MakeTransformIterator( + d_sorted_entry_it, [=] __device__(Entry const& e) -> data::COOTuple { return {0, e.index, e.fvalue}; // row_idx is not needed for scaning column size. }); detail::GetColumnSizesScan(device, num_columns, num_cuts_per_feature, - batch_it, dummy_is_valid, - 0, sorted_entries.size(), + IterSpan{sorted_batch_it, sorted_idx.size()}, dummy_is_valid, &cuts_ptr, &column_sizes_scan); auto d_cuts_ptr = cuts_ptr.DeviceSpan(); if (sketch_container->HasCategorical()) { - detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, - &sorted_entries, &column_sizes_scan); + detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, sorted_batch_it, &sorted_idx, + &column_sizes_scan); } auto const& h_cuts_ptr = cuts_ptr.ConstHostVector(); // Extract cuts - sketch_container->Push(dh::ToSpan(sorted_entries), - dh::ToSpan(column_sizes_scan), d_cuts_ptr, - h_cuts_ptr.back(), dh::ToSpan(temp_weights)); - sorted_entries.clear(); - sorted_entries.shrink_to_fit(); + sketch_container->Push(d_sorted_entry_it, dh::ToSpan(column_sizes_scan), d_cuts_ptr, + h_cuts_ptr.back(), IterSpan{d_sorted_weight_it, sorted_idx.size()}); } HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins, @@ -300,12 +264,6 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins, bool has_weights = dmat->Info().weights_.Size() > 0; size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(max_bins, dmat->Info().num_row_); - sketch_batch_num_elements = detail::SketchBatchNumElements( - sketch_batch_num_elements, - dmat->Info().num_row_, - dmat->Info().num_col_, - dmat->Info().num_nonzero_, - device, num_cuts_per_feature, has_weights); HistogramCuts cuts; SketchContainer sketch_container(dmat->Info().feature_types, max_bins, dmat->Info().num_col_, @@ -315,18 +273,21 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins, for (const auto& batch : dmat->GetBatches()) { size_t batch_nnz = batch.data.Size(); auto const& info = dmat->Info(); + + sketch_batch_num_elements = detail::SketchBatchNumElements( + sketch_batch_num_elements, dmat->Info().num_row_, dmat->Info().num_col_, + dmat->Info().num_nonzero_, device, num_cuts_per_feature, has_weights, + batch.data.DeviceCanRead()); + for (auto begin = 0ull; begin < batch_nnz; begin += sketch_batch_num_elements) { size_t end = std::min(batch_nnz, static_cast(begin + sketch_batch_num_elements)); if (has_weights) { bool is_ranking = HostSketchContainer::UseGroup(dmat->Info()); dh::caching_device_vector groups(info.group_ptr_.cbegin(), info.group_ptr_.cend()); - ProcessWeightedBatch( - device, batch, dmat->Info(), begin, end, - &sketch_container, - num_cuts_per_feature, - dmat->Info().num_col_, - is_ranking, dh::ToSpan(groups)); + ProcessWeightedBatch(device, dmat->Info(), batch, begin, end, &sketch_container, + num_cuts_per_feature, dmat->Info().num_col_, is_ranking, + dh::ToSpan(groups)); } else { ProcessBatch(device, dmat->Info(), batch, begin, end, &sketch_container, num_cuts_per_feature, dmat->Info().num_col_); @@ -336,5 +297,4 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins, sketch_container.MakeCuts(&cuts); return cuts; } -} // namespace common -} // namespace xgboost +} // namespace xgboost::common diff --git a/src/common/hist_util.cuh b/src/common/hist_util.cuh index 856404107099..16e22a7a0b30 100644 --- a/src/common/hist_util.cuh +++ b/src/common/hist_util.cuh @@ -9,17 +9,18 @@ #include -#include // for size_t +#include // for size_t, byte +#include "../cub_sort/device/device_radix_sort.cuh" #include "../data/device_adapter.cuh" +#include "cuda_context.cuh" #include "device_helpers.cuh" #include "hist_util.h" #include "quantile.cuh" #include "timer.h" -namespace xgboost { -namespace common { -namespace cuda { +namespace xgboost::common { +namespace cuda_impl { /** * copy and paste of the host version, we can't make it a __host__ __device__ function as * the fn might be a host only or device only callable object, which is not allowed by nvcc. @@ -40,68 +41,177 @@ auto __device__ DispatchBinType(BinTypeSize type, Fn&& fn) { SPAN_CHECK(false); return fn(uint32_t{}); } -} // namespace cuda +} // namespace cuda_impl namespace detail { -struct EntryCompareOp { - __device__ bool operator()(const Entry& a, const Entry& b) { - if (a.index == b.index) { - return a.fvalue < b.fvalue; +// Get column size from adapter batch and for output cuts. +template +__global__ void GetColumnSizeSharedMemKernel(IterSpan batch_iter, + data::IsValidFunctor is_valid, + Span out_column_size) { + extern __shared__ char smem[]; + + auto smem_cs_ptr = reinterpret_cast(smem); + + dh::BlockFill(smem_cs_ptr, out_column_size.size(), 0); + + cub::CTA_SYNC(); + + auto n = batch_iter.size(); + + for (auto idx : dh::GridStrideRange(static_cast(0), n)) { + auto e = batch_iter[idx]; + if (is_valid(e)) { + atomicAdd(&smem_cs_ptr[e.column_idx], static_cast(1)); } - return a.index < b.index; } -}; -// Get column size from adapter batch and for output cuts. -template -void GetColumnSizesScan(int device, size_t num_columns, size_t num_cuts_per_feature, - Iter batch_iter, data::IsValidFunctor is_valid, - size_t begin, size_t end, - HostDeviceVector *cuts_ptr, + cub::CTA_SYNC(); + + auto out_global_ptr = out_column_size; + for (auto i : dh::BlockStrideRange(static_cast(0), out_column_size.size())) { + atomicAdd(&out_global_ptr[i], static_cast(smem_cs_ptr[i])); + } +} + +template +std::uint32_t EstimateGridSize(std::int32_t device, Kernel kernel, std::size_t shared_mem) { + int n_mps = 0; + dh::safe_cuda(cudaDeviceGetAttribute(&n_mps, cudaDevAttrMultiProcessorCount, device)); + int n_blocks_per_mp = 0; + dh::safe_cuda(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&n_blocks_per_mp, kernel, + kBlockThreads, shared_mem)); + std::uint32_t grid_size = n_blocks_per_mp * n_mps; + return grid_size; +} + +/** + * \brief Get the size of each column. This is a histogram with additional handling of + * invalid values. + * + * \tparam BatchIt Type of input adapter batch. + * \tparam force_use_global_memory Used for testing. Force global atomic add. + * \tparam force_use_u64 Used for testing. For u64 as counter in shared memory. + * + * \param device CUDA device ordinal. + * \param batch_iter Iterator for input data from adapter batch. + * \param is_valid Whehter an element is considered as missing. + * \param out_column_size Output buffer for the size of each column. + */ +template +void LaunchGetColumnSizeKernel(std::int32_t device, IterSpan batch_iter, + data::IsValidFunctor is_valid, Span out_column_size) { + thrust::fill_n(thrust::device, dh::tbegin(out_column_size), out_column_size.size(), 0); + + std::size_t max_shared_memory = dh::MaxSharedMemory(device); + // Not strictly correct as we should use number of samples to determine the type of + // counter. However, the sample size is not known due to sliding window on number of + // elements. + std::size_t n = batch_iter.size(); + + std::size_t required_shared_memory = 0; + bool use_u32{false}; + if (!force_use_u64 && n < static_cast(std::numeric_limits::max())) { + required_shared_memory = out_column_size.size() * sizeof(std::uint32_t); + use_u32 = true; + } else { + required_shared_memory = out_column_size.size() * sizeof(std::size_t); + use_u32 = false; + } + bool use_shared = required_shared_memory <= max_shared_memory && required_shared_memory != 0; + + if (!force_use_global_memory && use_shared) { + CHECK_NE(required_shared_memory, 0); + std::uint32_t constexpr kBlockThreads = 512; + if (use_u32) { + CHECK(!force_use_u64); + auto kernel = GetColumnSizeSharedMemKernel; + auto grid_size = EstimateGridSize(device, kernel, required_shared_memory); + dh::LaunchKernel{grid_size, kBlockThreads, required_shared_memory, dh::DefaultStream()}( + kernel, batch_iter, is_valid, out_column_size); + } else { + auto kernel = GetColumnSizeSharedMemKernel; + auto grid_size = EstimateGridSize(device, kernel, required_shared_memory); + dh::LaunchKernel{grid_size, kBlockThreads, required_shared_memory, dh::DefaultStream()}( + kernel, batch_iter, is_valid, out_column_size); + } + } else { + auto d_out_column_size = out_column_size; + dh::LaunchN(batch_iter.size(), [=] __device__(size_t idx) { + auto e = batch_iter[idx]; + if (is_valid(e)) { + atomicAdd(&d_out_column_size[e.column_idx], static_cast(1)); + } + }); + } +} + +template +void GetColumnSizesScan(int device, size_t num_columns, std::size_t num_cuts_per_feature, + IterSpan batch_iter, data::IsValidFunctor is_valid, + HostDeviceVector* cuts_ptr, dh::caching_device_vector* column_sizes_scan) { - column_sizes_scan->resize(num_columns + 1, 0); + column_sizes_scan->resize(num_columns + 1); cuts_ptr->SetDevice(device); cuts_ptr->Resize(num_columns + 1, 0); dh::XGBCachingDeviceAllocator alloc; - auto d_column_sizes_scan = column_sizes_scan->data().get(); - dh::LaunchN(end - begin, [=] __device__(size_t idx) { - auto e = batch_iter[begin + idx]; - if (is_valid(e)) { - atomicAdd(&d_column_sizes_scan[e.column_idx], static_cast(1)); - } - }); + auto d_column_sizes_scan = dh::ToSpan(*column_sizes_scan); + LaunchGetColumnSizeKernel(device, batch_iter, is_valid, d_column_sizes_scan); // Calculate cuts CSC pointer auto cut_ptr_it = dh::MakeTransformIterator( column_sizes_scan->begin(), [=] __device__(size_t column_size) { return thrust::min(num_cuts_per_feature, column_size); }); thrust::exclusive_scan(thrust::cuda::par(alloc), cut_ptr_it, - cut_ptr_it + column_sizes_scan->size(), - cuts_ptr->DevicePointer()); + cut_ptr_it + column_sizes_scan->size(), cuts_ptr->DevicePointer()); thrust::exclusive_scan(thrust::cuda::par(alloc), column_sizes_scan->begin(), column_sizes_scan->end(), column_sizes_scan->begin()); } -inline size_t constexpr BytesPerElement(bool has_weight) { - // Double the memory usage for sorting. We need to assign weight for each element, so - // sizeof(float) is added to all elements. - return (has_weight ? sizeof(Entry) + sizeof(float) : sizeof(Entry)) * 2; +/** + * \brief Type for sorted index. + */ +using SortedIdxT = std::uint32_t; + +/** + * \brief Maximum number of elements for each batch, limited by the type of the sorted index. + */ +inline constexpr std::size_t kMaxNumEntrySort = std::numeric_limits::max(); + +/** + * \brief Return sorted index of input entries. KeyIt is an iterator that returns `xgboost::Entry`. + */ +template +void ArgSortEntry(KeyIt key_it, dh::device_vector* p_sorted_idx) { + auto& sorted_idx = *p_sorted_idx; + std::size_t n = sorted_idx.size(); + CHECK_LE(n, kMaxNumEntrySort); + + std::size_t bytes{0}; + std::byte* ptr{nullptr}; + cub_argsort::DeviceRadixSort::Argsort(ptr, bytes, key_it, + sorted_idx.data().get(), n); + dh::device_vector alloc(bytes); + ptr = alloc.data().get(); + cub_argsort::DeviceRadixSort::Argsort(ptr, bytes, key_it, + sorted_idx.data().get(), n); } -/* \brief Calcuate the length of sliding window. Returns `sketch_batch_num_elements` +/** + * \brief Calcuate the length of sliding window. Returns `sketch_batch_num_elements` * directly if it's not 0. */ -size_t SketchBatchNumElements(size_t sketch_batch_num_elements, - bst_row_t num_rows, bst_feature_t columns, - size_t nnz, int device, - size_t num_cuts, bool has_weight); +std::size_t SketchBatchNumElements(size_t sketch_batch_num_elements, bst_row_t num_rows, + bst_feature_t columns, size_t nnz, int device, size_t num_cuts, + bool has_weight, bool d_can_read); // Compute number of sample cuts needed on local node to maintain accuracy // We take more cuts than needed and then reduce them later size_t RequiredSampleCutsPerColumn(int max_bins, size_t num_rows); -/* \brief Estimate required memory for each sliding window. +/** + * \brief Estimate required memory for each sliding window. * * It's not precise as to obtain exact memory usage for sparse dataset we need to walk * through the whole dataset first. Also if data is from host DMatrix, we copy the @@ -113,46 +223,104 @@ size_t RequiredSampleCutsPerColumn(int max_bins, size_t num_rows); * cols if nnz is unknown. * \param num_bins Number of histogram bins. * \param with_weights Whether weight is used, works the same for ranking and other models. + * \param d_can_read Whether the device alread has read access to the data. * * \return The estimated bytes */ -size_t RequiredMemory(bst_row_t num_rows, bst_feature_t num_columns, size_t nnz, - size_t num_bins, bool with_weights); - -// Count the valid entries in each column and copy them out. -template -void MakeEntriesFromAdapter(AdapterBatch const& batch, BatchIter batch_iter, - Range1d range, float missing, - size_t columns, size_t cuts_per_feature, int device, - HostDeviceVector* cut_sizes_scan, - dh::caching_device_vector* column_sizes_scan, - dh::device_vector* sorted_entries) { - auto entry_iter = dh::MakeTransformIterator( - thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) { - return Entry(batch.GetElement(idx).column_idx, - batch.GetElement(idx).value); - }); - data::IsValidFunctor is_valid(missing); +std::size_t RequiredMemory(bst_row_t num_rows, bst_feature_t num_columns, std::size_t nnz, + bst_bin_t num_bins, bool with_weights, bool d_can_read); + +/** + * \brief Count the valid entries in each column and sort them. + * + * \param batch_iter Iterator to data batch, with value_type as data::COOTuple. + * \param range Boundary of the current sliding window. + * \param is_valid Specify the missing value. + * \param columns Number of features. + * \param cuts_per_feature Number of required cuts for each feature, which is estimated by + * sketching container. + * \param device CUDA ordinal. + * \param p_cut_sizes_scan Output cuts ptr. + * \param p_column_sizes_scan Output feature ptr. + * \param p_sorted_idx Output sorted index of input data (batch_iter). + */ +template +void MakeEntriesFromAdapter(BatchIter batch_iter, Range1d range, data::IsValidFunctor is_valid, + std::size_t columns, std::size_t cuts_per_feature, int device, + HostDeviceVector* p_cut_sizes_scan, + dh::caching_device_vector* p_column_sizes_scan, + dh::device_vector* p_sorted_idx) { + auto n = range.end() - range.begin(); + auto span = IterSpan{batch_iter + range.begin(), n}; // Work out how many valid entries we have in each column - GetColumnSizesScan(device, columns, cuts_per_feature, - batch_iter, is_valid, - range.begin(), range.end(), - cut_sizes_scan, - column_sizes_scan); - size_t num_valid = column_sizes_scan->back(); - // Copy current subset of valid elements into temporary storage and sort - sorted_entries->resize(num_valid); - dh::CopyIf(entry_iter + range.begin(), entry_iter + range.end(), - sorted_entries->begin(), is_valid); + GetColumnSizesScan(device, columns, cuts_per_feature, span, is_valid, p_cut_sizes_scan, + p_column_sizes_scan); + // Sort the current subset of valid elements. + dh::device_vector& sorted_idx = *p_sorted_idx; + sorted_idx.resize(span.size()); + + std::size_t n_valids = p_column_sizes_scan->back(); + + auto key_it = dh::MakeTransformIterator( + span.data(), [=] XGBOOST_DEVICE(data::COOTuple const& tup) -> Entry { + if (is_valid(tup)) { + return {static_cast(tup.column_idx), tup.value}; + } + // Push invalid elements to the end + return {std::numeric_limits::max(), std::numeric_limits::max()}; + }); + ArgSortEntry(key_it, &sorted_idx); + + sorted_idx.resize(n_valids); } -void SortByWeight(dh::device_vector* weights, - dh::device_vector* sorted_entries); +template +void RemoveDuplicatedCategories(int32_t device, MetaInfo const& info, Span d_cuts_ptr, + BatchIter batch_iter, dh::device_vector* p_sorted_idx, + dh::caching_device_vector* p_column_sizes_scan) { + info.feature_types.SetDevice(device); + auto d_feature_types = info.feature_types.ConstDeviceSpan(); + CHECK(!d_feature_types.empty()); + auto& column_sizes_scan = *p_column_sizes_scan; + // Removing duplicated entries in categorical features. + dh::caching_device_vector new_column_scan(column_sizes_scan.size()); + auto d_sorted_idx = dh::ToSpan(*p_sorted_idx); + dh::SegmentedUnique( + column_sizes_scan.data().get(), column_sizes_scan.data().get() + column_sizes_scan.size(), + dh::tcbegin(d_sorted_idx), dh::tcend(d_sorted_idx), new_column_scan.data().get(), + dh::tbegin(d_sorted_idx), [=] __device__(SortedIdxT l, SortedIdxT r) { + data::COOTuple const& le = batch_iter[l]; + data::COOTuple const& re = batch_iter[r]; + if (le.column_idx == re.column_idx) { + if (IsCat(d_feature_types, le.column_idx)) { + return le.value == re.value; + } + } + return false; + }); -void RemoveDuplicatedCategories( - int32_t device, MetaInfo const &info, Span d_cuts_ptr, - dh::device_vector *p_sorted_entries, - dh::caching_device_vector *p_column_sizes_scan); + // Renew the column scan and cut scan based on categorical data. + auto d_old_column_sizes_scan = dh::ToSpan(column_sizes_scan); + dh::caching_device_vector new_cuts_size(info.num_col_ + 1); + CHECK_EQ(new_column_scan.size(), new_cuts_size.size()); + dh::LaunchN(new_column_scan.size(), + [=, d_new_cuts_size = dh::ToSpan(new_cuts_size), + d_old_column_sizes_scan = dh::ToSpan(column_sizes_scan), + d_new_columns_ptr = dh::ToSpan(new_column_scan)] __device__(size_t idx) { + d_old_column_sizes_scan[idx] = d_new_columns_ptr[idx]; + if (idx == d_new_columns_ptr.size() - 1) { + return; + } + if (IsCat(d_feature_types, idx)) { + // Cut size is the same as number of categories in input. + d_new_cuts_size[idx] = d_new_columns_ptr[idx + 1] - d_new_columns_ptr[idx]; + } else { + d_new_cuts_size[idx] = d_cuts_ptr[idx + 1] - d_cuts_ptr[idx]; + } + }); + // Turn size into ptr. + thrust::exclusive_scan(new_cuts_size.cbegin(), new_cuts_size.cend(), d_cuts_ptr.data()); +} } // namespace detail // Compute sketch on DMatrix. @@ -160,70 +328,75 @@ void RemoveDuplicatedCategories( HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins, size_t sketch_batch_num_elements = 0); +// Quantile sketching on DMatrix. Exposed for tests. +void ProcessBatch(std::int32_t device, MetaInfo const& info, const SparsePage& page, + std::size_t begin, std::size_t end, SketchContainer* sketch_container, + bst_bin_t num_cuts_per_feature, std::size_t num_columns); + +// Quantile sketching on DMatrix with weighted samples. Exposed for tests. +void ProcessWeightedBatch(int device, MetaInfo const& info, const SparsePage& page, + std::size_t begin, std::size_t end, SketchContainer* sketch_container, + bst_bin_t num_cuts_per_feature, bst_feature_t num_columns, + bool is_ranking, Span d_group_ptr); + template -void ProcessSlidingWindow(AdapterBatch const &batch, MetaInfo const &info, - int device, size_t columns, size_t begin, size_t end, - float missing, SketchContainer *sketch_container, - int num_cuts) { - // Copy current subset of valid elements into temporary storage and sort - dh::device_vector sorted_entries; +void ProcessSlidingWindow(AdapterBatch const& batch, MetaInfo const& info, int device, + std::size_t columns, std::size_t begin, std::size_t end, float missing, + SketchContainer* sketch_container, int num_cuts) { dh::caching_device_vector column_sizes_scan; auto batch_iter = dh::MakeTransformIterator( thrust::make_counting_iterator(0llu), - [=] __device__(size_t idx) { return batch.GetElement(idx); }); + [=] __device__(std::size_t idx) { return batch.GetElement(idx); }); HostDeviceVector cuts_ptr; cuts_ptr.SetDevice(device); - detail::MakeEntriesFromAdapter(batch, batch_iter, {begin, end}, missing, - columns, num_cuts, device, - &cuts_ptr, - &column_sizes_scan, - &sorted_entries); - dh::XGBDeviceAllocator alloc; - thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(), - sorted_entries.end(), detail::EntryCompareOp()); + + dh::device_vector sorted_idx; + data::IsValidFunctor is_valid(missing); + detail::MakeEntriesFromAdapter(batch_iter, {begin, end}, is_valid, columns, num_cuts, device, + &cuts_ptr, &column_sizes_scan, &sorted_idx); if (sketch_container->HasCategorical()) { auto d_cuts_ptr = cuts_ptr.DeviceSpan(); - detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, - &sorted_entries, &column_sizes_scan); + detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, batch_iter + begin, &sorted_idx, + &column_sizes_scan); } + auto entry_it = dh::MakeTransformIterator( + batch_iter + begin, [=] __device__(data::COOTuple const& tup) { + return Entry{static_cast(tup.column_idx), tup.value}; + }); + auto d_sorted_entry_it = thrust::make_permutation_iterator(entry_it, sorted_idx.cbegin()); auto d_cuts_ptr = cuts_ptr.DeviceSpan(); - auto const &h_cuts_ptr = cuts_ptr.HostVector(); + auto const& h_cuts_ptr = cuts_ptr.HostVector(); // Extract the cuts from all columns concurrently - sketch_container->Push(dh::ToSpan(sorted_entries), - dh::ToSpan(column_sizes_scan), d_cuts_ptr, + sketch_container->Push(d_sorted_entry_it, dh::ToSpan(column_sizes_scan), d_cuts_ptr, h_cuts_ptr.back()); - sorted_entries.clear(); - sorted_entries.shrink_to_fit(); } template -void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info, - int num_cuts_per_feature, - bool is_ranking, float missing, int device, - size_t columns, size_t begin, size_t end, - SketchContainer *sketch_container) { - dh::XGBCachingDeviceAllocator alloc; +void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info, int num_cuts_per_feature, + bool is_ranking, float missing, int device, size_t columns, + std::size_t begin, std::size_t end, + SketchContainer* sketch_container) { dh::safe_cuda(cudaSetDevice(device)); info.weights_.SetDevice(device); auto weights = info.weights_.ConstDeviceSpan(); + dh::XGBCachingDeviceAllocator alloc; + auto batch_iter = dh::MakeTransformIterator( - thrust::make_counting_iterator(0llu), - [=] __device__(size_t idx) { return batch.GetElement(idx); }); - dh::device_vector sorted_entries; + thrust::make_counting_iterator(0llu), + [=] __device__(std::size_t idx) { return batch.GetElement(idx); }); dh::caching_device_vector column_sizes_scan; HostDeviceVector cuts_ptr; - detail::MakeEntriesFromAdapter(batch, batch_iter, - {begin, end}, missing, - columns, num_cuts_per_feature, device, - &cuts_ptr, - &column_sizes_scan, - &sorted_entries); + + dh::device_vector sorted_idx; data::IsValidFunctor is_valid(missing); + detail::MakeEntriesFromAdapter(batch_iter, {begin, end}, is_valid, columns, num_cuts_per_feature, + device, &cuts_ptr, &column_sizes_scan, &sorted_idx); - dh::device_vector temp_weights(sorted_entries.size()); + // sorted_idx.size() represents the number of valid elements. + dh::device_vector temp_weights(sorted_idx.size()); auto d_temp_weights = dh::ToSpan(temp_weights); if (is_ranking) { @@ -238,8 +411,7 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info, bst_group_t group_idx = dh::SegmentId(d_group_ptr, ridx); return weights[group_idx]; }); - auto retit = thrust::copy_if(thrust::cuda::par(alloc), - weight_iter + begin, weight_iter + end, + auto retit = thrust::copy_if(thrust::cuda::par(alloc), weight_iter + begin, weight_iter + end, batch_iter + begin, d_temp_weights.data(), // output is_valid); @@ -248,38 +420,43 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info, CHECK_EQ(batch.NumRows(), weights.size()); auto const weight_iter = dh::MakeTransformIterator( thrust::make_counting_iterator(0lu), - [=]__device__(size_t idx) -> float { - return weights[batch.GetElement(idx).row_idx]; - }); - auto retit = thrust::copy_if(thrust::cuda::par(alloc), - weight_iter + begin, weight_iter + end, + [=] __device__(size_t idx) -> float { return weights[batch.GetElement(idx).row_idx]; }); + auto retit = thrust::copy_if(thrust::cuda::par(alloc), weight_iter + begin, weight_iter + end, batch_iter + begin, d_temp_weights.data(), // output is_valid); CHECK_EQ(retit - d_temp_weights.data(), d_temp_weights.size()); } - detail::SortByWeight(&temp_weights, &sorted_entries); + auto entry_it = dh::MakeTransformIterator( + batch_iter + begin, [=] __device__(data::COOTuple const& tup) { + return Entry{static_cast(tup.column_idx), tup.value}; + }); + auto d_sorted_entry_it = thrust::make_permutation_iterator(entry_it, sorted_idx.cbegin()); + auto d_sorted_weight_it = + thrust::make_permutation_iterator(dh::tbegin(d_temp_weights), sorted_idx.cbegin()); + + thrust::inclusive_scan_by_key( + thrust::cuda::par(alloc), d_sorted_entry_it, d_sorted_entry_it + sorted_idx.size(), + d_sorted_weight_it, d_sorted_weight_it, + [=] __device__(const Entry& a, const Entry& b) { return a.index == b.index; }); if (sketch_container->HasCategorical()) { auto d_cuts_ptr = cuts_ptr.DeviceSpan(); - detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, - &sorted_entries, &column_sizes_scan); + detail::RemoveDuplicatedCategories(device, info, d_cuts_ptr, batch_iter + begin, &sorted_idx, + &column_sizes_scan); } auto const& h_cuts_ptr = cuts_ptr.ConstHostVector(); auto d_cuts_ptr = cuts_ptr.DeviceSpan(); // Extract cuts - sketch_container->Push(dh::ToSpan(sorted_entries), - dh::ToSpan(column_sizes_scan), d_cuts_ptr, - h_cuts_ptr.back(), dh::ToSpan(temp_weights)); - sorted_entries.clear(); - sorted_entries.shrink_to_fit(); + sketch_container->Push(d_sorted_entry_it, dh::ToSpan(column_sizes_scan), d_cuts_ptr, + h_cuts_ptr.back(), IterSpan{d_sorted_weight_it, sorted_idx.size()}); } -/* - * \brief Perform sketching on GPU. +/** + * \brief Perform sketching on GPU in-place. * * \param batch A batch from adapter. * \param num_bins Bins per column. @@ -290,43 +467,38 @@ void ProcessWeightedSlidingWindow(Batch batch, MetaInfo const& info, * testing. */ template -void AdapterDeviceSketch(Batch batch, int num_bins, - MetaInfo const& info, +void AdapterDeviceSketch(Batch const& batch, bst_bin_t num_bins, MetaInfo const& info, float missing, SketchContainer* sketch_container, - size_t sketch_batch_num_elements = 0) { + std::size_t sketch_batch_num_elements = 0) { size_t num_rows = batch.NumRows(); size_t num_cols = batch.NumCols(); - size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, num_rows); + std::size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, num_rows); int32_t device = sketch_container->DeviceIdx(); bool weighted = !info.weights_.Empty(); if (weighted) { sketch_batch_num_elements = detail::SketchBatchNumElements( - sketch_batch_num_elements, - num_rows, num_cols, std::numeric_limits::max(), - device, num_cuts_per_feature, true); + sketch_batch_num_elements, num_rows, num_cols, std::numeric_limits::max(), device, + num_cuts_per_feature, true, true); for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) { size_t end = std::min(batch.Size(), static_cast(begin + sketch_batch_num_elements)); - ProcessWeightedSlidingWindow(batch, info, - num_cuts_per_feature, - HostSketchContainer::UseGroup(info), missing, device, num_cols, begin, end, - sketch_container); + ProcessWeightedSlidingWindow(batch, info, num_cuts_per_feature, + HostSketchContainer::UseGroup(info), missing, device, num_cols, + begin, end, sketch_container); } } else { sketch_batch_num_elements = detail::SketchBatchNumElements( - sketch_batch_num_elements, - num_rows, num_cols, std::numeric_limits::max(), - device, num_cuts_per_feature, false); + sketch_batch_num_elements, num_rows, num_cols, std::numeric_limits::max(), device, + num_cuts_per_feature, false, true); for (auto begin = 0ull; begin < batch.Size(); begin += sketch_batch_num_elements) { size_t end = std::min(batch.Size(), static_cast(begin + sketch_batch_num_elements)); - ProcessSlidingWindow(batch, info, device, num_cols, begin, end, missing, - sketch_container, num_cuts_per_feature); + ProcessSlidingWindow(batch, info, device, num_cols, begin, end, missing, sketch_container, + num_cuts_per_feature); } } } -} // namespace common -} // namespace xgboost +} // namespace xgboost::common #endif // COMMON_HIST_UTIL_CUH_ diff --git a/src/common/quantile.cu b/src/common/quantile.cu index cabdc603b97e..fcdfe8e260c4 100644 --- a/src/common/quantile.cu +++ b/src/common/quantile.cu @@ -25,96 +25,6 @@ namespace xgboost { namespace common { - -using WQSketch = HostSketchContainer::WQSketch; -using SketchEntry = WQSketch::Entry; - -// Algorithm 4 in XGBoost's paper, using binary search to find i. -template -__device__ SketchEntry BinarySearchQuery(EntryIter beg, EntryIter end, float rank) { - assert(end - beg >= 2); - rank *= 2; - auto front = *beg; - if (rank < front.rmin + front.rmax) { - return *beg; - } - auto back = *(end - 1); - if (rank >= back.rmin + back.rmax) { - return back; - } - - auto search_begin = dh::MakeTransformIterator( - beg, [=] __device__(SketchEntry const &entry) { - return entry.rmin + entry.rmax; - }); - auto search_end = search_begin + (end - beg); - auto i = - thrust::upper_bound(thrust::seq, search_begin + 1, search_end - 1, rank) - - search_begin - 1; - if (rank < (*(beg + i)).RMinNext() + (*(beg + i + 1)).RMaxPrev()) { - return *(beg + i); - } else { - return *(beg + i + 1); - } -} - -template -void PruneImpl(common::Span cuts_ptr, - Span sorted_data, - Span columns_ptr_in, // could be ptr for data or cuts - Span feature_types, Span out_cuts, - ToSketchEntry to_sketch_entry) { - dh::LaunchN(out_cuts.size(), [=] __device__(size_t idx) { - size_t column_id = dh::SegmentId(cuts_ptr, idx); - auto out_column = out_cuts.subspan( - cuts_ptr[column_id], cuts_ptr[column_id + 1] - cuts_ptr[column_id]); - auto in_column = sorted_data.subspan(columns_ptr_in[column_id], - columns_ptr_in[column_id + 1] - - columns_ptr_in[column_id]); - auto to = cuts_ptr[column_id + 1] - cuts_ptr[column_id]; - idx -= cuts_ptr[column_id]; - auto front = to_sketch_entry(0ul, in_column, column_id); - auto back = to_sketch_entry(in_column.size() - 1, in_column, column_id); - - auto is_cat = IsCat(feature_types, column_id); - if (in_column.size() <= to || is_cat) { - // cut idx equals sample idx - out_column[idx] = to_sketch_entry(idx, in_column, column_id); - return; - } - // 1 thread for each output. See A.4 for detail. - auto d_out = out_column; - if (idx == 0) { - d_out.front() = front; - return; - } - if (idx == to - 1) { - d_out.back() = back; - return; - } - - float w = back.rmin - front.rmax; - auto budget = static_cast(d_out.size()); - assert(budget != 0); - auto q = ((static_cast(idx) * w) / (static_cast(to) - 1.0f) + front.rmax); - auto it = dh::MakeTransformIterator( - thrust::make_counting_iterator(0ul), [=] __device__(size_t idx) { - auto e = to_sketch_entry(idx, in_column, column_id); - return e; - }); - d_out[idx] = BinarySearchQuery(it, it + in_column.size(), q); - }); -} - -template -void CopyTo(Span out, Span src) { - CHECK_EQ(out.size(), src.size()); - static_assert(std::is_same, std::remove_cv_t>::value); - dh::safe_cuda(cudaMemcpyAsync(out.data(), src.data(), - out.size_bytes(), - cudaMemcpyDefault)); -} - // Compute the merge path. common::Span> MergePath( Span const &d_x, Span const &x_ptr, @@ -306,62 +216,6 @@ void MergeImpl(int32_t device, Span const &d_x, }); } -void SketchContainer::Push(Span entries, Span columns_ptr, - common::Span cuts_ptr, - size_t total_cuts, Span weights) { - dh::safe_cuda(cudaSetDevice(device_)); - Span out; - dh::device_vector cuts; - bool first_window = this->Current().empty(); - if (!first_window) { - cuts.resize(total_cuts); - out = dh::ToSpan(cuts); - } else { - this->Current().resize(total_cuts); - out = dh::ToSpan(this->Current()); - } - auto ft = this->feature_types_.ConstDeviceSpan(); - if (weights.empty()) { - auto to_sketch_entry = [] __device__(size_t sample_idx, - Span const &column, - size_t) { - float rmin = sample_idx; - float rmax = sample_idx + 1; - return SketchEntry{rmin, rmax, 1, column[sample_idx].fvalue}; - }; // NOLINT - PruneImpl(cuts_ptr, entries, columns_ptr, ft, out, to_sketch_entry); - } else { - auto to_sketch_entry = [weights, columns_ptr] __device__( - size_t sample_idx, - Span const &column, - size_t column_id) { - Span column_weights_scan = - weights.subspan(columns_ptr[column_id], column.size()); - float rmin = sample_idx > 0 ? column_weights_scan[sample_idx - 1] : 0.0f; - float rmax = column_weights_scan[sample_idx]; - float wmin = rmax - rmin; - wmin = wmin < 0 ? kRtEps : wmin; // GPU scan can generate floating error. - return SketchEntry{rmin, rmax, wmin, column[sample_idx].fvalue}; - }; // NOLINT - PruneImpl(cuts_ptr, entries, columns_ptr, ft, out, to_sketch_entry); - } - auto n_uniques = this->ScanInput(out, cuts_ptr); - - if (!first_window) { - CHECK_EQ(this->columns_ptr_.Size(), cuts_ptr.size()); - out = out.subspan(0, n_uniques); - this->Merge(cuts_ptr, out); - this->FixError(); - } else { - this->Current().resize(n_uniques); - this->columns_ptr_.SetDevice(device_); - this->columns_ptr_.Resize(cuts_ptr.size()); - - auto d_cuts_ptr = this->columns_ptr_.DeviceSpan(); - CopyTo(d_cuts_ptr, cuts_ptr); - } -} - size_t SketchContainer::ScanInput(Span entries, Span d_columns_ptr_in) { /* There are 2 types of duplication. First is duplicated feature values, which comes * from user input data. Second is duplicated sketching entries, which is generated by @@ -429,11 +283,11 @@ void SketchContainer::Prune(size_t to) { auto d_columns_ptr_out = columns_ptr_b_.ConstDeviceSpan(); auto out = dh::ToSpan(this->Other()); auto in = dh::ToSpan(this->Current()); - auto no_op = [] __device__(size_t sample_idx, - Span const &entries, - size_t) { return entries[sample_idx]; }; // NOLINT + auto no_op = [] __device__(size_t sample_idx, auto const &entries, size_t) { + return entries[sample_idx]; + }; // NOLINT auto ft = this->feature_types_.ConstDeviceSpan(); - PruneImpl(d_columns_ptr_out, in, d_columns_ptr_in, ft, out, no_op); + PruneImpl(d_columns_ptr_out, in.data(), d_columns_ptr_in, ft, out, no_op); this->columns_ptr_.Copy(columns_ptr_b_); this->Alternate(); diff --git a/src/common/quantile.cuh b/src/common/quantile.cuh index 7ebd4ff51663..7a89c5e7b179 100644 --- a/src/common/quantile.cuh +++ b/src/common/quantile.cuh @@ -1,22 +1,112 @@ +/** + * Copyright 2020~2023, XGBoost contributors + * + * \brief GPU implementation of GK sketching. + */ #ifndef XGBOOST_COMMON_QUANTILE_CUH_ #define XGBOOST_COMMON_QUANTILE_CUH_ #include -#include "xgboost/span.h" -#include "xgboost/data.h" +#include "categorical.h" #include "device_helpers.cuh" #include "quantile.h" #include "timer.h" -#include "categorical.h" +#include "xgboost/data.h" +#include "xgboost/span.h" // for IterSpan, Span -namespace xgboost { -namespace common { +namespace xgboost::common { class HistogramCuts; using WQSketch = WQuantileSketch; using SketchEntry = WQSketch::Entry; +// Algorithm 4 in XGBoost's paper, using binary search to find i. +template +__device__ SketchEntry BinarySearchQuery(EntryIter beg, EntryIter end, float rank) { + assert(end - beg >= 2); + rank *= 2; + auto front = *beg; + if (rank < front.rmin + front.rmax) { + return *beg; + } + auto back = *(end - 1); + if (rank >= back.rmin + back.rmax) { + return back; + } + + auto search_begin = dh::MakeTransformIterator( + beg, [=] __device__(SketchEntry const &entry) { + return entry.rmin + entry.rmax; + }); + auto search_end = search_begin + (end - beg); + auto i = + thrust::upper_bound(thrust::seq, search_begin + 1, search_end - 1, rank) - + search_begin - 1; + if (rank < (*(beg + i)).RMinNext() + (*(beg + i + 1)).RMaxPrev()) { + return *(beg + i); + } else { + return *(beg + i + 1); + } +} + +template +void PruneImpl(common::Span cuts_ptr, EntryIter sorted_data, + Span columns_ptr_in, // could be ptr for data or cuts + Span feature_types, Span out_cuts, + ToSketchEntry to_sketch_entry) { + dh::LaunchN(out_cuts.size(), [=] __device__(size_t idx) { + size_t column_id = dh::SegmentId(cuts_ptr, idx); + auto out_column = out_cuts.subspan( + cuts_ptr[column_id], cuts_ptr[column_id + 1] - cuts_ptr[column_id]); + + auto in_column_beg = columns_ptr_in[column_id]; + auto in_column = + IterSpan{sorted_data + in_column_beg, columns_ptr_in[column_id + 1] - in_column_beg}; + // auto in_column = sorted_data.subspan(columns_ptr_in[column_id], + // columns_ptr_in[column_id + 1] - columns_ptr_in[column_id]); + auto to = cuts_ptr[column_id + 1] - cuts_ptr[column_id]; + idx -= cuts_ptr[column_id]; + auto front = to_sketch_entry(0ul, in_column, column_id); + auto back = to_sketch_entry(in_column.size() - 1, in_column, column_id); + + auto is_cat = IsCat(feature_types, column_id); + if (in_column.size() <= to || is_cat) { + // cut idx equals sample idx + out_column[idx] = to_sketch_entry(idx, in_column, column_id); + return; + } + // 1 thread for each output. See A.4 for detail. + auto d_out = out_column; + if (idx == 0) { + d_out.front() = front; + return; + } + if (idx == to - 1) { + d_out.back() = back; + return; + } + + float w = back.rmin - front.rmax; + auto budget = static_cast(d_out.size()); + assert(budget != 0); + auto q = ((static_cast(idx) * w) / (static_cast(to) - 1.0f) + front.rmax); + auto it = dh::MakeTransformIterator( + thrust::make_counting_iterator(0ul), [=] __device__(size_t idx) { + auto e = to_sketch_entry(idx, in_column, column_id); + return e; + }); + d_out[idx] = BinarySearchQuery(it, it + in_column.size(), q); + }); +} + +template +void CopyTo(Span out, Span src) { + CHECK_EQ(out.size(), src.size()); + static_assert(std::is_same, std::remove_cv_t>::value); + dh::safe_cuda(cudaMemcpyAsync(out.data(), src.data(), out.size_bytes(), cudaMemcpyDefault)); +} + namespace detail { struct SketchUnique { XGBOOST_DEVICE bool operator()(SketchEntry const& a, SketchEntry const& b) const { @@ -67,10 +157,10 @@ class SketchContainer { return entries_b_; } } - dh::device_vector const& Current() const { + [[nodiscard]] dh::device_vector const& Current() const { return const_cast(this)->Current(); } - dh::device_vector const& Other() const { + [[nodiscard]] dh::device_vector const& Other() const { return const_cast(this)->Other(); } void Alternate() { @@ -88,7 +178,7 @@ class SketchContainer { public: /* \breif GPU quantile structure, with sketch data for each columns. * - * \param max_bin Maximum number of bins per columns + * \param max_bin Maximum number of bins per column. * \param num_columns Total number of columns in dataset. * \param num_rows Total number of rows in known dataset (typically the rows in current worker). * \param device GPU ID. @@ -130,17 +220,65 @@ class SketchContainer { * addition inside `RMinNext` and subtraction in `RMaxPrev`. */ void FixError(); - /* \brief Push sorted entries. + /** + * \brief Push sorted entries. * - * \param entries Sorted entries. + * \param sorted_entries Iterator to sorted entries. * \param columns_ptr CSC pointer for entries. * \param cuts_ptr CSC pointer for cuts. * \param total_cuts Total number of cuts, equal to the back of cuts_ptr. * \param weights (optional) data weights. */ - void Push(Span entries, Span columns_ptr, - common::Span cuts_ptr, size_t total_cuts, - Span weights = {}); + template ::iterator> + void Push(EntryIter sorted_entries, Span columns_ptr, common::Span cuts_ptr, + size_t total_cuts, IterSpan weights = {}) { + dh::safe_cuda(cudaSetDevice(device_)); + Span out; + dh::device_vector cuts; + bool first_window = this->Current().empty(); + if (!first_window) { + cuts.resize(total_cuts); + out = dh::ToSpan(cuts); + } else { + this->Current().resize(total_cuts); + out = dh::ToSpan(this->Current()); + } + auto ft = this->feature_types_.ConstDeviceSpan(); + if (weights.empty()) { + auto to_sketch_entry = [] __device__(size_t sample_idx, auto const& column, size_t) { + float rmin = sample_idx; + float rmax = sample_idx + 1; + return SketchEntry{rmin, rmax, 1, column[sample_idx].fvalue}; + }; // NOLINT + PruneImpl(cuts_ptr, sorted_entries, columns_ptr, ft, out, to_sketch_entry); + } else { + auto to_sketch_entry = [weights, columns_ptr] __device__( + size_t sample_idx, auto const& column, size_t column_id) { + auto column_weights_scan = weights.subspan(columns_ptr[column_id], column.size()); + float rmin = sample_idx > 0 ? column_weights_scan[sample_idx - 1] : 0.0f; + float rmax = column_weights_scan[sample_idx]; + float wmin = rmax - rmin; + wmin = wmin < 0 ? kRtEps : wmin; // GPU scan can generate floating error. + return SketchEntry{rmin, rmax, wmin, column[sample_idx].fvalue}; + }; // NOLINT + PruneImpl(cuts_ptr, sorted_entries, columns_ptr, ft, out, to_sketch_entry); + } + auto n_uniques = this->ScanInput(out, cuts_ptr); + + if (!first_window) { + CHECK_EQ(this->columns_ptr_.Size(), cuts_ptr.size()); + out = out.subspan(0, n_uniques); + this->Merge(cuts_ptr, out); + this->FixError(); + } else { + this->Current().resize(n_uniques); + this->columns_ptr_.SetDevice(device_); + this->columns_ptr_.Resize(cuts_ptr.size()); + + auto d_cuts_ptr = this->columns_ptr_.DeviceSpan(); + CopyTo(d_cuts_ptr, cuts_ptr); + } + } /* \brief Prune the quantile structure. * * \param to The maximum size of pruned quantile. If the size of quantile @@ -199,7 +337,6 @@ class SketchContainer { return n_uniques; } }; -} // namespace common -} // namespace xgboost +} // namespace xgboost::common #endif // XGBOOST_COMMON_QUANTILE_CUH_ diff --git a/src/common/stats.cuh b/src/common/stats.cuh index f31233461f6d..e8503520d4ef 100644 --- a/src/common/stats.cuh +++ b/src/common/stats.cuh @@ -146,7 +146,7 @@ auto MakeWQSegOp(SegIt seg_it, ValIt val_it, AlphaIt alpha_it, Span * std::distance(seg_begin, seg_end) should be equal to n_segments + 1 */ template ::value>* = nullptr> + std::enable_if_t>* = nullptr> void SegmentedQuantile(Context const* ctx, AlphaIt alpha_it, SegIt seg_begin, SegIt seg_end, ValIt val_begin, ValIt val_end, HostDeviceVector* quantiles) { dh::device_vector sorted_idx; diff --git a/src/cub_sort/agent/agent_radix_sort_histogram.cuh b/src/cub_sort/agent/agent_radix_sort_histogram.cuh new file mode 100644 index 000000000000..7b93f2ba035f --- /dev/null +++ b/src/cub_sort/agent/agent_radix_sort_histogram.cuh @@ -0,0 +1,245 @@ +/****************************************************************************** + * Copyright (c) 2011-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * agent_radix_sort_histogram.cuh implements a stateful abstraction of CUDA + * thread blocks for participating in the device histogram kernel used for + * one-sweep radix sorting. + */ + +#pragma once + +#include + +#include +#include +#include +#include +#include // for memcpy + +#include "../block/radix_rank_sort_operations.cuh" +#include "../util_type.cuh" + +// NOLINTBEGIN +namespace cub_argsort { +template < + int _BLOCK_THREADS, + int _ITEMS_PER_THREAD, + int NOMINAL_4B_NUM_PARTS, + typename ComputeT, + int _RADIX_BITS> +struct AgentRadixSortHistogramPolicy +{ + enum + { + BLOCK_THREADS = _BLOCK_THREADS, + ITEMS_PER_THREAD = _ITEMS_PER_THREAD, + /** NUM_PARTS is the number of private histograms (parts) each histogram is split + * into. Each warp lane is assigned to a specific part based on the lane + * ID. However, lanes with the same ID in different warp use the same private + * histogram. This arrangement helps reduce the degree of conflicts in atomic + * operations. */ + NUM_PARTS = CUB_MAX(1, NOMINAL_4B_NUM_PARTS * 4 / CUB_MAX(sizeof(ComputeT), 4)), + RADIX_BITS = _RADIX_BITS, + }; +}; + +template < + int _BLOCK_THREADS, + int _RADIX_BITS> +struct AgentRadixSortExclusiveSumPolicy +{ + enum + { + BLOCK_THREADS = _BLOCK_THREADS, + RADIX_BITS = _RADIX_BITS, + }; +}; + +template +struct AgentRadixSortHistogram { + using KeyT = typename std::iterator_traits::value_type; + // constants + enum { + ITEMS_PER_THREAD = AgentRadixSortHistogramPolicy::ITEMS_PER_THREAD, + BLOCK_THREADS = AgentRadixSortHistogramPolicy::BLOCK_THREADS, + TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD, + RADIX_BITS = AgentRadixSortHistogramPolicy::RADIX_BITS, + RADIX_DIGITS = 1 << RADIX_BITS, + MAX_NUM_PASSES = (sizeof(KeyT) * 8 + RADIX_BITS - 1) / RADIX_BITS, + NUM_PARTS = AgentRadixSortHistogramPolicy::NUM_PARTS, + }; + + using Twiddle = cub_argsort::RadixSortTwiddle; + using ShmemCounterT = std::uint32_t; + using ShmemAtomicCounterT = ShmemCounterT; + using UnsignedBits = typename cub_argsort::MyTraits< + typename std::iterator_traits::value_type>::UnsignedBits; + + struct _TempStorage { + ShmemAtomicCounterT bins[MAX_NUM_PASSES][RADIX_DIGITS][NUM_PARTS]; + }; + + struct TempStorage : ::cub::Uninitialized<_TempStorage> {}; + + // thread fields + // shared memory storage + _TempStorage& s; + + // bins for the histogram + OffsetT* d_bins_out; + + // data to compute the histogram + KeyIteratorT d_keys_in; + + // number of data items + OffsetT num_items; + + // begin and end bits for sorting + int begin_bit, end_bit; + + // number of sorting passes + int num_passes; + + __device__ __forceinline__ AgentRadixSortHistogram(TempStorage& temp_storage, + OffsetT* d_bins_out, KeyIteratorT d_keys_in, + OffsetT num_items, int begin_bit, + int end_bit) + : s(temp_storage.Alias()), + d_bins_out(d_bins_out), + d_keys_in{d_keys_in}, + num_items(num_items), + begin_bit(begin_bit), + end_bit(end_bit), + num_passes((end_bit - begin_bit + RADIX_BITS - 1) / RADIX_BITS) {} + + __device__ __forceinline__ void Init() { +// Initialize bins to 0. +#pragma unroll + for (int bin = threadIdx.x; bin < RADIX_DIGITS; bin += BLOCK_THREADS) { +#pragma unroll + for (int pass = 0; pass < num_passes; ++pass) { +#pragma unroll + for (int part = 0; part < NUM_PARTS; ++part) { + s.bins[pass][bin][part] = 0; + } + } + } + ::cub::CTA_SYNC(); + } + + __device__ __forceinline__ void LoadTileKeys(OffsetT tile_offset, + UnsignedBits (&keys)[ITEMS_PER_THREAD]) { + // tile_offset < num_items always, hence the line below works + bool full_tile = num_items - tile_offset >= TILE_ITEMS; + auto it = thrust::make_transform_iterator(d_keys_in, [] __device__(auto const& v) { + static_assert(sizeof(std::remove_reference_t) == sizeof(UnsignedBits)); + UnsignedBits dst; + std::memcpy(&dst, &v, sizeof(v)); + return dst; + }); + if (full_tile) { + ::cub::LoadDirectStriped(threadIdx.x, it + tile_offset, keys); + } else { + ::cub::LoadDirectStriped(threadIdx.x, it + tile_offset, keys, + num_items - tile_offset, Twiddle::DefaultKey()); + } + +#pragma unroll + for (int u = 0; u < ITEMS_PER_THREAD; ++u) { + keys[u] = Twiddle::In(keys[u]); + } + } + + __device__ __forceinline__ void AccumulateSharedHistograms( + OffsetT tile_offset, UnsignedBits (&keys)[ITEMS_PER_THREAD]) { + int part = ::cub::LaneId() % NUM_PARTS; +#pragma unroll + for (int current_bit = begin_bit, pass = 0; current_bit < end_bit; + current_bit += RADIX_BITS, ++pass) { + int num_bits = CUB_MIN(RADIX_BITS, end_bit - current_bit); + DigitExtractorT digit_extractor(current_bit, num_bits); +#pragma unroll + for (int u = 0; u < ITEMS_PER_THREAD; ++u) { + int bin = digit_extractor.Digit(keys[u]); + // Using cuda::atomic<> results in lower performance on GP100, + // so atomicAdd() is used instead. + atomicAdd(&s.bins[pass][bin][part], 1); + } + } + } + + __device__ __forceinline__ void AccumulateGlobalHistograms() { +#pragma unroll + for (int bin = threadIdx.x; bin < RADIX_DIGITS; bin += BLOCK_THREADS) { +#pragma unroll + for (int pass = 0; pass < num_passes; ++pass) { + OffsetT count = ::cub::internal::ThreadReduce(s.bins[pass][bin], ::cub::Sum()); + if (count > 0) { + // Using cuda::atomic<> here would also require using it in + // other kernels. However, other kernels of onesweep sorting + // (ExclusiveSum, Onesweep) don't need atomic + // access. Therefore, atomicAdd() is used, until + // cuda::atomic_ref<> becomes available. + atomicAdd(&d_bins_out[pass * RADIX_DIGITS + bin], count); + } + } + } + } + + __device__ __forceinline__ void Process() { + // Within a portion, avoid overflowing (u)int32 counters. + // Between portions, accumulate results in global memory. + const OffsetT MAX_PORTION_SIZE = 1 << 30; + OffsetT num_portions = ::cub::DivideAndRoundUp(num_items, MAX_PORTION_SIZE); + for (OffsetT portion = 0; portion < num_portions; ++portion) { + // Reset the counters. + Init(); + ::cub::CTA_SYNC(); + + // Process the tiles. + OffsetT portion_offset = portion * MAX_PORTION_SIZE; + OffsetT portion_size = CUB_MIN(MAX_PORTION_SIZE, num_items - portion_offset); + for (OffsetT offset = blockIdx.x * TILE_ITEMS; offset < portion_size; + offset += TILE_ITEMS * gridDim.x) { + OffsetT tile_offset = portion_offset + offset; + UnsignedBits keys[ITEMS_PER_THREAD]; + LoadTileKeys(tile_offset, keys); + AccumulateSharedHistograms(tile_offset, keys); + } + ::cub::CTA_SYNC(); + + // Accumulate the result in global memory. + AccumulateGlobalHistograms(); + ::cub::CTA_SYNC(); + } + } +}; +} // namespace cub_argsort +// NOLINTEND diff --git a/src/cub_sort/agent/agent_radix_sort_onesweep.cuh b/src/cub_sort/agent/agent_radix_sort_onesweep.cuh new file mode 100644 index 000000000000..f0a4762e2203 --- /dev/null +++ b/src/cub_sort/agent/agent_radix_sort_onesweep.cuh @@ -0,0 +1,521 @@ +/****************************************************************************** + * Copyright (c) 2011-2020, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * agent_radix_sort_onesweep.cuh implements a stateful abstraction of CUDA + * thread blocks for participating in the device one-sweep radix sort kernel. + */ + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include + +#include "../block/radix_rank_sort_operations.cuh" +#include "../util_type.cuh" + +// NOLINTBEGIN +namespace cub_argsort { +template +XGBOOST_DEVICE thrust::transform_iterator MakeTransformIterator(IterT iter, + FuncT func) { + return thrust::transform_iterator(iter, func); +} + +struct EntryExtractor { + static_assert(sizeof(xgboost::Entry) == 8); + std::uint32_t bit_start{0}; + // Need special handling for floating point. + ::cub::ShiftDigitExtractor shift_fv; + ::cub::ShiftDigitExtractor shift_idx; + + using UnsignedBits = typename MyTraits::UnsignedBits; + + explicit XGBOOST_DEVICE EntryExtractor(std::uint32_t bit_start = 0, std::uint32_t num_bits = 0) + : bit_start{bit_start}, + shift_fv{bit_start >= 8 * sizeof(float) + ? bit_start - static_cast(8 * sizeof(float)) + : bit_start, + num_bits}, + shift_idx{bit_start >= 8 * sizeof(float) + ? bit_start - static_cast(8 * sizeof(float)) + : bit_start, + num_bits} {} + + __device__ __forceinline__ std::uint32_t Digit(UnsignedBits key) { + static_assert(sizeof(UnsignedBits) == sizeof(xgboost::Entry)); + auto* ptr = reinterpret_cast(&key); + std::uint32_t f; + + if (bit_start < sizeof(float) * 8) { + auto v = ptr[0]; // fvalue + std::memcpy(&f, &v, sizeof(f)); + static_assert(sizeof(f) == sizeof(v)); + return shift_fv.Digit(f); + } else { + auto v = ptr[1]; // findex + std::memcpy(&f, &v, sizeof(f)); + return shift_idx.Digit(f); + } + } +}; + +/** \brief cub::RadixSortStoreAlgorithm enumerates different algorithms to write + * partitioned elements (keys, values) stored in shared memory into global + * memory. Currently applies only to writing 4B keys in full tiles; in all other cases, + * RADIX_SORT_STORE_DIRECT is used. + */ +enum RadixSortStoreAlgorithm +{ + /** \brief Elements are statically distributed among block threads, which write them + * into the appropriate partition in global memory. This results in fewer instructions + * and more writes in flight at a given moment, but may generate more transactions. */ + RADIX_SORT_STORE_DIRECT, + /** \brief Elements are distributed among warps in a block distribution. Each warp + * goes through its elements and tries to write them while minimizing the number of + * memory transactions. This results in fewer memory transactions, but more + * instructions and less writes in flight at a given moment. */ + RADIX_SORT_STORE_ALIGNED +}; + +template < + int NOMINAL_BLOCK_THREADS_4B, + int NOMINAL_ITEMS_PER_THREAD_4B, + typename ComputeT, + /** \brief Number of private histograms to use in the ranker; + ignored if the ranking algorithm is not one of RADIX_RANK_MATCH_EARLY_COUNTS_* */ + int _RANK_NUM_PARTS, + /** \brief Ranking algorithm used in the onesweep kernel. Only algorithms that + support warp-strided key arrangement and count callbacks are supported. */ + ::cub::RadixRankAlgorithm _RANK_ALGORITHM, + ::cub::BlockScanAlgorithm _SCAN_ALGORITHM, + RadixSortStoreAlgorithm _STORE_ALGORITHM, + int _RADIX_BITS, + typename ScalingType = ::cub::RegBoundScaling< + NOMINAL_BLOCK_THREADS_4B, NOMINAL_ITEMS_PER_THREAD_4B, ComputeT> > +struct AgentRadixSortOnesweepPolicy : ScalingType +{ + enum + { + RANK_NUM_PARTS = _RANK_NUM_PARTS, + RADIX_BITS = _RADIX_BITS, + }; + static const ::cub::RadixRankAlgorithm RANK_ALGORITHM = _RANK_ALGORITHM; + static const ::cub::BlockScanAlgorithm SCAN_ALGORITHM = _SCAN_ALGORITHM; + static const RadixSortStoreAlgorithm STORE_ALGORITHM = _STORE_ALGORITHM; +}; + +template +struct AgentRadixSortOnesweep { + // constants + enum { + ITEMS_PER_THREAD = AgentRadixSortOnesweepPolicy::ITEMS_PER_THREAD, + KEYS_ONLY = std::is_same::value, + BLOCK_THREADS = AgentRadixSortOnesweepPolicy::BLOCK_THREADS, + RANK_NUM_PARTS = AgentRadixSortOnesweepPolicy::RANK_NUM_PARTS, + TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD, + RADIX_BITS = AgentRadixSortOnesweepPolicy::RADIX_BITS, + RADIX_DIGITS = 1 << RADIX_BITS, + BINS_PER_THREAD = (RADIX_DIGITS + BLOCK_THREADS - 1) / BLOCK_THREADS, + FULL_BINS = BINS_PER_THREAD * BLOCK_THREADS == RADIX_DIGITS, + WARP_THREADS = CUB_PTX_WARP_THREADS, + BLOCK_WARPS = BLOCK_THREADS / WARP_THREADS, + WARP_MASK = ~0, + LOOKBACK_PARTIAL_MASK = 1 << (PortionOffsetT(sizeof(PortionOffsetT)) * 8 - 2), + LOOKBACK_GLOBAL_MASK = 1 << (PortionOffsetT(sizeof(PortionOffsetT)) * 8 - 1), + LOOKBACK_KIND_MASK = LOOKBACK_PARTIAL_MASK | LOOKBACK_GLOBAL_MASK, + LOOKBACK_VALUE_MASK = ~LOOKBACK_KIND_MASK, + }; + + using KeyT = typename std::iterator_traits::value_type; + + using UnsignedBits = typename MyTraits::UnsignedBits; + using AtomicOffsetT = PortionOffsetT; + + static const ::cub::RadixRankAlgorithm RANK_ALGORITHM = + AgentRadixSortOnesweepPolicy::RANK_ALGORITHM; + static const ::cub::BlockScanAlgorithm SCAN_ALGORITHM = + AgentRadixSortOnesweepPolicy::SCAN_ALGORITHM; + static const RadixSortStoreAlgorithm STORE_ALGORITHM = + sizeof(UnsignedBits) == sizeof(uint32_t) ? AgentRadixSortOnesweepPolicy::STORE_ALGORITHM + : RADIX_SORT_STORE_DIRECT; + using Twiddle = RadixSortTwiddle; + + static_assert(RANK_ALGORITHM == ::cub::RADIX_RANK_MATCH || + RANK_ALGORITHM == ::cub::RADIX_RANK_MATCH_EARLY_COUNTS_ANY || + RANK_ALGORITHM == ::cub::RADIX_RANK_MATCH_EARLY_COUNTS_ATOMIC_OR, + "for onesweep agent, the ranking algorithm must warp-strided key arrangement"); + + using BlockRadixRankT = std::conditional_t< + RANK_ALGORITHM == ::cub::RADIX_RANK_MATCH_EARLY_COUNTS_ATOMIC_OR, + ::cub::BlockRadixRankMatchEarlyCounts, + std::conditional_t< + RANK_ALGORITHM == ::cub::RADIX_RANK_MATCH, + ::cub::BlockRadixRankMatch, + ::cub::BlockRadixRankMatchEarlyCounts>>; + + // temporary storage + struct TempStorage_ { + union { + UnsignedBits keys_out[TILE_ITEMS]; + ValueT values_out[TILE_ITEMS]; + typename BlockRadixRankT::TempStorage rank_temp_storage; + }; + union { + OffsetT global_offsets[RADIX_DIGITS]; + PortionOffsetT block_idx; + }; + }; + + using TempStorage = ::cub::Uninitialized; + + // thread variables + TempStorage_& s; + + // kernel parameters + AtomicOffsetT* d_lookback; + AtomicOffsetT* d_ctrs; + OffsetT* d_bins_out; + const OffsetT* d_bins_in; + // const UnsignedBits + KeyIteratorT d_keys_in; + xgboost::common::Span d_values_out; + const ValueT* d_values_in; + // common::Span d_idx_out; + PortionOffsetT num_items; + DigitExtractor digit_extractor; + + // other thread variables + int warp; + int lane; + PortionOffsetT block_idx; + bool full_block; + + // helper methods + __device__ __forceinline__ int Digit(UnsignedBits key) + { + return digit_extractor.Digit(key); + } + + __device__ __forceinline__ int ThreadBin(int u) + { + return threadIdx.x * BINS_PER_THREAD + u; + } + + __device__ __forceinline__ void LookbackPartial(int (&bins)[BINS_PER_THREAD]) + { + #pragma unroll + for (int u = 0; u < BINS_PER_THREAD; ++u) + { + int bin = ThreadBin(u); + if (FULL_BINS || bin < RADIX_DIGITS) + { + // write the local sum into the bin + AtomicOffsetT& loc = d_lookback[block_idx * RADIX_DIGITS + bin]; + PortionOffsetT value = bins[u] | LOOKBACK_PARTIAL_MASK; + ::cub::ThreadStore<::cub::STORE_VOLATILE>(&loc, value); + } + } + } + + struct CountsCallback { + using AgentT = + AgentRadixSortOnesweep; + AgentT& agent; + int (&bins)[BINS_PER_THREAD]; + UnsignedBits (&keys)[ITEMS_PER_THREAD]; + static const bool EMPTY = false; + + __device__ __forceinline__ CountsCallback(AgentT& agent, int (&bins)[BINS_PER_THREAD], + UnsignedBits (&keys)[ITEMS_PER_THREAD]) + : agent(agent), bins(bins), keys(keys) {} + __device__ __forceinline__ void operator()(int (&other_bins)[BINS_PER_THREAD]) { +#pragma unroll + for (int u = 0; u < BINS_PER_THREAD; ++u) { + bins[u] = other_bins[u]; + } + agent.LookbackPartial(bins); + + // agent.TryShortCircuit(keys, bins); + } + }; + + __device__ __forceinline__ void LookbackGlobal(int (&bins)[BINS_PER_THREAD]) + { + #pragma unroll + for (int u = 0; u < BINS_PER_THREAD; ++u) + { + int bin = ThreadBin(u); + if (FULL_BINS || bin < RADIX_DIGITS) + { + PortionOffsetT inc_sum = bins[u]; + int want_mask = ~0; + // backtrack as long as necessary + for (PortionOffsetT block_jdx = block_idx - 1; block_jdx >= 0; --block_jdx) + { + // wait for some value to appear + PortionOffsetT value_j = 0; + AtomicOffsetT& loc_j = d_lookback[block_jdx * RADIX_DIGITS + bin]; + do { + __threadfence_block(); // prevent hoisting loads from loop + value_j = ::cub::ThreadLoad<::cub::LOAD_VOLATILE>(&loc_j); + } while (value_j == 0); + + inc_sum += value_j & LOOKBACK_VALUE_MASK; + want_mask = ::cub::WARP_BALLOT((value_j & LOOKBACK_GLOBAL_MASK) == 0, want_mask); + if (value_j & LOOKBACK_GLOBAL_MASK) break; + } + AtomicOffsetT& loc_i = d_lookback[block_idx * RADIX_DIGITS + bin]; + PortionOffsetT value_i = inc_sum | LOOKBACK_GLOBAL_MASK; + ::cub::ThreadStore<::cub::STORE_VOLATILE>(&loc_i, value_i); + s.global_offsets[bin] += inc_sum - bins[u]; + } + } + } + + __device__ __forceinline__ void LoadKeys(OffsetT tile_offset, + UnsignedBits (&keys)[ITEMS_PER_THREAD]) { + auto it = MakeTransformIterator(d_keys_in, [] __device__(auto const& v) { + static_assert(sizeof(std::remove_reference_t) == sizeof(UnsignedBits)); + UnsignedBits dst; + std::memcpy(&dst, &v, sizeof(v)); + return dst; + }); + + if (full_block) { + ::cub::LoadDirectWarpStriped(threadIdx.x, it + tile_offset, keys); + } else { + ::cub::LoadDirectWarpStriped(threadIdx.x, it + tile_offset, keys, + num_items - tile_offset, Twiddle::DefaultKey()); + } + +#pragma unroll + for (int u = 0; u < ITEMS_PER_THREAD; ++u) { + keys[u] = Twiddle::In(keys[u]); + } + } + + __device__ __forceinline__ void LoadValues(OffsetT tile_offset, + ValueT (&values)[ITEMS_PER_THREAD]) { + if (full_block) { + ::cub::LoadDirectWarpStriped(threadIdx.x, d_values_in + tile_offset, values); + } else { + int tile_items = num_items - tile_offset; + ::cub::LoadDirectWarpStriped(threadIdx.x, d_values_in + tile_offset, values, + tile_items); + } + } + + __device__ __forceinline__ + void ScatterKeysShared(UnsignedBits (&keys)[ITEMS_PER_THREAD], int (&ranks)[ITEMS_PER_THREAD]) + { + // write to shared memory + #pragma unroll + for (int u = 0; u < ITEMS_PER_THREAD; ++u) + { + s.keys_out[ranks[u]] = keys[u]; + } + } + + __device__ __forceinline__ + void ScatterValuesShared(ValueT (&values)[ITEMS_PER_THREAD], int (&ranks)[ITEMS_PER_THREAD]) + { + // write to shared memory + #pragma unroll + for (int u = 0; u < ITEMS_PER_THREAD; ++u) + { + s.values_out[ranks[u]] = values[u]; + } + } + + __device__ __forceinline__ void LoadBinsToOffsetsGlobal(int (&offsets)[BINS_PER_THREAD]) + { + // global offset - global part + #pragma unroll + for (int u = 0; u < BINS_PER_THREAD; ++u) + { + int bin = ThreadBin(u); + if (FULL_BINS || bin < RADIX_DIGITS) + { + s.global_offsets[bin] = d_bins_in[bin] - offsets[u]; + } + } + } + + __device__ __forceinline__ void UpdateBinsGlobal(int (&bins)[BINS_PER_THREAD], + int (&offsets)[BINS_PER_THREAD]) + { + bool last_block = (block_idx + 1) * TILE_ITEMS >= num_items; + if (d_bins_out != NULL && last_block) + { + #pragma unroll + for (int u = 0; u < BINS_PER_THREAD; ++u) + { + int bin = ThreadBin(u); + if (FULL_BINS || bin < RADIX_DIGITS) + { + d_bins_out[bin] = s.global_offsets[bin] + offsets[u] + bins[u]; + } + } + } + } + + template + __device__ __forceinline__ void ScatterValuesGlobalDirect(int (&digits)[ITEMS_PER_THREAD]) + { + int tile_items = FULL_TILE ? TILE_ITEMS : num_items - block_idx * TILE_ITEMS; +#pragma unroll + for (int u = 0; u < ITEMS_PER_THREAD; ++u) + { + int idx = threadIdx.x + u * BLOCK_THREADS; + ValueT value = s.values_out[idx]; + OffsetT global_idx = idx + s.global_offsets[digits[u]]; + if (FULL_TILE || idx < tile_items) { + d_values_out[global_idx] = value; + } + ::cub::WARP_SYNC(WARP_MASK); + } + } + + __device__ __forceinline__ void ScatterValuesGlobal(int (&digits)[ITEMS_PER_THREAD]) { + // write block data to global memory + if (full_block) { + ScatterValuesGlobalDirect(digits); + } else { + ScatterValuesGlobalDirect(digits); + } + } + + __device__ __forceinline__ void ComputeKeyDigits(int (&digits)[ITEMS_PER_THREAD]) + { + #pragma unroll + for (int u = 0; u < ITEMS_PER_THREAD; ++u) + { + int idx = threadIdx.x + u * BLOCK_THREADS; + digits[u] = Digit(s.keys_out[idx]); + } + } + + __device__ __forceinline__ void GatherScatterValues( + int (&ranks)[ITEMS_PER_THREAD], ::cub::Int2Type keys_only) + { + // compute digits corresponding to the keys + int digits[ITEMS_PER_THREAD]; + ComputeKeyDigits(digits); + + // load values + ValueT values[ITEMS_PER_THREAD]; + LoadValues(block_idx * TILE_ITEMS, values); + + // scatter values + cub::CTA_SYNC(); + ScatterValuesShared(values, ranks); + + cub::CTA_SYNC(); + ScatterValuesGlobal(digits); + } + + __device__ __forceinline__ void Process() + { + // load keys + // if warp1 < warp2, all elements of warp1 occur before those of warp2 + // in the source array + UnsignedBits keys[ITEMS_PER_THREAD]; + LoadKeys(block_idx * TILE_ITEMS, keys); + + // rank keys + int ranks[ITEMS_PER_THREAD]; + int exclusive_digit_prefix[BINS_PER_THREAD]; + int bins[BINS_PER_THREAD]; + BlockRadixRankT(s.rank_temp_storage).RankKeys( + keys, ranks, digit_extractor, exclusive_digit_prefix, + CountsCallback(*this, bins, keys)); + + // scatter keys in shared memory + ::cub::CTA_SYNC(); + ScatterKeysShared(keys, ranks); + + // compute global offsets + LoadBinsToOffsetsGlobal(exclusive_digit_prefix); + LookbackGlobal(bins); + UpdateBinsGlobal(bins, exclusive_digit_prefix); + + // scatter keys in global memory + ::cub::CTA_SYNC(); + + // scatter values if necessary + GatherScatterValues(ranks, ::cub::Int2Type()); + } + + __device__ __forceinline__ // + AgentRadixSortOnesweep(TempStorage &temp_storage, + AtomicOffsetT *d_lookback, + AtomicOffsetT *d_ctrs, + OffsetT *d_bins_out, + const OffsetT *d_bins_in, + KeyIteratorT d_keys_in, + xgboost::common::Span d_values_out, + const ValueT *d_values_in, + PortionOffsetT num_items, + DigitExtractor de) + : s(temp_storage.Alias()) + , d_lookback(d_lookback) + , d_ctrs(d_ctrs) + , d_bins_out(d_bins_out) + , d_bins_in(d_bins_in) + , d_keys_in{d_keys_in} + , d_values_out(d_values_out) + , d_values_in(d_values_in) + // , d_idx_out{d_idx_out} + , num_items(num_items) + , digit_extractor{de} + , warp(threadIdx.x / WARP_THREADS) + , lane(::cub::LaneId()) + { + // initialization + if (threadIdx.x == 0) + { + s.block_idx = atomicAdd(d_ctrs, 1); + } + ::cub::CTA_SYNC(); + block_idx = s.block_idx; + full_block = (block_idx + 1) * TILE_ITEMS <= num_items; + } +}; +} // namespace cub_argsort +// NOLINTEND diff --git a/src/cub_sort/block/radix_rank_sort_operations.cuh b/src/cub_sort/block/radix_rank_sort_operations.cuh new file mode 100644 index 000000000000..3e35048d8a3d --- /dev/null +++ b/src/cub_sort/block/radix_rank_sort_operations.cuh @@ -0,0 +1,31 @@ +#pragma once + +#include +#include +#include + +#include "../util_type.cuh" + +// NOLINTBEGIN +namespace cub_argsort { +/** \brief Twiddling keys for radix sort. */ +template +struct RadixSortTwiddle { + using TraitsT = MyTraits; + using UnsignedBits = typename TraitsT::UnsignedBits; + static __host__ __device__ __forceinline__ UnsignedBits In(UnsignedBits key) { + key = TraitsT::TwiddleIn(key); + if (IS_DESCENDING) key = ~key; + return key; + } + static __host__ __device__ __forceinline__ UnsignedBits Out(UnsignedBits key) { + if (IS_DESCENDING) key = ~key; + key = TraitsT::TwiddleOut(key); + return key; + } + static __host__ __device__ __forceinline__ UnsignedBits DefaultKey() { + return Out(~UnsignedBits(0)); + } +}; +} // namespace cub_argsort +// NOLINTEND diff --git a/src/cub_sort/device/device_radix_sort.cuh b/src/cub_sort/device/device_radix_sort.cuh new file mode 100644 index 000000000000..10af9107a42c --- /dev/null +++ b/src/cub_sort/device/device_radix_sort.cuh @@ -0,0 +1,88 @@ + +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2022, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE + * ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * @file cub::DeviceRadixSort provides device-wide, parallel operations for + * computing a radix sort across a sequence of data items residing within + * device-accessible memory. + */ + +#pragma once + +#include +#include + +#include "dispatch/dispatch_radix_sort.cuh" +#include "xgboost/span.h" + +// NOLINTBEGIN +namespace cub_argsort { +namespace detail { + +/** + * ChooseOffsetT checks NumItemsT, the type of the num_items parameter, and + * selects the offset type based on it. + */ +template +struct ChooseOffsetT { + // NumItemsT must be an integral type (but not bool). + static_assert(std::is_integral::value && + !std::is_same::type, bool>::value, + "NumItemsT must be an integral type, but not bool"); + + // Unsigned integer type for global offsets. + using Type = + typename std::conditional::type; +}; +} // namespace detail + +template +struct DeviceRadixSort { + template + CUB_RUNTIME_FUNCTION static cudaError_t Argsort(void *d_temp_storage, + std::size_t &temp_storage_bytes, + KeyIteratorT d_keys_in, SortedIdxT *d_idx_out, + NumItemsT num_items, + cudaStream_t stream = nullptr) { + // Unsigned integer type for global offsets. + using OffsetT = typename detail::ChooseOffsetT::Type; + using KeyT = typename std::iterator_traits::value_type; + + int constexpr kBeginBit = 0; + int constexpr kEndBit = sizeof(KeyT) * 8; + + return DispatchRadixArgSort::Dispatch(d_temp_storage, temp_storage_bytes, + d_keys_in, + static_cast(num_items), + kBeginBit, kEndBit, stream, d_idx_out); + } +}; +} // namespace cub_argsort +// NOLINTEND diff --git a/src/cub_sort/device/dispatch/dispatch_radix_sort.cuh b/src/cub_sort/device/dispatch/dispatch_radix_sort.cuh new file mode 100644 index 000000000000..e8c214f0c207 --- /dev/null +++ b/src/cub_sort/device/dispatch/dispatch_radix_sort.cuh @@ -0,0 +1,430 @@ +/****************************************************************************** + * Copyright (c) 2011, Duane Merrill. All rights reserved. + * Copyright (c) 2011-2018, NVIDIA CORPORATION. All rights reserved. + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * * Redistributions of source code must retain the above copyright + * notice, this list of conditions and the following disclaimer. + * * Redistributions in binary form must reproduce the above copyright + * notice, this list of conditions and the following disclaimer in the + * documentation and/or other materials provided with the distribution. + * * Neither the name of the NVIDIA CORPORATION nor the + * names of its contributors may be used to endorse or promote products + * derived from this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY + * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + ******************************************************************************/ + +/** + * \file + * cub::DeviceRadixSort provides device-wide, parallel operations for computing a radix sort across + * a sequence of data items residing within device-accessible memory. + */ + +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../../agent/agent_radix_sort_histogram.cuh" +#include "../../agent/agent_radix_sort_onesweep.cuh" +#include "../../util_type.cuh" +#include "xgboost/span.h" + +// NOLINTBEGIN +namespace cub_argsort { +namespace detail { +CUB_RUNTIME_FUNCTION inline cudaError_t HasUVA(bool &has_uva) { + has_uva = false; + cudaError_t error = cudaSuccess; + int device = -1; + if (CubDebug(error = cudaGetDevice(&device)) != cudaSuccess) return error; + int uva = 0; + if (CubDebug(error = cudaDeviceGetAttribute(&uva, cudaDevAttrUnifiedAddressing, device)) != + cudaSuccess) { + return error; + } + has_uva = uva == 1; + return error; +} + +CUB_RUNTIME_FUNCTION inline cudaError_t DebugSyncStream(cudaStream_t) { return cudaSuccess; } +} // namespace detail + +/****************************************************************************** + * Kernel entry points + *****************************************************************************/ +/** + * Kernel for computing multiple histograms + */ + +/** + * Histogram kernel + */ +template +__global__ void __launch_bounds__(ChainedPolicyT::ActivePolicy::HistogramPolicy::BLOCK_THREADS) + DeviceRadixArgSortHistogramKernel(OffsetT *d_bins_out, const KeyIteratorT d_keys_in, + OffsetT num_items, int start_bit, int end_bit) { + typedef typename ChainedPolicyT::ActivePolicy::HistogramPolicy HistogramPolicyT; + typedef AgentRadixSortHistogram + AgentT; + __shared__ typename AgentT::TempStorage temp_storage; + AgentT agent(temp_storage, d_bins_out, d_keys_in, num_items, start_bit, end_bit); + agent.Process(); +} + +template +__global__ void __launch_bounds__(ChainedPolicyT::ActivePolicy::OnesweepPolicy::BLOCK_THREADS) + DeviceRadixSortOnesweepKernel(AtomicOffsetT *d_lookback, AtomicOffsetT *d_ctrs, + OffsetT *d_bins_out, const OffsetT *d_bins_in, + KeyIteratorT d_keys_in, + xgboost::common::Span d_idx_out, + xgboost::common::Span d_idx_in, + PortionOffsetT num_items, int current_bit, int num_bits) { + typedef typename ChainedPolicyT::ActivePolicy::OnesweepPolicy OnesweepPolicyT; + typedef AgentRadixSortOnesweep + AgentT; + __shared__ typename AgentT::TempStorage s; + + DigitExtractor de(current_bit, num_bits); // fixme + AgentT agent(s, d_lookback, d_ctrs, d_bins_out, d_bins_in, d_keys_in, d_idx_out, d_idx_in.data(), + num_items, de); + agent.Process(); +} + +/** + * Exclusive sum kernel + */ +template +__global__ void DeviceRadixSortExclusiveSumKernel(OffsetT *d_bins) { + typedef typename ChainedPolicyT::ActivePolicy::ExclusiveSumPolicy ExclusiveSumPolicyT; + const int RADIX_BITS = ExclusiveSumPolicyT::RADIX_BITS; + const int RADIX_DIGITS = 1 << RADIX_BITS; + const int BLOCK_THREADS = ExclusiveSumPolicyT::BLOCK_THREADS; + const int BINS_PER_THREAD = (RADIX_DIGITS + BLOCK_THREADS - 1) / BLOCK_THREADS; + typedef ::cub::BlockScan BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + + // load the bins + OffsetT bins[BINS_PER_THREAD]; + int bin_start = blockIdx.x * RADIX_DIGITS; +#pragma unroll + for (int u = 0; u < BINS_PER_THREAD; ++u) { + int bin = threadIdx.x * BINS_PER_THREAD + u; + if (bin >= RADIX_DIGITS) break; + bins[u] = d_bins[bin_start + bin]; + } + + // compute offsets + BlockScan(temp_storage).ExclusiveSum(bins, bins); + +// store the offsets +#pragma unroll + for (int u = 0; u < BINS_PER_THREAD; ++u) { + int bin = threadIdx.x * BINS_PER_THREAD + u; + if (bin >= RADIX_DIGITS) break; + d_bins[bin_start + bin] = bins[u]; + } +} + +template +struct SortedKeyOp { + using KeyT = std::remove_reference_t::value_type>; + + KeyIt d_keys; + std::uint32_t *s_idx_in; + + __device__ KeyT operator()(std::size_t i) const { + auto idx = s_idx_in[i]; + return d_keys[idx]; + } +}; + +/** + * Utility class for dispatching the appropriately-tuned kernels for device-wide radix sort + */ +template ::value_type, ValueT, OffsetT> > +struct DispatchRadixArgSort : SelectedPolicy { + //------------------------------------------------------------------------------ + // Problem state + //------------------------------------------------------------------------------ + + void *d_temp_storage; ///< [in] Device-accessible allocation of temporary storage. When NULL, + ///< the required allocation size is written to \p temp_storage_bytes and + ///< no work is done. + size_t + &temp_storage_bytes; ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation + using KeyT = std::remove_reference_t::value_type>; + KeyIteratorT d_keys; + OffsetT num_items; ///< [in] Number of items to sort + int begin_bit; ///< [in] The beginning (least-significant) bit index needed for key comparison + int end_bit; ///< [in] The past-the-end (most-significant) bit index needed for key comparison + cudaStream_t + stream; ///< [in] CUDA stream to launch kernels within. Default is stream0. + int ptx_version; ///< [in] PTX version + SortedIdxT *d_idx; + + //------------------------------------------------------------------------------ + // Constructor + //------------------------------------------------------------------------------ + + CUB_RUNTIME_FUNCTION __forceinline__ DispatchRadixArgSort( + void *d_temp_storage, size_t &temp_storage_bytes, KeyIteratorT d_keys, OffsetT num_items, + int begin_bit, int end_bit, cudaStream_t stream, int ptx_version, SortedIdxT *d_idx_out) + : d_temp_storage(d_temp_storage), + temp_storage_bytes(temp_storage_bytes), + d_keys(d_keys), + num_items(num_items), + begin_bit(begin_bit), + end_bit(end_bit), + stream(stream), + ptx_version(ptx_version), + d_idx{d_idx_out} {} + + //------------------------------------------------------------------------------ + // Normal problem size invocation + //------------------------------------------------------------------------------ + template + CUB_RUNTIME_FUNCTION __forceinline__ cudaError_t InvokeOnesweep() { + typedef typename DispatchRadixArgSort::MaxPolicy MaxPolicyT; + // PortionOffsetT is used for offsets within a portion, and must be signed. + typedef int PortionOffsetT; + typedef PortionOffsetT AtomicOffsetT; + + // compute temporary storage size + const int RADIX_BITS = ActivePolicyT::ONESWEEP_RADIX_BITS; + const int RADIX_DIGITS = 1 << RADIX_BITS; + const int ONESWEEP_ITEMS_PER_THREAD = ActivePolicyT::OnesweepPolicy::ITEMS_PER_THREAD; + const int ONESWEEP_BLOCK_THREADS = ActivePolicyT::OnesweepPolicy::BLOCK_THREADS; + const int ONESWEEP_TILE_ITEMS = ONESWEEP_ITEMS_PER_THREAD * ONESWEEP_BLOCK_THREADS; + // portions handle inputs with >=2**30 elements, due to the way lookback works + // for testing purposes, one portion is <= 2**28 elements + const PortionOffsetT PORTION_SIZE = ((1 << 28) - 1) / ONESWEEP_TILE_ITEMS * ONESWEEP_TILE_ITEMS; + int num_passes = ::cub::DivideAndRoundUp(end_bit - begin_bit, RADIX_BITS); + OffsetT num_portions = static_cast(::cub::DivideAndRoundUp(num_items, PORTION_SIZE)); + + PortionOffsetT max_num_blocks = ::cub::DivideAndRoundUp( + static_cast(CUB_MIN(num_items, static_cast(PORTION_SIZE))), + ONESWEEP_TILE_ITEMS); + + std::size_t allocation_sizes[] = { + // bins + num_portions * num_passes * RADIX_DIGITS * sizeof(OffsetT), + // lookback + max_num_blocks * RADIX_DIGITS * sizeof(AtomicOffsetT), + // counters + num_portions * num_passes * sizeof(AtomicOffsetT), + // index + num_items * sizeof(SortedIdxT), + }; + const int NUM_ALLOCATIONS = sizeof(allocation_sizes) / sizeof(allocation_sizes[0]); + void *allocations[NUM_ALLOCATIONS] = {}; + ::cub::AliasTemporaries(d_temp_storage, temp_storage_bytes, allocations, + allocation_sizes); + + // just return if no temporary storage is provided + cudaError_t error = cudaSuccess; + if (d_temp_storage == nullptr) return error; + + OffsetT *d_bins = static_cast(allocations[0]); + AtomicOffsetT *d_lookback = static_cast(allocations[1]); + AtomicOffsetT *d_ctrs = static_cast(allocations[2]); + SortedIdxT *d_idx_tmp = static_cast(allocations[3]); + + thrust::sequence(thrust::cuda::par.on(stream), d_idx, d_idx + num_items); + ::cub::DoubleBuffer d_idx_out{d_idx, d_idx_tmp}; + + do { + // initialization + if (CubDebug(error = cudaMemsetAsync( + d_ctrs, 0, num_portions * num_passes * sizeof(AtomicOffsetT), stream))) { + break; + } + + // compute num_passes histograms with RADIX_DIGITS bins each + if (CubDebug(error = cudaMemsetAsync(d_bins, 0, num_passes * RADIX_DIGITS * sizeof(OffsetT), + stream))) { + break; + } + int device = -1; + int num_sms = 0; + if (CubDebug(error = cudaGetDevice(&device))) break; + if (CubDebug(error = + cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, device))) { + break; + } + + auto s_idx_out = xgboost::common::Span(d_idx_out.Current(), num_items); + const int HISTO_BLOCK_THREADS = ActivePolicyT::HistogramPolicy::BLOCK_THREADS; + int histo_blocks_per_sm = 1; + auto histogram_kernel = + DeviceRadixArgSortHistogramKernel; + if (CubDebug(error = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &histo_blocks_per_sm, histogram_kernel, HISTO_BLOCK_THREADS, 0))) { + break; + } + error = THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron( + histo_blocks_per_sm * num_sms, HISTO_BLOCK_THREADS, 0, stream) + .doit(histogram_kernel, d_bins, d_keys, num_items, begin_bit, end_bit); + if (CubDebug(error)) { + break; + } + + error = detail::DebugSyncStream(stream); + if (CubDebug(error)) { + break; + } + + // exclusive sums to determine starts + const int SCAN_BLOCK_THREADS = ActivePolicyT::ExclusiveSumPolicy::BLOCK_THREADS; + error = THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron(num_passes, + SCAN_BLOCK_THREADS, 0, stream) + .doit(DeviceRadixSortExclusiveSumKernel, d_bins); + if (CubDebug(error)) { + break; + } + + error = detail::DebugSyncStream(stream); + if (CubDebug(error)) { + break; + } + + auto d_keys = this->d_keys; + static_assert(RADIX_BITS == 8); + for (int current_bit = begin_bit, pass = 0; current_bit < end_bit; + current_bit += RADIX_BITS, ++pass) { + int num_bits = CUB_MIN(end_bit - current_bit, RADIX_BITS); + for (OffsetT portion = 0; portion < num_portions; ++portion) { + PortionOffsetT portion_num_items = static_cast( + CUB_MIN(num_items - portion * PORTION_SIZE, static_cast(PORTION_SIZE))); + PortionOffsetT num_blocks = + ::cub::DivideAndRoundUp(portion_num_items, ONESWEEP_TILE_ITEMS); + if (CubDebug(error = cudaMemsetAsync(d_lookback, 0, + num_blocks * RADIX_DIGITS * sizeof(AtomicOffsetT), + stream))) { + break; + } + + auto s_idx_in = d_idx_out.Current(); + // fixme: this doesn't work well with idx when iterating through portion. + auto key_in = MakeTransformIterator(thrust::make_counting_iterator(0ul), + SortedKeyOp{d_keys, s_idx_in}); + + auto onesweep_kernel = + DeviceRadixSortOnesweepKernel; + error = + THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron( + num_blocks, ONESWEEP_BLOCK_THREADS, 0, stream) + .doit(onesweep_kernel, d_lookback, d_ctrs + portion * num_passes + pass, + portion < num_portions - 1 + ? d_bins + ((portion + 1) * num_passes + pass) * RADIX_DIGITS + : nullptr, + d_bins + (portion * num_passes + pass) * RADIX_DIGITS, + key_in + portion * PORTION_SIZE, + xgboost::common::Span{d_idx_out.Alternate(), num_items}, + xgboost::common::Span{d_idx_out.Current(), num_items}.subspan( + portion * PORTION_SIZE), + portion_num_items, current_bit, num_bits); + if (CubDebug(error)) { + break; + } + + error = detail::DebugSyncStream(stream); + if (CubDebug(error)) { + break; + } + } + + if (error != cudaSuccess) { + break; + } + + d_idx_out.selector ^= 1; + } + } while (false); + + return error; + } + + //------------------------------------------------------------------------------ + // Chained policy invocation + //------------------------------------------------------------------------------ + /// Invocation + template + CUB_RUNTIME_FUNCTION __forceinline__ cudaError_t Invoke() { + typedef typename DispatchRadixArgSort::MaxPolicy MaxPolicyT; + // Return if empty problem, or if no bits to sort and double-buffering is used + if (num_items == 0 || (begin_bit == end_bit)) { + if (d_temp_storage == nullptr) { + temp_storage_bytes = 1; + } + return cudaSuccess; + } + return InvokeOnesweep(); + } + + CUB_RUNTIME_FUNCTION __forceinline__ static cudaError_t Dispatch( + void *d_temp_storage, size_t &temp_storage_bytes, KeyIteratorT d_keys, OffsetT num_items, + int begin_bit, int end_bit, cudaStream_t stream, SortedIdxT *d_idx_out) { + typedef typename DispatchRadixArgSort::MaxPolicy MaxPolicyT; + + cudaError_t error; + do { + // Get PTX version + int ptx_version = 0; + if (CubDebug(error = ::cub::PtxVersion(ptx_version))) break; + + // Create dispatch functor + DispatchRadixArgSort dispatch{d_temp_storage, temp_storage_bytes, d_keys, + num_items, begin_bit, end_bit, + stream, ptx_version, d_idx_out}; + + // Dispatch to chained policy + if (CubDebug(error = MaxPolicyT::Invoke(ptx_version, dispatch))) { + break; + } + } while (false); + + return error; + } +}; +} // namespace cub_argsort +// NOLINTEND diff --git a/src/cub_sort/util_type.cuh b/src/cub_sort/util_type.cuh new file mode 100644 index 000000000000..8c780f37f272 --- /dev/null +++ b/src/cub_sort/util_type.cuh @@ -0,0 +1,58 @@ +#pragma once +#include // for swap +#include + +#include + +// NOLINTBEGIN +namespace cub_argsort { +struct EntryTrait : public ::cub::BaseTraits<::cub::NOT_A_NUMBER, false, false, unsigned long long, + xgboost::Entry> { + using Entry = xgboost::Entry; + using UnsignedBits = unsigned long long; + + static constexpr ::cub::Category CATEGORY = ::cub::NOT_A_NUMBER; + // The calculation for bst_feature_t is not necessary as it's unsigned integer, only + // performed here for clarify. + static constexpr UnsignedBits LOWEST_KEY = + (UnsignedBits{::cub::NumericTraits::LOWEST_KEY} + << sizeof(decltype(Entry::index)) * 8) ^ + UnsignedBits { ::cub::NumericTraits::LOWEST_KEY }; + + static constexpr UnsignedBits MAX_KEY = + (UnsignedBits{::cub::NumericTraits::MAX_KEY} + << sizeof(decltype(Entry::index)) * 8) ^ + UnsignedBits { ::cub::NumericTraits::MAX_KEY }; + + static __device__ __forceinline__ UnsignedBits TwiddleIn(UnsignedBits key) { + using F32T = ::cub::NumericTraits; + + auto ptr = reinterpret_cast(&key); + // Make index the most significant element + // after swap, 0^th is favlue, 1^th is index + thrust::swap(ptr[0], ptr[1]); + + auto& fv_key = ptr[0]; + fv_key = F32T::TwiddleIn(fv_key); + + return key; + }; + + static __device__ __forceinline__ UnsignedBits TwiddleOut(UnsignedBits key) { + using F32T = ::cub::NumericTraits; + + auto ptr = reinterpret_cast(&key); + // after swap, 0^th is index, 1^th is fvalue + thrust::swap(ptr[0], ptr[1]); + auto& fv_key = ptr[1]; + fv_key = F32T::TwiddleOut(fv_key); + + return key; + } +}; + +template +struct MyTraits : public std::conditional_t, EntryTrait, + ::cub::NumericTraits> {}; +} // namespace cub_argsort +// NOLINTEND diff --git a/src/data/array_interface.h b/src/data/array_interface.h index fee22203c111..1b18f140aa67 100644 --- a/src/data/array_interface.h +++ b/src/data/array_interface.h @@ -26,6 +26,10 @@ #include "xgboost/logging.h" #include "xgboost/span.h" +#if defined(XGBOOST_USE_CUDA) +#include "cuda_fp16.h" +#endif + namespace xgboost { // Common errors in parsing columnar format. struct ArrayInterfaceErrors { @@ -304,12 +308,12 @@ class ArrayInterfaceHandler { template struct ToDType; // float -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600 +#if defined(XGBOOST_USE_CUDA) template <> struct ToDType<__half> { static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kF2; }; -#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600 +#endif // defined(XGBOOST_USE_CUDA) template <> struct ToDType { static constexpr ArrayInterfaceHandler::Type kType = ArrayInterfaceHandler::kF4; @@ -459,11 +463,11 @@ class ArrayInterface { CHECK(sizeof(long double) == 16) << error::NoF128(); type = T::kF16; } else if (typestr[1] == 'f' && typestr[2] == '2') { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600 +#if defined(XGBOOST_USE_CUDA) type = T::kF2; #else LOG(FATAL) << "Half type is not supported."; -#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600 +#endif // defined(XGBOOST_USE_CUDA) } else if (typestr[1] == 'f' && typestr[2] == '4') { type = T::kF4; } else if (typestr[1] == 'f' && typestr[2] == '8') { @@ -490,20 +494,17 @@ class ArrayInterface { } } - XGBOOST_DEVICE size_t Shape(size_t i) const { return shape[i]; } - XGBOOST_DEVICE size_t Stride(size_t i) const { return strides[i]; } + [[nodiscard]] XGBOOST_DEVICE std::size_t Shape(size_t i) const { return shape[i]; } + [[nodiscard]] XGBOOST_DEVICE std::size_t Stride(size_t i) const { return strides[i]; } template XGBOOST_HOST_DEV_INLINE decltype(auto) DispatchCall(Fn func) const { using T = ArrayInterfaceHandler::Type; switch (type) { case T::kF2: { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600 +#if defined(XGBOOST_USE_CUDA) return func(reinterpret_cast<__half const *>(data)); -#else - SPAN_CHECK(false); - return func(reinterpret_cast(data)); -#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600 +#endif // defined(XGBOOST_USE_CUDA) } case T::kF4: return func(reinterpret_cast(data)); @@ -540,23 +541,23 @@ class ArrayInterface { return func(reinterpret_cast(data)); } - XGBOOST_DEVICE std::size_t ElementSize() const { + [[nodiscard]] XGBOOST_DEVICE std::size_t ElementSize() const { return this->DispatchCall([](auto *typed_data_ptr) { return sizeof(std::remove_pointer_t); }); } - XGBOOST_DEVICE std::size_t ElementAlignment() const { + [[nodiscard]] XGBOOST_DEVICE std::size_t ElementAlignment() const { return this->DispatchCall([](auto *typed_data_ptr) { return std::alignment_of>::value; }); } template - XGBOOST_DEVICE T operator()(Index &&...index) const { + XGBOOST_HOST_DEV_INLINE T operator()(Index &&...index) const { static_assert(sizeof...(index) <= D, "Invalid index."); return this->DispatchCall([=](auto const *p_values) -> T { std::size_t offset = linalg::detail::Offset<0ul>(strides, 0ul, index...); -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600 +#if defined(XGBOOST_USE_CUDA) // No operator defined for half -> size_t using Type = std::conditional_t< std::is_same<__half, @@ -566,7 +567,7 @@ class ArrayInterface { return static_cast(static_cast(p_values[offset])); #else return static_cast(p_values[offset]); -#endif +#endif // defined(XGBOOST_USE_CUDA) }); } @@ -603,7 +604,7 @@ void DispatchDType(ArrayInterface const array, std::int32_t device, Fn fn) { }; switch (array.type) { case ArrayInterfaceHandler::kF2: { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 600 +#if defined(XGBOOST_USE_CUDA) dispatch(__half{}); #endif break; diff --git a/src/data/device_adapter.cuh b/src/data/device_adapter.cuh index 494fb7d1c438..136fbb743ec8 100644 --- a/src/data/device_adapter.cuh +++ b/src/data/device_adapter.cuh @@ -29,7 +29,7 @@ class CudfAdapterBatch : public detail::NoMetaInfo { : columns_(columns), num_rows_(num_rows) {} size_t Size() const { return num_rows_ * columns_.size(); } - __device__ COOTuple GetElement(size_t idx) const { + __device__ __forceinline__ COOTuple GetElement(size_t idx) const { size_t column_idx = idx % columns_.size(); size_t row_idx = idx / columns_.size(); auto const& column = columns_[column_idx]; @@ -221,13 +221,24 @@ size_t GetRowCounts(const AdapterBatchT batch, common::Span offset, * \brief Check there's no inf in data. */ template -bool HasInfInData(AdapterBatchT const& batch, IsValidFunctor is_valid) { +bool NoInfInData(AdapterBatchT const& batch, IsValidFunctor is_valid) { auto counting = thrust::make_counting_iterator(0llu); - auto value_iter = dh::MakeTransformIterator( - counting, [=] XGBOOST_DEVICE(std::size_t idx) { return batch.GetElement(idx).value; }); - auto valid = - thrust::none_of(value_iter, value_iter + batch.Size(), - [is_valid] XGBOOST_DEVICE(float v) { return is_valid(v) && std::isinf(v); }); + auto value_iter = dh::MakeTransformIterator(counting, [=] XGBOOST_DEVICE(std::size_t idx) { + auto v = batch.GetElement(idx).value; + if (!is_valid(v)) { + // discard the invalid elements. + return true; + } + // check that there's no inf in data. + return !std::isinf(v); + }); + dh::XGBCachingDeviceAllocator alloc; + // The default implementation in thrust optimizes any_of/none_of/all_of by using small + // intervals to early stop. But we expect all data to be valid here, using small + // intervals only decreases performance due to excessive kernel launch and stream + // synchronization. + auto valid = dh::Reduce(thrust::cuda::par(alloc), value_iter, value_iter + batch.Size(), true, + thrust::logical_and<>{}); return valid; } }; // namespace data diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index 4409a7ebbea7..5fa6af0f5d57 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -200,7 +200,7 @@ void CopyDataToEllpack(const AdapterBatchT& batch, common::Span( @@ -333,7 +333,7 @@ void CopyGHistToEllpack(GHistIndexMatrix const& page, common::Span // is dense, ifeature is the actual feature index. offset = d_csc_indptr[ifeature]; } - common::cuda::DispatchBinType(bin_type, [&](auto t) { + common::cuda_impl::DispatchBinType(bin_type, [&](auto t) { using T = decltype(t); auto ptr = reinterpret_cast(d_data.data()); auto bin_idx = ptr[r_begin + ifeature] + offset; diff --git a/src/data/iterative_dmatrix.cc b/src/data/iterative_dmatrix.cc index 8eb1c203432f..627606aa3741 100644 --- a/src/data/iterative_dmatrix.cc +++ b/src/data/iterative_dmatrix.cc @@ -366,8 +366,8 @@ inline void IterativeDMatrix::InitFromCUDA(Context const*, BatchParam const&, Da common::AssertGPUSupport(); } -inline BatchSet IterativeDMatrix::GetEllpackBatches(Context const* ctx, - BatchParam const& param) { +inline BatchSet IterativeDMatrix::GetEllpackBatches(Context const*, + BatchParam const&) { common::AssertGPUSupport(); auto begin_iter = BatchIterator(new SimpleBatchIteratorImpl(ellpack_)); return BatchSet(BatchIterator(begin_iter)); diff --git a/src/data/simple_dmatrix.cuh b/src/data/simple_dmatrix.cuh index 63310a92984f..e2c0ae347686 100644 --- a/src/data/simple_dmatrix.cuh +++ b/src/data/simple_dmatrix.cuh @@ -64,7 +64,7 @@ void CountRowOffsets(const AdapterBatchT& batch, common::Span offset, template size_t CopyToSparsePage(AdapterBatchT const& batch, int32_t device, float missing, SparsePage* page) { - bool valid = HasInfInData(batch, IsValidFunctor{missing}); + bool valid = NoInfInData(batch, IsValidFunctor{missing}); CHECK(valid) << error::InfInData(); page->offset.SetDevice(device); diff --git a/tests/cpp/common/test_hist_util.cu b/tests/cpp/common/test_hist_util.cu index e907a9f72ee6..262c8032941d 100644 --- a/tests/cpp/common/test_hist_util.cu +++ b/tests/cpp/common/test_hist_util.cu @@ -21,8 +21,7 @@ #include "../helpers.h" #include "test_hist_util.h" -namespace xgboost { -namespace common { +namespace xgboost::common { template HistogramCuts GetHostCuts(Context const* ctx, AdapterT* adapter, int num_bins, float missing) { @@ -48,22 +47,6 @@ TEST(HistUtil, DeviceSketch) { EXPECT_EQ(device_cuts.MinValues(), host_cuts.MinValues()); } -TEST(HistUtil, SketchBatchNumElements) { -#if defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 - LOG(WARNING) << "Test not runnable with RMM enabled."; - return; -#endif // defined(XGBOOST_USE_RMM) && XGBOOST_USE_RMM == 1 - size_t constexpr kCols = 10000; - int device; - dh::safe_cuda(cudaGetDevice(&device)); - auto avail = static_cast(dh::AvailableMemory(device) * 0.8); - auto per_elem = detail::BytesPerElement(false); - auto avail_elem = avail / per_elem; - size_t rows = avail_elem / kCols * 10; - auto batch = detail::SketchBatchNumElements(0, rows, kCols, rows * kCols, device, 256, false); - ASSERT_EQ(batch, avail_elem); -} - TEST(HistUtil, DeviceSketchMemory) { int num_columns = 100; int num_rows = 1000; @@ -71,15 +54,26 @@ TEST(HistUtil, DeviceSketchMemory) { auto x = GenerateRandom(num_rows, num_columns); auto dmat = GetDMatrixFromData(x, num_rows, num_columns); - dh::GlobalMemoryLogger().Clear(); - ConsoleLogger::Configure({{"verbosity", "3"}}); - auto device_cuts = DeviceSketch(0, dmat.get(), num_bins); + SketchContainer sketch_container(dmat->Info().feature_types, num_bins, dmat->Info().num_col_, + dmat->Info().num_row_, 0); + auto verbosity = GlobalConfigThreadLocalStore::Get()->verbosity; - size_t bytes_required = detail::RequiredMemory( - num_rows, num_columns, num_rows * num_columns, num_bins, false); - EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 1.05); - EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 0.95); - ConsoleLogger::Configure({{"verbosity", "0"}}); + for (auto const& page : dmat->GetBatches()) { + dh::GlobalMemoryLogger().Clear(); + ConsoleLogger::Configure({{"verbosity", "3"}}); + + size_t num_cuts_per_feature = + detail::RequiredSampleCutsPerColumn(num_bins, dmat->Info().num_row_); + ProcessBatch(0, dmat->Info(), page, 0, page.data.Size(), &sketch_container, + num_cuts_per_feature, dmat->Info().num_col_); + std::size_t bytes_required = detail::RequiredMemory( + num_rows, num_columns, num_rows * num_columns, num_bins, false, page.data.DeviceCanRead()); + + EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 1.05); + EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 0.95); + } + + ConsoleLogger::Configure({{"verbosity", std::to_string(verbosity)}}); } TEST(HistUtil, DeviceSketchWeightsMemory) { @@ -89,16 +83,27 @@ TEST(HistUtil, DeviceSketchWeightsMemory) { auto x = GenerateRandom(num_rows, num_columns); auto dmat = GetDMatrixFromData(x, num_rows, num_columns); dmat->Info().weights_.HostVector() = GenerateRandomWeights(num_rows); - - dh::GlobalMemoryLogger().Clear(); - ConsoleLogger::Configure({{"verbosity", "3"}}); - auto device_cuts = DeviceSketch(0, dmat.get(), num_bins); - ConsoleLogger::Configure({{"verbosity", "0"}}); - - size_t bytes_required = detail::RequiredMemory( - num_rows, num_columns, num_rows * num_columns, num_bins, true); - EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 1.05); - EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required); + dmat->Info().weights_.SetDevice(0); + SketchContainer sketch_container(dmat->Info().feature_types, num_bins, dmat->Info().num_col_, + dmat->Info().num_row_, 0); + + for (auto const& page : dmat->GetBatches()) { + auto verbosity = GlobalConfigThreadLocalStore::Get()->verbosity; + dh::GlobalMemoryLogger().Clear(); + ConsoleLogger::Configure({{"verbosity", "3"}}); + + size_t num_cuts_per_feature = + detail::RequiredSampleCutsPerColumn(num_bins, dmat->Info().num_row_); + ProcessWeightedBatch(0, dmat->Info(), page, 0, page.data.Size(), &sketch_container, + num_cuts_per_feature, dmat->Info().num_col_, false, {}); + std::size_t bytes_required = detail::RequiredMemory( + num_rows, num_columns, num_rows * num_columns, num_bins, true, page.data.DeviceCanRead()); + + EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 1.05); + EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 0.95); + + ConsoleLogger::Configure({{"verbosity", std::to_string(verbosity)}}); + } } TEST(HistUtil, DeviceSketchDeterminism) { @@ -196,7 +201,7 @@ TEST(HistUtil, DeviceSketchMultipleColumnsWeights) { } } -TEST(HistUitl, DeviceSketchWeights) { +TEST(HistUtil, DeviceSketchWeights) { int bin_sizes[] = {2, 16, 256, 512}; int sizes[] = {100, 1000, 1500}; int num_columns = 5; @@ -318,24 +323,6 @@ TEST(HistUtil, AdapterDeviceSketch) { EXPECT_EQ(device_cuts.MinValues(), host_cuts.MinValues()); } -TEST(HistUtil, AdapterDeviceSketchMemory) { - int num_columns = 100; - int num_rows = 1000; - int num_bins = 256; - auto x = GenerateRandom(num_rows, num_columns); - auto x_device = thrust::device_vector(x); - auto adapter = AdapterFromData(x_device, num_rows, num_columns); - - dh::GlobalMemoryLogger().Clear(); - ConsoleLogger::Configure({{"verbosity", "3"}}); - auto cuts = MakeUnweightedCutsForTest(adapter, num_bins, std::numeric_limits::quiet_NaN()); - ConsoleLogger::Configure({{"verbosity", "0"}}); - size_t bytes_required = detail::RequiredMemory( - num_rows, num_columns, num_rows * num_columns, num_bins, false); - EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 1.05); - EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 0.95); -} - TEST(HistUtil, AdapterSketchSlidingWindowMemory) { int num_columns = 100; int num_rows = 1000; @@ -343,22 +330,25 @@ TEST(HistUtil, AdapterSketchSlidingWindowMemory) { auto x = GenerateRandom(num_rows, num_columns); auto x_device = thrust::device_vector(x); auto adapter = AdapterFromData(x_device, num_rows, num_columns); + MetaInfo info; + HostDeviceVector ft; + SketchContainer sketch_container(ft, num_bins, num_columns, num_rows, 0); dh::GlobalMemoryLogger().Clear(); + auto verbosity = GlobalConfigThreadLocalStore::Get()->verbosity; ConsoleLogger::Configure({{"verbosity", "3"}}); - common::HistogramCuts batched_cuts; - HostDeviceVector ft; - SketchContainer sketch_container(ft, num_bins, num_columns, num_rows, 0); - AdapterDeviceSketch(adapter.Value(), num_bins, info, std::numeric_limits::quiet_NaN(), - &sketch_container); - HistogramCuts cuts; - sketch_container.MakeCuts(&cuts); - size_t bytes_required = detail::RequiredMemory( - num_rows, num_columns, num_rows * num_columns, num_bins, false); + + auto nnz = std::numeric_limits::max(); + AdapterDeviceSketch(adapter.Value(), num_bins, info, nnz, &sketch_container); + + size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, num_rows); + size_t bytes_required = + detail::RequiredMemory(num_rows, num_columns, nnz, num_cuts_per_feature, false, true); EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 1.05); EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 0.95); - ConsoleLogger::Configure({{"verbosity", "0"}}); + + ConsoleLogger::Configure({{"verbosity", std::to_string(verbosity)}}); } TEST(HistUtil, AdapterSketchSlidingWindowWeightedMemory) { @@ -368,31 +358,34 @@ TEST(HistUtil, AdapterSketchSlidingWindowWeightedMemory) { auto x = GenerateRandom(num_rows, num_columns); auto x_device = thrust::device_vector(x); auto adapter = AdapterFromData(x_device, num_rows, num_columns); + MetaInfo info; auto& h_weights = info.weights_.HostVector(); h_weights.resize(num_rows); std::fill(h_weights.begin(), h_weights.end(), 1.0f); + info.weights_.SetDevice(0); - dh::GlobalMemoryLogger().Clear(); - ConsoleLogger::Configure({{"verbosity", "3"}}); - common::HistogramCuts batched_cuts; HostDeviceVector ft; SketchContainer sketch_container(ft, num_bins, num_columns, num_rows, 0); - AdapterDeviceSketch(adapter.Value(), num_bins, info, - std::numeric_limits::quiet_NaN(), - &sketch_container); - HistogramCuts cuts; - sketch_container.MakeCuts(&cuts); - ConsoleLogger::Configure({{"verbosity", "0"}}); - size_t bytes_required = detail::RequiredMemory( - num_rows, num_columns, num_rows * num_columns, num_bins, true); + dh::GlobalMemoryLogger().Clear(); + auto verbosity = GlobalConfigThreadLocalStore::Get()->verbosity; + ConsoleLogger::Configure({{"verbosity", "3"}}); + + auto nnz = std::numeric_limits::max(); + AdapterDeviceSketch(adapter.Value(), num_bins, info, nnz, &sketch_container); + + size_t num_cuts_per_feature = detail::RequiredSampleCutsPerColumn(num_bins, num_rows); + size_t bytes_required = + detail::RequiredMemory(num_rows, num_columns, nnz, num_cuts_per_feature, true, true); EXPECT_LE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required * 1.05); EXPECT_GE(dh::GlobalMemoryLogger().PeakMemory(), bytes_required); + + ConsoleLogger::Configure({{"verbosity", std::to_string(verbosity)}}); } -void TestCategoricalSketchAdapter(size_t n, size_t num_categories, - int32_t num_bins, bool weighted) { +void TestCategoricalSketchAdapter(size_t n, size_t num_categories, int32_t num_bins, bool weighted, + std::size_t batch_size) { auto h_x = GenerateRandomCategoricalSingleColumn(n, num_categories); thrust::device_vector x(h_x); auto adapter = AdapterFromData(x, n, 1); @@ -413,8 +406,8 @@ void TestCategoricalSketchAdapter(size_t n, size_t num_categories, ASSERT_EQ(info.feature_types.Size(), 1); SketchContainer container(info.feature_types, num_bins, 1, n, 0); - AdapterDeviceSketch(adapter.Value(), num_bins, info, - std::numeric_limits::quiet_NaN(), &container); + AdapterDeviceSketch(adapter.Value(), num_bins, info, std::numeric_limits::quiet_NaN(), + &container, batch_size); HistogramCuts cuts; container.MakeCuts(&cuts); @@ -448,10 +441,22 @@ TEST(HistUtil, AdapterDeviceSketchCategorical) { auto x_device = thrust::device_vector(x); auto adapter = AdapterFromData(x_device, n, 1); ValidateBatchedCuts(adapter, num_bins, dmat.get()); - TestCategoricalSketchAdapter(n, num_categories, num_bins, true); - TestCategoricalSketchAdapter(n, num_categories, num_bins, false); + TestCategoricalSketchAdapter(n, num_categories, num_bins, true, /*batch_size=*/0); + TestCategoricalSketchAdapter(n, num_categories, num_bins, false, /*batch_size=*/0); } } + + for (std::size_t batch_size : {4ul, 32ul}) { + std::size_t n_samples = 4096; + std::int32_t n_categories = 64; + bst_bin_t n_bins = 13; + auto x = GenerateRandomCategoricalSingleColumn(n_samples, n_categories); + auto dmat = GetDMatrixFromData(x, n_samples, 1); + auto x_device = thrust::device_vector(x); + + auto adapter = AdapterFromData(x_device, n_samples, 1); + TestCategoricalSketchAdapter(n_samples, n_categories, n_bins, true, batch_size); + } } TEST(HistUtil, AdapterDeviceSketchMultipleColumns) { @@ -471,8 +476,8 @@ TEST(HistUtil, AdapterDeviceSketchMultipleColumns) { TEST(HistUtil, AdapterDeviceSketchBatches) { int num_bins = 256; - int num_rows = 5000; - int batch_sizes[] = {0, 100, 1500, 6000}; + int num_rows = 512; + int batch_sizes[] = {0, 100}; int num_columns = 5; for (auto batch_size : batch_sizes) { auto x = GenerateRandom(num_rows, num_columns); @@ -483,6 +488,130 @@ TEST(HistUtil, AdapterDeviceSketchBatches) { } } +namespace { +auto MakeData(Context const* ctx, std::size_t n_samples, bst_feature_t n_features) { + dh::safe_cuda(cudaSetDevice(ctx->gpu_id)); + auto n = n_samples * n_features; + std::vector x; + x.resize(n); + + std::iota(x.begin(), x.end(), 0); + std::int32_t c{0}; + float missing = n_samples * n_features; + for (std::size_t i = 0; i < x.size(); ++i) { + if (i % 5 == 0) { + x[i] = missing; + c++; + } + } + thrust::device_vector d_x; + d_x = x; + + auto n_invalids = n / 10 * 2 + 1; + auto is_valid = data::IsValidFunctor{missing}; + return std::tuple{x, d_x, n_invalids, is_valid}; +} + +void TestGetColumnSize(std::size_t n_samples) { + auto ctx = MakeCUDACtx(0); + bst_feature_t n_features = 12; + [[maybe_unused]] auto [x, d_x, n_invalids, is_valid] = MakeData(&ctx, n_samples, n_features); + + auto adapter = AdapterFromData(d_x, n_samples, n_features); + auto batch = adapter.Value(); + + auto batch_iter = dh::MakeTransformIterator( + thrust::make_counting_iterator(0llu), + [=] __device__(std::size_t idx) { return batch.GetElement(idx); }); + + dh::caching_device_vector column_sizes_scan; + column_sizes_scan.resize(n_features + 1); + std::vector h_column_size(column_sizes_scan.size()); + std::vector h_column_size_1(column_sizes_scan.size()); + + detail::LaunchGetColumnSizeKernel( + ctx.gpu_id, IterSpan{batch_iter, batch.Size()}, is_valid, dh::ToSpan(column_sizes_scan)); + thrust::copy(column_sizes_scan.begin(), column_sizes_scan.end(), h_column_size.begin()); + + detail::LaunchGetColumnSizeKernel( + ctx.gpu_id, IterSpan{batch_iter, batch.Size()}, is_valid, dh::ToSpan(column_sizes_scan)); + thrust::copy(column_sizes_scan.begin(), column_sizes_scan.end(), h_column_size_1.begin()); + ASSERT_EQ(h_column_size, h_column_size_1); + + detail::LaunchGetColumnSizeKernel( + ctx.gpu_id, IterSpan{batch_iter, batch.Size()}, is_valid, dh::ToSpan(column_sizes_scan)); + thrust::copy(column_sizes_scan.begin(), column_sizes_scan.end(), h_column_size_1.begin()); + ASSERT_EQ(h_column_size, h_column_size_1); + + detail::LaunchGetColumnSizeKernel( + ctx.gpu_id, IterSpan{batch_iter, batch.Size()}, is_valid, dh::ToSpan(column_sizes_scan)); + thrust::copy(column_sizes_scan.begin(), column_sizes_scan.end(), h_column_size_1.begin()); + ASSERT_EQ(h_column_size, h_column_size_1); +} + +void TestMakeEntries(bst_row_t n_samples) { + bst_bin_t n_bins_per_feat = 256; + bst_feature_t n_features = 12; + + auto ctx = MakeCUDACtx(0); + auto [x, d_x, n_invalids, is_valid] = MakeData(&ctx, n_samples, n_features); + + HostDeviceVector cuts_ptr; + cuts_ptr.SetDevice(ctx.gpu_id); + auto n = n_samples * n_features; + + auto adapter = AdapterFromData(d_x, n_samples, n_features); + auto batch = adapter.Value(); + + auto batch_iter = dh::MakeTransformIterator( + thrust::make_counting_iterator(0llu), + [=] __device__(std::size_t idx) { return batch.GetElement(idx); }); + dh::caching_device_vector column_sizes_scan; + dh::device_vector sorted_idx; + std::size_t begin = 0, end = batch.Size(); + + detail::MakeEntriesFromAdapter(batch_iter, {begin, end}, is_valid, n_features, n_bins_per_feat, + ctx.gpu_id, &cuts_ptr, &column_sizes_scan, &sorted_idx); + ASSERT_EQ(sorted_idx.size(), n - n_invalids); + + std::vector h_sorted_idx(sorted_idx.size()); + thrust::copy(sorted_idx.cbegin(), sorted_idx.cend(), h_sorted_idx.begin()); + + for (auto idx : h_sorted_idx) { + ASSERT_TRUE(is_valid(x[idx])); + } + + std::vector h_column_ptr(column_sizes_scan.size()); + thrust::copy(column_sizes_scan.cbegin(), column_sizes_scan.cend(), h_column_ptr.begin()); + + for (std::size_t i = 1; i < column_sizes_scan.size(); ++i) { + auto beg = column_sizes_scan[i - 1]; + auto end = column_sizes_scan[i]; + for (std::size_t j = beg + 1; j < end; ++j) { + ASSERT_LE(h_sorted_idx.at(j - 1), h_sorted_idx.at(j)); + } + } +} +} // namespace + +TEST(HistUtil, MakeEntries) { + bst_row_t n_samples = 4096; + TestMakeEntries(n_samples); +} + +TEST(HistUtil, DISABLED_MakeEntriesLarge) { + // Disabled by default due to memory limit. + bst_row_t n_samples = + static_cast(std::numeric_limits::max()) / 12 + 256; + ASSERT_GT(n_samples * 12, std::numeric_limits::max()); + TestMakeEntries(n_samples); +} + +TEST(HistUtil, GetColumnSize) { + bst_row_t n_samples = 4096; + TestGetColumnSize(n_samples); +} + // Check sketching from adapter or DMatrix results in the same answer // Consistency here is useful for testing and user experience TEST(HistUtil, SketchingEquivalent) { @@ -529,7 +658,7 @@ TEST(HistUtil, DeviceSketchFromGroupWeights) { ASSERT_EQ(cuts.Ptrs().size(), weighted_cuts.Ptrs().size()); for (size_t i = 0; i < cuts.Values().size(); ++i) { - EXPECT_EQ(cuts.Values()[i], weighted_cuts.Values()[i]) << "i:"<< i; + EXPECT_EQ(cuts.Values()[i], weighted_cuts.Values()[i]); } for (size_t i = 0; i < cuts.MinValues().size(); ++i) { ASSERT_EQ(cuts.MinValues()[i], weighted_cuts.MinValues()[i]); @@ -540,92 +669,94 @@ TEST(HistUtil, DeviceSketchFromGroupWeights) { ValidateCuts(weighted_cuts, m.get(), kBins); } -void TestAdapterSketchFromWeights(bool with_group) { - size_t constexpr kRows = 300, kCols = 20, kBins = 256; - size_t constexpr kGroups = 10; - HostDeviceVector storage; - std::string m = - RandomDataGenerator{kRows, kCols, 0}.Device(0).GenerateArrayInterface( - &storage); - MetaInfo info; - Context ctx; - auto& h_weights = info.weights_.HostVector(); - if (with_group) { - h_weights.resize(kGroups); - } else { - h_weights.resize(kRows); - } - std::fill(h_weights.begin(), h_weights.end(), 1.0f); +class HistUtilSketch : public ::testing::TestWithParam { + public: + void TestAdapterSketchFromWeights(bool with_group) { + size_t constexpr kRows = 300, kCols = 20, kBins = 256; + size_t constexpr kGroups = 10; + HostDeviceVector storage; + std::string m = RandomDataGenerator{kRows, kCols, 0}.Device(0).GenerateArrayInterface(&storage); + MetaInfo info; + Context ctx; + auto& h_weights = info.weights_.HostVector(); + if (with_group) { + h_weights.resize(kGroups); + } else { + h_weights.resize(kRows); + } + std::fill(h_weights.begin(), h_weights.end(), 1.0f); - std::vector groups(kGroups); - if (with_group) { - for (size_t i = 0; i < kGroups; ++i) { - groups[i] = kRows / kGroups; + std::vector groups(kGroups); + if (with_group) { + for (size_t i = 0; i < kGroups; ++i) { + groups[i] = kRows / kGroups; + } + info.SetInfo(ctx, "group", groups.data(), DataType::kUInt32, kGroups); } - info.SetInfo(ctx, "group", groups.data(), DataType::kUInt32, kGroups); - } - info.weights_.SetDevice(0); - info.num_row_ = kRows; - info.num_col_ = kCols; + info.weights_.SetDevice(0); + info.num_row_ = kRows; + info.num_col_ = kCols; - data::CupyAdapter adapter(m); - auto const& batch = adapter.Value(); - HostDeviceVector ft; - SketchContainer sketch_container(ft, kBins, kCols, kRows, 0); - AdapterDeviceSketch(adapter.Value(), kBins, info, std::numeric_limits::quiet_NaN(), - &sketch_container); + data::CupyAdapter adapter(m); + auto const& batch = adapter.Value(); + HostDeviceVector ft; + SketchContainer sketch_container(ft, kBins, kCols, kRows, 0); + AdapterDeviceSketch(adapter.Value(), kBins, info, std::numeric_limits::quiet_NaN(), + &sketch_container); - common::HistogramCuts cuts; - sketch_container.MakeCuts(&cuts); + common::HistogramCuts cuts; + sketch_container.MakeCuts(&cuts); - auto dmat = GetDMatrixFromData(storage.HostVector(), kRows, kCols); - if (with_group) { - dmat->Info().SetInfo(ctx, "group", groups.data(), DataType::kUInt32, kGroups); - } - - dmat->Info().SetInfo(ctx, "weight", h_weights.data(), DataType::kFloat32, h_weights.size()); - dmat->Info().num_col_ = kCols; - dmat->Info().num_row_ = kRows; - ASSERT_EQ(cuts.Ptrs().size(), kCols + 1); - ValidateCuts(cuts, dmat.get(), kBins); - - if (with_group) { - dmat->Info().weights_ = decltype(dmat->Info().weights_)(); // remove weight - HistogramCuts non_weighted = DeviceSketch(0, dmat.get(), kBins, 0); - for (size_t i = 0; i < cuts.Values().size(); ++i) { - ASSERT_EQ(cuts.Values()[i], non_weighted.Values()[i]); - } - for (size_t i = 0; i < cuts.MinValues().size(); ++i) { - ASSERT_EQ(cuts.MinValues()[i], non_weighted.MinValues()[i]); + auto dmat = GetDMatrixFromData(storage.HostVector(), kRows, kCols); + if (with_group) { + dmat->Info().SetInfo(ctx, "group", groups.data(), DataType::kUInt32, kGroups); } - for (size_t i = 0; i < cuts.Ptrs().size(); ++i) { - ASSERT_EQ(cuts.Ptrs().at(i), non_weighted.Ptrs().at(i)); + + dmat->Info().SetInfo(ctx, "weight", h_weights.data(), DataType::kFloat32, h_weights.size()); + dmat->Info().num_col_ = kCols; + dmat->Info().num_row_ = kRows; + ASSERT_EQ(cuts.Ptrs().size(), kCols + 1); + ValidateCuts(cuts, dmat.get(), kBins); + + if (with_group) { + dmat->Info().weights_ = decltype(dmat->Info().weights_)(); // remove weight + HistogramCuts non_weighted = DeviceSketch(0, dmat.get(), kBins, 0); + for (size_t i = 0; i < cuts.Values().size(); ++i) { + ASSERT_EQ(cuts.Values()[i], non_weighted.Values()[i]); + } + for (size_t i = 0; i < cuts.MinValues().size(); ++i) { + ASSERT_EQ(cuts.MinValues()[i], non_weighted.MinValues()[i]); + } + for (size_t i = 0; i < cuts.Ptrs().size(); ++i) { + ASSERT_EQ(cuts.Ptrs().at(i), non_weighted.Ptrs().at(i)); + } } - } - if (with_group) { - common::HistogramCuts weighted; - auto& h_weights = info.weights_.HostVector(); - h_weights.resize(kGroups); - // Generate different weight. - for (size_t i = 0; i < h_weights.size(); ++i) { - // FIXME(jiamingy): Some entries generated GPU test cannot pass the validate cuts if - // we use more diverse weights, partially caused by - // https://github.com/dmlc/xgboost/issues/7946 - h_weights[i] = (i % 2 == 0 ? 1 : 2) / static_cast(kGroups); + if (with_group) { + common::HistogramCuts weighted; + auto& h_weights = info.weights_.HostVector(); + h_weights.resize(kGroups); + // Generate different weight. + for (size_t i = 0; i < h_weights.size(); ++i) { + // FIXME(jiamingy): Some entries generated GPU test cannot pass the validate cuts if + // we use more diverse weights, partially caused by + // https://github.com/dmlc/xgboost/issues/7946 + h_weights[i] = (i % 2 == 0 ? 1 : 2) / static_cast(kGroups); + } + SketchContainer sketch_container(ft, kBins, kCols, kRows, 0); + AdapterDeviceSketch(adapter.Value(), kBins, info, std::numeric_limits::quiet_NaN(), + &sketch_container); + sketch_container.MakeCuts(&weighted); + ValidateCuts(weighted, dmat.get(), kBins); } - SketchContainer sketch_container(ft, kBins, kCols, kRows, 0); - AdapterDeviceSketch(adapter.Value(), kBins, info, std::numeric_limits::quiet_NaN(), - &sketch_container); - sketch_container.MakeCuts(&weighted); - ValidateCuts(weighted, dmat.get(), kBins); } -} +}; -TEST(HistUtil, AdapterSketchFromWeights) { - TestAdapterSketchFromWeights(false); - TestAdapterSketchFromWeights(true); +TEST_P(HistUtilSketch, AdapterSketchFromWeights) { + bool with_group = GetParam(); + this->TestAdapterSketchFromWeights(with_group); } -} // namespace common -} // namespace xgboost + +INSTANTIATE_TEST_SUITE_P(HistUtil, HistUtilSketch, ::testing::Values(true, false)); +} // namespace xgboost::common diff --git a/tests/cpp/common/test_hist_util.h b/tests/cpp/common/test_hist_util.h index b8de641ffd7b..a3b12621b9bd 100644 --- a/tests/cpp/common/test_hist_util.h +++ b/tests/cpp/common/test_hist_util.h @@ -120,12 +120,10 @@ inline void TestBinDistribution(const HistogramCuts& cuts, int column_idx, // Test sketch quantiles against the real quantiles Not a very strict // test -inline void TestRank(const std::vector &column_cuts, - const std::vector &sorted_x, - const std::vector &sorted_weights) { +inline void TestRank(const std::vector& column_cuts, const std::vector& sorted_x, + const std::vector& sorted_weights) { double eps = 0.05; - auto total_weight = - std::accumulate(sorted_weights.begin(), sorted_weights.end(), 0.0); + auto total_weight = std::accumulate(sorted_weights.begin(), sorted_weights.end(), 0.0); // Ignore the last cut, its special double sum_weight = 0.0; size_t j = 0; @@ -136,28 +134,26 @@ inline void TestRank(const std::vector &column_cuts, } double expected_rank = ((i + 1) * total_weight) / column_cuts.size(); double acceptable_error = std::max(2.9, total_weight * eps); - EXPECT_LE(std::abs(expected_rank - sum_weight), acceptable_error); + ASSERT_LE(std::abs(expected_rank - sum_weight), acceptable_error); } } inline void ValidateColumn(const HistogramCuts& cuts, int column_idx, const std::vector& sorted_column, - const std::vector& sorted_weights, - size_t num_bins) { - + const std::vector& sorted_weights, size_t num_bins) { // Check the endpoints are correct CHECK_GT(sorted_column.size(), 0); - EXPECT_LT(cuts.MinValues().at(column_idx), sorted_column.front()); - EXPECT_GT(cuts.Values()[cuts.Ptrs()[column_idx]], sorted_column.front()); - EXPECT_GE(cuts.Values()[cuts.Ptrs()[column_idx+1]-1], sorted_column.back()); + ASSERT_LT(cuts.MinValues().at(column_idx), sorted_column.front()); + ASSERT_GT(cuts.Values()[cuts.Ptrs()[column_idx]], sorted_column.front()); + ASSERT_GE(cuts.Values()[cuts.Ptrs()[column_idx + 1] - 1], sorted_column.back()); // Check the cuts are sorted auto cuts_begin = cuts.Values().begin() + cuts.Ptrs()[column_idx]; auto cuts_end = cuts.Values().begin() + cuts.Ptrs()[column_idx + 1]; - EXPECT_TRUE(std::is_sorted(cuts_begin, cuts_end)); + ASSERT_TRUE(std::is_sorted(cuts_begin, cuts_end)); // Check all cut points are unique - EXPECT_EQ(std::set(cuts_begin, cuts_end).size(), + ASSERT_EQ(std::set(cuts_begin, cuts_end).size(), static_cast(cuts_end - cuts_begin)); auto unique = std::set(sorted_column.begin(), sorted_column.end()); @@ -173,8 +169,7 @@ inline void ValidateColumn(const HistogramCuts& cuts, int column_idx, int num_cuts_column = cuts.Ptrs()[column_idx + 1] - cuts.Ptrs()[column_idx]; std::vector column_cuts(num_cuts_column); std::copy(cuts.Values().begin() + cuts.Ptrs()[column_idx], - cuts.Values().begin() + cuts.Ptrs()[column_idx + 1], - column_cuts.begin()); + cuts.Values().begin() + cuts.Ptrs()[column_idx + 1], column_cuts.begin()); TestBinDistribution(cuts, column_idx, sorted_column, sorted_weights); TestRank(column_cuts, sorted_column, sorted_weights); } diff --git a/tests/cpp/common/test_quantile.cu b/tests/cpp/common/test_quantile.cu index f36334bcc794..44523bd971c2 100644 --- a/tests/cpp/common/test_quantile.cu +++ b/tests/cpp/common/test_quantile.cu @@ -22,7 +22,7 @@ TEST(GPUQuantile, Basic) { dh::device_vector cuts_ptr(kCols+1); thrust::fill(cuts_ptr.begin(), cuts_ptr.end(), 0); // Push empty - sketch.Push(dh::ToSpan(entries), dh::ToSpan(cuts_ptr), dh::ToSpan(cuts_ptr), 0); + sketch.Push(entries.data().get(), dh::ToSpan(cuts_ptr), dh::ToSpan(cuts_ptr), 0); ASSERT_EQ(sketch.Data().size(), 0); } @@ -50,7 +50,7 @@ void TestSketchUnique(float sparsity) { thrust::make_counting_iterator(0llu), [=] __device__(size_t idx) { return batch.GetElement(idx); }); auto end = kCols * kRows; - detail::GetColumnSizesScan(0, kCols, n_cuts, batch_iter, is_valid, 0, end, + detail::GetColumnSizesScan(0, kCols, n_cuts, IterSpan{batch_iter, end}, is_valid, &cut_sizes_scan, &column_sizes_scan); auto const& cut_sizes = cut_sizes_scan.HostVector(); ASSERT_LE(sketch.Data().size(), cut_sizes.back()); @@ -518,7 +518,7 @@ TEST(GPUQuantile, Push) { HostDeviceVector ft; SketchContainer sketch(ft, n_bins, kCols, kRows, 0); - sketch.Push(dh::ToSpan(d_entries), dh::ToSpan(columns_ptr), dh::ToSpan(columns_ptr), kRows, {}); + sketch.Push(d_entries.data().get(), dh::ToSpan(columns_ptr), dh::ToSpan(columns_ptr), kRows, {}); auto sketch_data = sketch.Data(); @@ -568,8 +568,8 @@ TEST(GPUQuantile, MultiColPush) { columns_ptr.begin()); dh::device_vector cuts_ptr(columns_ptr); - sketch.Push(dh::ToSpan(d_entries), dh::ToSpan(columns_ptr), - dh::ToSpan(cuts_ptr), kRows * kCols, {}); + sketch.Push(d_entries.data().get(), dh::ToSpan(columns_ptr), dh::ToSpan(cuts_ptr), kRows * kCols, + {}); auto sketch_data = sketch.Data(); ASSERT_EQ(sketch_data.size(), kCols * 2); diff --git a/tests/cpp/common/test_span.cc b/tests/cpp/common/test_span.cc index 133fae9fdd0a..b29c562bc3eb 100644 --- a/tests/cpp/common/test_span.cc +++ b/tests/cpp/common/test_span.cc @@ -1,15 +1,16 @@ -/*! - * Copyright 2018 XGBoost contributors +/** + * Copyright 2018-2023, XGBoost contributors */ -#include -#include +#include "test_span.h" +#include #include -#include "test_span.h" -namespace xgboost { -namespace common { +#include + +#include "../../../src/common/transform_iterator.h" // for MakeIndexTransformIter +namespace xgboost::common { TEST(Span, TestStatus) { int status = 1; TestTestStatus {&status}(); @@ -526,5 +527,17 @@ TEST(SpanDeathTest, Empty) { Span s{data.data(), static_cast::index_type>(0)}; EXPECT_DEATH(s[0], ""); // not ok to use it. } -} // namespace common -} // namespace xgboost + +TEST(IterSpan, Basic) { + auto iter = common::MakeIndexTransformIter([](std::size_t i) { return i; }); + std::size_t n = 13; + auto span = IterSpan{iter, n}; + ASSERT_EQ(span.size(), n); + for (std::size_t i = 0; i < n; ++i) { + ASSERT_EQ(span[i], i); + } + ASSERT_EQ(span.subspan(1).size(), n - 1); + ASSERT_EQ(span.subspan(1)[0], 1); + ASSERT_EQ(span.subspan(1, 2)[1], 2); +} +} // namespace xgboost::common diff --git a/tests/cpp/cub_sort/test_dispatch_radix_sort.cu b/tests/cpp/cub_sort/test_dispatch_radix_sort.cu new file mode 100644 index 000000000000..76f9ca719d60 --- /dev/null +++ b/tests/cpp/cub_sort/test_dispatch_radix_sort.cu @@ -0,0 +1,222 @@ +/** + * Copyright 2023, XGBoost contributors + */ +#include +#include // for device_vector +#include +#include // for make_reverse_iterator +#include // for sequence +#include // for bst_feature_t +#include // for Entry +#include // for HostDeviceVector + +#include // for bitset +#include // for uint32_t +#include // for size_t +#include // for default_random_engine, uniform_int_distribution, uniform_real_distribution +#include // for stringstream +#include // for tuple +#include // for vector + +#include "../../../src/cub_sort/device/device_radix_sort.cuh" + +namespace cub_argsort { +using xgboost::Entry; +using xgboost::HostDeviceVector; + +namespace { +void TestBitCast() { + Entry e; + e.index = 3; + e.fvalue = -512.0f; + + static_assert(sizeof(Entry) == sizeof(std::uint64_t)); + std::uint64_t bits; + std::memcpy(&bits, &e, sizeof(e)); + + std::bitset<64> set{bits}; + + std::uint32_t* ptr = reinterpret_cast(&bits); + std::bitset<32> lhs{ptr[0]}; + std::bitset<32> rhs{ptr[1]}; + // The first 32-bit segment contains the feature index + ASSERT_EQ(lhs, std::bitset<32>{e.index}); + + std::swap(ptr[0], ptr[1]); + set = bits; + // after swap, the second segment contains the feature index + ASSERT_EQ(std::bitset<32>{ptr[1]}, std::bitset<32>{e.index}); + + bits = EntryTrait::LOWEST_KEY; + auto pptr = reinterpret_cast(&bits); + ASSERT_EQ(pptr[0], []() { return ::cub::NumericTraits::LOWEST_KEY; }()); + ASSERT_EQ(pptr[1], []() { return ::cub::NumericTraits::LOWEST_KEY; }()); + + bits = EntryTrait::MAX_KEY; + pptr = reinterpret_cast(&bits); + ASSERT_EQ(pptr[0], []() { return ::cub::NumericTraits::MAX_KEY; }()); + ASSERT_EQ(pptr[1], []() { return ::cub::NumericTraits::MAX_KEY; }()); +} + +enum TestType { kf32, kf64, ki32 }; + +class RadixArgSortNumeric : public ::testing::TestWithParam> { + public: + template + void TestArgSort(std::size_t n) { + HostDeviceVector data; + data.SetDevice(0); + data.Resize(n, 0.0f); + auto d_data = data.DeviceSpan(); + auto beg = thrust::make_reverse_iterator(d_data.data() + d_data.size()); + thrust::sequence(thrust::device, beg, beg + d_data.size(), -static_cast(n / 2.0)); + auto const& h_in = data.ConstHostSpan(); + + HostDeviceVector idx_out(data.Size(), 0u); + idx_out.SetDevice(0); + auto d_idx_out = idx_out.DeviceSpan(); + + std::size_t bytes{0}; + DeviceRadixSort>::Argsort(nullptr, bytes, d_data.data(), + d_idx_out.data(), d_data.size()); + thrust::device_vector temp(bytes); + DeviceRadixSort>::Argsort(temp.data().get(), bytes, d_data.data(), + d_idx_out.data(), d_data.size()); + ASSERT_GT(bytes, n * sizeof(std::uint32_t)); + + auto const& h_idx_out = idx_out.ConstHostSpan(); + for (std::size_t i = 1; i < h_idx_out.size(); ++i) { + ASSERT_EQ(h_idx_out[i] + 1, h_idx_out[i - 1]); + ASSERT_EQ(h_in[h_idx_out[i]], h_in[h_idx_out[i - 1]] + 1); + } + } + + template + void TestSameValue(std::size_t n) { + HostDeviceVector data(n, static_cast(1.0), 0); + + auto d_data = data.ConstDeviceSpan(); + HostDeviceVector idx_out(n); + idx_out.SetDevice(0); + auto d_idx_out = idx_out.DeviceSpan(); + + std::size_t bytes{0}; + DeviceRadixSort::Argsort(nullptr, bytes, d_data.data(), d_idx_out.data(), + d_data.size()); + thrust::device_vector temp(bytes); + DeviceRadixSort::Argsort(temp.data().get(), bytes, d_data.data(), + d_idx_out.data(), d_data.size()); + + auto const& h_idx = idx_out.ConstHostVector(); + std::vector expected(n); + std::iota(expected.begin(), expected.end(), 0); + ASSERT_EQ(h_idx, expected); + } +}; + +class RadixArgSortEntry : public ::testing::TestWithParam { + public: + void TestCustomExtractor(std::size_t n) { + HostDeviceVector data(n, Entry{0, 0}, 0); + + auto& h_data = data.HostVector(); + + std::default_random_engine rng; + rng.seed(1); + + std::uniform_int_distribution fdist(0, 27); + std::uniform_real_distribution vdist(-8.0f, 8.0f); + + for (auto it = h_data.rbegin(); it != h_data.rend(); ++it) { + auto d = std::distance(h_data.rbegin(), it); + it->fvalue = vdist(rng); + it->index = fdist(rng); + } + + HostDeviceVector out_idx(n, 0u, 0); + + auto d_data = data.ConstDeviceSpan(); + auto d_idx_out = out_idx.DeviceSpan(); + std::size_t bytes{0}; + + DeviceRadixSort::Argsort(nullptr, bytes, d_data.data(), d_idx_out.data(), + d_data.size()); + thrust::device_vector temp(bytes); + DeviceRadixSort::Argsort(temp.data().get(), bytes, d_data.data(), + d_idx_out.data(), d_data.size()); + + auto const& h_idx = out_idx.ConstHostVector(); + + for (std::size_t i = 1; i < h_idx.size(); ++i) { + ASSERT_GE(h_data[h_idx[i]].index, h_data[h_idx[i - 1]].index); + if (h_data[h_idx[i]].index == h_data[h_idx[i - 1]].index) { + // within the same feature, value should be increasing. + ASSERT_GE(h_data[h_idx[i]].fvalue, h_data[h_idx[i - 1]].fvalue); + } + } + } +}; +} // namespace + +TEST_P(RadixArgSortNumeric, Basic) { + auto [t, n] = GetParam(); + switch (t) { + case kf32: { + TestArgSort(n); + break; + } + case kf64: { + TestArgSort(n); + break; + } + case ki32: { + TestArgSort(n); + break; + } + }; +} + +TEST_P(RadixArgSortNumeric, SameValue) { + auto [t, n] = GetParam(); + switch (t) { + case kf32: { + TestSameValue(n); + break; + } + case kf64: { + TestSameValue(n); + break; + } + case ki32: { + TestSameValue(n); + break; + } + }; +} + +INSTANTIATE_TEST_SUITE_P(RadixArgSort, RadixArgSortNumeric, + testing::Values(std::tuple{kf32, 128}, std::tuple{kf64, 128}, + std::tuple{ki32, 128}, std::tuple{kf32, 8192}, + std::tuple{kf64, 8192}, std::tuple{ki32, 8192}), + ([](::testing::TestParamInfo const& info) { + auto [t, n] = info.param; + std::stringstream ss; + ss << static_cast(t) << "_" << n; + return ss.str(); + })); + +TEST_P(RadixArgSortEntry, Basic) { + std::size_t n = GetParam(); + TestCustomExtractor(n); +} + +INSTANTIATE_TEST_SUITE_P(RadixArgSort, RadixArgSortEntry, testing::Values(128, 8192), + ([](::testing::TestParamInfo const& info) { + auto n = info.param; + std::stringstream ss; + ss << n; + return ss.str(); + })); + +TEST(RadixArgSort, BitCast) { TestBitCast(); } +} // namespace cub_argsort diff --git a/tests/python-gpu/test_from_cudf.py b/tests/python-gpu/test_from_cudf.py index 523dbf9312a4..c3b27c363217 100644 --- a/tests/python-gpu/test_from_cudf.py +++ b/tests/python-gpu/test_from_cudf.py @@ -12,7 +12,7 @@ def dmatrix_from_cudf(input_type, DMatrixT, missing=np.NAN): - '''Test constructing DMatrix from cudf''' + """Test constructing DMatrix from cudf""" import cudf import pandas as pd @@ -25,24 +25,27 @@ def dmatrix_from_cudf(input_type, DMatrixT, missing=np.NAN): na[5, 0] = missing na[3, 1] = missing - pa = pd.DataFrame({'0': na[:, 0], - '1': na[:, 1], - '2': na[:, 2].astype(np.int32)}) + pa = pd.DataFrame({"0": na[:, 0], "1": na[:, 1], "2": na[:, 2].astype(np.int32)}) np_label = np.random.randn(kRows).astype(input_type) pa_label = pd.DataFrame(np_label) - cd = cudf.from_pandas(pa) + cudf_df = cudf.from_pandas(pa) + cudf_df[["0", "1"]] = cudf_df[["0", "1"]].astype(input_type) cd_label = cudf.from_pandas(pa_label).iloc[:, 0] - dtrain = DMatrixT(cd, missing=missing, label=cd_label) + dtrain = DMatrixT(cudf_df, missing=missing, label=cd_label) assert dtrain.num_col() == kCols assert dtrain.num_row() == kRows + dtrain_from_pd = DMatrixT(pa, missing=missing, label=pa_label) + tm.predictor_equal(dtrain_from_pd, dtrain) + def _test_from_cudf(DMatrixT): - '''Test constructing DMatrix from cudf''' + """Test constructing DMatrix from cudf""" import cudf + dmatrix_from_cudf(np.float32, DMatrixT, np.NAN) dmatrix_from_cudf(np.float64, DMatrixT, np.NAN) @@ -50,37 +53,35 @@ def _test_from_cudf(DMatrixT): dmatrix_from_cudf(np.int32, DMatrixT, -2) dmatrix_from_cudf(np.int64, DMatrixT, -3) - cd = cudf.DataFrame({'x': [1, 2, 3], 'y': [0.1, 0.2, 0.3]}) + cd = cudf.DataFrame({"x": [1, 2, 3], "y": [0.1, 0.2, 0.3]}) dtrain = DMatrixT(cd) - assert dtrain.feature_names == ['x', 'y'] - assert dtrain.feature_types == ['int', 'float'] + assert dtrain.feature_names == ["x", "y"] + assert dtrain.feature_types == ["int", "float"] - series = cudf.DataFrame({'x': [1, 2, 3]}).iloc[:, 0] + series = cudf.DataFrame({"x": [1, 2, 3]}).iloc[:, 0] assert isinstance(series, cudf.Series) dtrain = DMatrixT(series) - assert dtrain.feature_names == ['x'] - assert dtrain.feature_types == ['int'] + assert dtrain.feature_names == ["x"] + assert dtrain.feature_types == ["int"] with pytest.raises(ValueError, match=r".*multi.*"): dtrain = DMatrixT(cd, label=cd) xgb.train({"tree_method": "gpu_hist", "objective": "multi:softprob"}, dtrain) # Test when number of elements is less than 8 - X = cudf.DataFrame({'x': cudf.Series([0, 1, 2, np.NAN, 4], - dtype=np.int32)}) + X = cudf.DataFrame({"x": cudf.Series([0, 1, 2, np.NAN, 4], dtype=np.int32)}) dtrain = DMatrixT(X) assert dtrain.num_col() == 1 assert dtrain.num_row() == 5 # Boolean is not supported. - X_boolean = cudf.DataFrame({'x': cudf.Series([True, False])}) + X_boolean = cudf.DataFrame({"x": cudf.Series([True, False])}) with pytest.raises(Exception): dtrain = DMatrixT(X_boolean) - y_boolean = cudf.DataFrame({ - 'x': cudf.Series([True, False, True, True, True])}) + y_boolean = cudf.DataFrame({"x": cudf.Series([True, False, True, True, True])}) with pytest.raises(Exception): dtrain = DMatrixT(X_boolean, label=y_boolean)