From 69fea724289e21a6ae976d15022c264cba832c60 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Thu, 21 Jul 2022 04:27:44 -0700 Subject: [PATCH 01/13] Loop unrolling in GPU histogram kernel. --- src/common/device_helpers.cuh | 2 +- src/tree/gpu_hist/histogram.cu | 286 ++++++++++++++++++++------------- 2 files changed, 178 insertions(+), 110 deletions(-) diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index ccec859a286c..f3d387983b61 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -1949,7 +1949,7 @@ class LDGIterator { const T *ptr_; public: - explicit LDGIterator(const T *ptr) : ptr_(ptr) {} + XGBOOST_DEVICE explicit LDGIterator(const T *ptr) : ptr_(ptr) {} __device__ T operator[](std::size_t idx) const { DeviceWordT tmp[kNumWords]; static_assert(sizeof(tmp) == sizeof(T), "Expect sizes to be equal."); diff --git a/src/tree/gpu_hist/histogram.cu b/src/tree/gpu_hist/histogram.cu index a63c6c9548ed..ff6089ec5935 100644 --- a/src/tree/gpu_hist/histogram.cu +++ b/src/tree/gpu_hist/histogram.cu @@ -1,19 +1,18 @@ /*! * Copyright 2020-2021 by XGBoost Contributors */ -#include #include +#include + #include #include #include -#include "xgboost/base.h" -#include "row_partitioner.cuh" - -#include "histogram.cuh" - -#include "../../data/ellpack_page.cuh" #include "../../common/device_helpers.cuh" +#include "../../data/ellpack_page.cuh" +#include "histogram.cuh" +#include "row_partitioner.cuh" +#include "xgboost/base.h" namespace xgboost { namespace tree { @@ -59,12 +58,8 @@ __host__ XGBOOST_DEV_INLINE Pair operator+(Pair const& lhs, Pair const& rhs) { } // anonymous namespace struct Clip : public thrust::unary_function { - static XGBOOST_DEV_INLINE float Pclip(float v) { - return v > 0 ? v : 0; - } - static XGBOOST_DEV_INLINE float Nclip(float v) { - return v < 0 ? abs(v) : 0; - } + static XGBOOST_DEV_INLINE float Pclip(float v) { return v > 0 ? v : 0; } + static XGBOOST_DEV_INLINE float Nclip(float v) { return v < 0 ? abs(v) : 0; } XGBOOST_DEV_INLINE Pair operator()(GradientPair x) const { auto pg = Pclip(x.GetGrad()); @@ -73,7 +68,7 @@ struct Clip : public thrust::unary_function { auto ng = Nclip(x.GetGrad()); auto nh = Nclip(x.GetHess()); - return { GradientPair{ pg, ph }, GradientPair{ ng, nh } }; + return {GradientPair{pg, ph}, GradientPair{ng, nh}}; } }; @@ -82,18 +77,18 @@ HistRounding CreateRoundingFactor(common::Span using T = typename GradientSumT::ValueT; dh::XGBCachingDeviceAllocator alloc; - thrust::device_ptr gpair_beg {gpair.data()}; - thrust::device_ptr gpair_end {gpair.data() + gpair.size()}; + thrust::device_ptr gpair_beg{gpair.data()}; + thrust::device_ptr gpair_end{gpair.data() + gpair.size()}; auto beg = thrust::make_transform_iterator(gpair_beg, Clip()); auto end = thrust::make_transform_iterator(gpair_end, Clip()); Pair p = dh::Reduce(thrust::cuda::par(alloc), beg, end, Pair{}, thrust::plus{}); - GradientPair positive_sum {p.first}, negative_sum {p.second}; + GradientPair positive_sum{p.first}, negative_sum{p.second}; - auto histogram_rounding = GradientSumT { - CreateRoundingFactor(std::max(positive_sum.GetGrad(), negative_sum.GetGrad()), - gpair.size()), - CreateRoundingFactor(std::max(positive_sum.GetHess(), negative_sum.GetHess()), - gpair.size()) }; + auto histogram_rounding = + GradientSumT{CreateRoundingFactor(std::max(positive_sum.GetGrad(), negative_sum.GetGrad()), + gpair.size()), + CreateRoundingFactor(std::max(positive_sum.GetHess(), negative_sum.GetHess()), + gpair.size())}; using IntT = typename HistRounding::SharedSumT::ValueT; @@ -102,8 +97,7 @@ HistRounding CreateRoundingFactor(common::Span */ GradientSumT to_floating_point = histogram_rounding / - T(IntT(1) << (sizeof(typename GradientSumT::ValueT) * 8 - - 2)); // keep 1 for sign bit + T(IntT(1) << (sizeof(typename GradientSumT::ValueT) * 8 - 2)); // keep 1 for sign bit /** * Factor for converting gradients from floating-point to fixed-point. For * f64: @@ -113,60 +107,98 @@ HistRounding CreateRoundingFactor(common::Span * rounding is calcuated as exp(m), see the rounding factor calcuation for * details. */ - GradientSumT to_fixed_point = GradientSumT( - T(1) / to_floating_point.GetGrad(), T(1) / to_floating_point.GetHess()); + GradientSumT to_fixed_point = + GradientSumT(T(1) / to_floating_point.GetGrad(), T(1) / to_floating_point.GetHess()); return {histogram_rounding, to_fixed_point, to_floating_point}; } -template HistRounding -CreateRoundingFactor(common::Span gpair); -template HistRounding -CreateRoundingFactor(common::Span gpair); - -template -__global__ void SharedMemHistKernel(EllpackDeviceAccessor matrix, - FeatureGroupsAccessor feature_groups, - common::Span d_ridx, - GradientSumT* __restrict__ d_node_hist, - const GradientPair* __restrict__ d_gpair, - HistRounding const rounding) { - using SharedSumT = typename HistRounding::SharedSumT; - using T = typename GradientSumT::ValueT; +template HistRounding CreateRoundingFactor( + common::Span gpair); +template HistRounding CreateRoundingFactor(common::Span gpair); - extern __shared__ char smem[]; - FeatureGroup group = feature_groups[blockIdx.y]; - SharedSumT *smem_arr = reinterpret_cast(smem); - if (use_shared_memory_histograms) { - dh::BlockFill(smem_arr, group.num_bins, SharedSumT()); - __syncthreads(); - } - int feature_stride = matrix.is_dense ? group.num_features : matrix.row_stride; - size_t n_elements = feature_stride * d_ridx.size(); - for (auto idx : dh::GridStrideRange(static_cast(0), n_elements)) { - int ridx = d_ridx[idx / feature_stride]; - int gidx = matrix.gidx_iter[ridx * matrix.row_stride + group.start_feature + - idx % feature_stride]; - if (gidx != matrix.NumBins()) { - // If we are not using shared memory, accumulate the values directly into - // global memory - gidx = use_shared_memory_histograms ? gidx - group.start_bin : gidx; - if (use_shared_memory_histograms) { +template +class HistogramAgent { + using SharedSumT = typename HistRounding::SharedSumT; + SharedSumT* smem_arr; + GradientSumT* d_node_hist; + dh::LDGIterator d_ridx; + const GradientPair* d_gpair; + const FeatureGroup group; + const EllpackDeviceAccessor& matrix; + const int feature_stride; + const std::size_t n_elements; + const HistRounding& rounding; + + public: + __device__ HistogramAgent(SharedSumT* smem_arr, GradientSumT* __restrict__ d_node_hist, + const FeatureGroup& group, const EllpackDeviceAccessor& matrix, + common::Span d_ridx, + const HistRounding& rounding, const GradientPair* d_gpair) + : smem_arr(smem_arr), + d_node_hist(d_node_hist), + d_ridx(d_ridx.data()), + group(group), + matrix(matrix), + feature_stride(matrix.is_dense ? group.num_features : matrix.row_stride), + n_elements(feature_stride * d_ridx.size()), + rounding(rounding), + d_gpair(d_gpair) {} + __device__ void ProcessPartialTileShared(std::size_t offset) { + for (std::size_t idx = offset + threadIdx.x; + idx < min(offset + kBlockThreads * kItemsPerTile, n_elements); idx += kBlockThreads) { + int ridx = d_ridx[idx / feature_stride]; + int gidx = + matrix.gidx_iter[ridx * matrix.row_stride + group.start_feature + idx % feature_stride] - + group.start_bin; + if (matrix.is_dense || gidx != matrix.NumBins()) { auto adjusted = rounding.ToFixedPoint(d_gpair[ridx]); dh::AtomicAddGpair(smem_arr + gidx, adjusted); - } else { - GradientSumT truncated{ - TruncateWithRoundingFactor(rounding.rounding.GetGrad(), - d_gpair[ridx].GetGrad()), - TruncateWithRoundingFactor(rounding.rounding.GetHess(), - d_gpair[ridx].GetHess()), - }; - dh::AtomicAddGpair(d_node_hist + gidx, truncated); } } } + // Instruction level parallelism by loop unrolling + // Allows the kernel to pipeline many operations while waiting for global memory + // Increases the throughput of this kernel significantly + __device__ void ProcessFullTileShared(std::size_t offset) { + std::size_t idx[kItemsPerThread]; + int ridx[kItemsPerThread]; + int gidx[kItemsPerThread]; + GradientPair gpair[kItemsPerThread]; +#pragma unroll + for (int i = 0; i < kItemsPerThread; i++) { + idx[i] = offset + i * kBlockThreads + threadIdx.x; + } +#pragma unroll + for (int i = 0; i < kItemsPerThread; i++) { + ridx[i] = d_ridx[idx[i] / feature_stride]; + } +#pragma unroll + for (int i = 0; i < kItemsPerThread; i++) { + gpair[i] = d_gpair[ridx[i]]; + gidx[i] = matrix.gidx_iter[ridx[i] * matrix.row_stride + group.start_feature + + idx[i] % feature_stride]; + } +#pragma unroll + for (int i = 0; i < kItemsPerThread; i++) { + if ((matrix.is_dense || gidx[i] != matrix.NumBins())) { + auto adjusted = rounding.ToFixedPoint(gpair[i]); + dh::AtomicAddGpair(smem_arr + gidx[i] - group.start_bin, adjusted); + } + } + } + __device__ void BuildHistogramWithShared() { + dh::BlockFill(smem_arr, group.num_bins, SharedSumT()); + __syncthreads(); + + std::size_t offset = blockIdx.x * kItemsPerTile; + while (offset + kItemsPerTile <= n_elements) { + ProcessFullTileShared(offset); + offset += kItemsPerTile * gridDim.x; + } + ProcessPartialTileShared(offset); - if (use_shared_memory_histograms) { // Write shared memory back to global memory __syncthreads(); for (auto i : dh::BlockStrideRange(0, group.num_bins)) { @@ -174,6 +206,49 @@ __global__ void SharedMemHistKernel(EllpackDeviceAccessor matrix, dh::AtomicAddGpair(d_node_hist + group.start_bin + i, truncated); } } + + __device__ void BuildHistogramWithGlobal() { + for (auto idx : dh::GridStrideRange(static_cast(0), n_elements)) { + int ridx = d_ridx[idx / feature_stride]; + int gidx = + matrix.gidx_iter[ridx * matrix.row_stride + group.start_feature + idx % feature_stride]; + if (matrix.is_dense || gidx != matrix.NumBins()) { + // If we are not using shared memory, accumulate the values directly into + // global memory + GradientSumT truncated{ + TruncateWithRoundingFactor(rounding.rounding.GetGrad(), + d_gpair[ridx].GetGrad()), + TruncateWithRoundingFactor(rounding.rounding.GetHess(), + d_gpair[ridx].GetHess()), + }; + dh::AtomicAddGpair(d_node_hist + gidx, truncated); + } + } + } +}; + +template +__global__ void __launch_bounds__(kBlockThreads) + SharedMemHistKernel(const EllpackDeviceAccessor matrix, + const FeatureGroupsAccessor feature_groups, + common::Span d_ridx, + GradientSumT* __restrict__ d_node_hist, + const GradientPair* __restrict__ d_gpair, + HistRounding const rounding) { + using SharedSumT = typename HistRounding::SharedSumT; + using T = typename GradientSumT::ValueT; + + extern __shared__ char smem[]; + const FeatureGroup group = feature_groups[blockIdx.y]; + SharedSumT* smem_arr = reinterpret_cast(smem); + auto agent = HistogramAgent( + smem_arr, d_node_hist, group, matrix, d_ridx, rounding, d_gpair); + if (use_shared_memory_histograms) { + agent.BuildHistogramWithShared(); + } else { + agent.BuildHistogramWithGlobal(); + } } template @@ -182,78 +257,71 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix, common::Span gpair, common::Span d_ridx, common::Span histogram, - HistRounding rounding, - bool force_global_memory) { + HistRounding rounding, bool force_global_memory) { // decide whether to use shared memory int device = 0; dh::safe_cuda(cudaGetDevice(&device)); // opt into maximum shared memory for the kernel if necessary size_t max_shared_memory = dh::MaxSharedMemoryOptin(device); - size_t smem_size = sizeof(typename HistRounding::SharedSumT) * - feature_groups.max_group_bins; + size_t smem_size = + sizeof(typename HistRounding::SharedSumT) * feature_groups.max_group_bins; bool shared = !force_global_memory && smem_size <= max_shared_memory; smem_size = shared ? smem_size : 0; + constexpr int kBlockThreads = 1024; + constexpr int kItemsPerThread = 8; + constexpr int kItemsPerTile = kBlockThreads * kItemsPerThread; + auto runit = [&](auto kernel) { if (shared) { - dh::safe_cuda(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - max_shared_memory)); + dh::safe_cuda(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + max_shared_memory)); } // determine the launch configuration - int min_grid_size; - int block_threads = 1024; - dh::safe_cuda(cudaOccupancyMaxPotentialBlockSize( - &min_grid_size, &block_threads, kernel, smem_size, 0)); - int num_groups = feature_groups.NumGroups(); int n_mps = 0; - dh::safe_cuda( - cudaDeviceGetAttribute(&n_mps, cudaDevAttrMultiProcessorCount, device)); + dh::safe_cuda(cudaDeviceGetAttribute(&n_mps, cudaDevAttrMultiProcessorCount, device)); int n_blocks_per_mp = 0; - dh::safe_cuda(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &n_blocks_per_mp, kernel, block_threads, smem_size)); + dh::safe_cuda(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&n_blocks_per_mp, kernel, + kBlockThreads, smem_size)); + // This gives the number of blocks to keep the device occupied + // Use this as the maximum number of blocks unsigned grid_size = n_blocks_per_mp * n_mps; - // TODO(canonizer): This is really a hack, find a better way to distribute - // the data among thread blocks. The intention is to generate enough thread - // blocks to fill the GPU, but avoid having too many thread blocks, as this - // is less efficient when the number of rows is low. At least one thread - // block per feature group is required. The number of thread blocks: - // - for num_groups <= num_groups_threshold, around grid_size * num_groups - // - for num_groups_threshold <= num_groups <= num_groups_threshold * - // grid_size, - // around grid_size * num_groups_threshold - // - for num_groups_threshold * grid_size <= num_groups, around num_groups - int num_groups_threshold = 4; - grid_size = common::DivRoundUp( - grid_size, common::DivRoundUp(num_groups, num_groups_threshold)); - - using T = typename GradientSumT::ValueT; + // Otherwise launch blocks such that each block has a minimum amount of work to do + // There are fixed costs to launching each block, e.g. zeroing shared memory + // The below amount of minimum work was found by experimentation + constexpr int kMinItemsPerBlock = kItemsPerTile * 16; + int columns_per_group = common::DivRoundUp(matrix.row_stride, feature_groups.NumGroups()); + // Average number of matrix elements processed by each group + std::size_t items_per_group = d_ridx.size() * columns_per_group; + + // Allocate number of blocks such that each block has about kMinItemsPerBlock work + // Up to a maximum where the device is saturated + grid_size = + min(grid_size, + unsigned(common::DivRoundUp(items_per_group, kMinItemsPerBlock))); dh::LaunchKernel {dim3(grid_size, num_groups), - static_cast(block_threads), - smem_size} (kernel, matrix, feature_groups, d_ridx, - histogram.data(), gpair.data(), rounding); + static_cast(kBlockThreads), smem_size}( + kernel, matrix, feature_groups, d_ridx, histogram.data(), gpair.data(), rounding); }; if (shared) { - runit(SharedMemHistKernel); + runit(SharedMemHistKernel); } else { - runit(SharedMemHistKernel); + runit(SharedMemHistKernel); } + dh::safe_cuda(cudaDeviceSynchronize()); dh::safe_cuda(cudaGetLastError()); } template void BuildGradientHistogram( - EllpackDeviceAccessor const& matrix, - FeatureGroupsAccessor const& feature_groups, - common::Span gpair, - common::Span ridx, - common::Span histogram, - HistRounding rounding, + EllpackDeviceAccessor const& matrix, FeatureGroupsAccessor const& feature_groups, + common::Span gpair, common::Span ridx, + common::Span histogram, HistRounding rounding, bool force_global_memory); } // namespace tree From 8abd10358ffa3f30a0edcaa63fd6609458f02f2f Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Thu, 21 Jul 2022 06:36:26 -0700 Subject: [PATCH 02/13] Tune for higgs. --- src/tree/gpu_hist/histogram.cu | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/tree/gpu_hist/histogram.cu b/src/tree/gpu_hist/histogram.cu index ff6089ec5935..8e4adf22ac32 100644 --- a/src/tree/gpu_hist/histogram.cu +++ b/src/tree/gpu_hist/histogram.cu @@ -270,7 +270,7 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix, smem_size = shared ? smem_size : 0; constexpr int kBlockThreads = 1024; - constexpr int kItemsPerThread = 8; + constexpr int kItemsPerThread = 4; constexpr int kItemsPerTile = kBlockThreads * kItemsPerThread; auto runit = [&](auto kernel) { @@ -293,7 +293,7 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix, // Otherwise launch blocks such that each block has a minimum amount of work to do // There are fixed costs to launching each block, e.g. zeroing shared memory // The below amount of minimum work was found by experimentation - constexpr int kMinItemsPerBlock = kItemsPerTile * 16; + constexpr int kMinItemsPerBlock = kItemsPerTile; int columns_per_group = common::DivRoundUp(matrix.row_stride, feature_groups.NumGroups()); // Average number of matrix elements processed by each group std::size_t items_per_group = d_ridx.size() * columns_per_group; @@ -303,6 +303,7 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix, grid_size = min(grid_size, unsigned(common::DivRoundUp(items_per_group, kMinItemsPerBlock))); + dh::LaunchKernel {dim3(grid_size, num_groups), static_cast(kBlockThreads), smem_size}( kernel, matrix, feature_groups, d_ridx, histogram.data(), gpair.data(), rounding); From 2cc8b16819896e32df9197510f1c734bd46d9222 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Fri, 22 Jul 2022 05:41:19 -0700 Subject: [PATCH 03/13] Use 8 items per thread. --- src/tree/gpu_hist/histogram.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tree/gpu_hist/histogram.cu b/src/tree/gpu_hist/histogram.cu index 8e4adf22ac32..255c4588c8c3 100644 --- a/src/tree/gpu_hist/histogram.cu +++ b/src/tree/gpu_hist/histogram.cu @@ -270,7 +270,7 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix, smem_size = shared ? smem_size : 0; constexpr int kBlockThreads = 1024; - constexpr int kItemsPerThread = 4; + constexpr int kItemsPerThread = 8; constexpr int kItemsPerTile = kBlockThreads * kItemsPerThread; auto runit = [&](auto kernel) { From 4058e788649dd054065f492be3f5810b1dbfac32 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Sat, 16 Jul 2022 11:54:47 -0700 Subject: [PATCH 04/13] Use aligned loads in compressed iterator. --- src/common/compressed_iterator.h | 204 ++++-------------- src/data/ellpack_page.cu | 110 +++++----- src/data/ellpack_page.cuh | 9 +- tests/cpp/common/test_compressed_iterator.cc | 26 +-- .../common/test_gpu_compressed_iterator.cu | 47 ++-- tests/cpp/data/test_ellpack_page.cu | 6 +- tests/cpp/data/test_iterative_dmatrix.cu | 8 +- .../gpu_hist/test_gradient_based_sampler.cu | 4 +- 8 files changed, 122 insertions(+), 292 deletions(-) diff --git a/src/common/compressed_iterator.h b/src/common/compressed_iterator.h index 9f60722fb982..1aff259cb145 100644 --- a/src/common/compressed_iterator.h +++ b/src/common/compressed_iterator.h @@ -20,19 +20,6 @@ namespace common { using CompressedByteT = unsigned char; namespace detail { -inline void SetBit(CompressedByteT *byte, int bit_idx) { - *byte |= 1 << bit_idx; -} -template -inline T CheckBit(const T &byte, int bit_idx) { - return byte & (1 << bit_idx); -} -inline void ClearBit(CompressedByteT *byte, int bit_idx) { - *byte &= ~(1 << bit_idx); -} -static const int kPadding = 4; // Assign padding so we can read slightly off - // the beginning of the array - // The number of bits required to represent a given unsigned range inline XGBOOST_DEVICE size_t SymbolBits(size_t num_symbols) { auto bits = std::ceil(log2(static_cast(num_symbols))); @@ -40,183 +27,70 @@ inline XGBOOST_DEVICE size_t SymbolBits(size_t num_symbols) { } } // namespace detail -/** - * \class CompressedBufferWriter - * - * \brief Writes bit compressed symbols to a memory buffer. Use - * CompressedIterator to read symbols back from buffer. Currently limited to a - * maximum symbol size of 28 bits. - * - * \author Rory - * \date 7/9/2017 - */ - -class CompressedBufferWriter { - size_t symbol_bits_; - - public: - XGBOOST_DEVICE explicit CompressedBufferWriter(size_t num_symbols) - : symbol_bits_(detail::SymbolBits(num_symbols)) {} - - /** - * \fn static size_t CompressedBufferWriter::CalculateBufferSize(int - * num_elements, int num_symbols) - * - * \brief Calculates number of bytes required for a given number of elements - * and a symbol range. - * - * \author Rory - * \date 7/9/2017 - * - * \param num_elements Number of elements. - * \param num_symbols Max number of symbols (alphabet size) - * - * \return The calculated buffer size. - */ - static size_t CalculateBufferSize(size_t num_elements, size_t num_symbols) { - constexpr int kBitsPerByte = 8; - size_t compressed_size = static_cast(std::ceil( - static_cast(detail::SymbolBits(num_symbols) * num_elements) / - kBitsPerByte)); - // Handle atomicOr where input must be unsigned int, hence 4 bytes aligned. - size_t ret = - std::ceil(static_cast(compressed_size + detail::kPadding) / - static_cast(sizeof(unsigned int))) * - sizeof(unsigned int); - return ret; +inline XGBOOST_DEVICE int SmallestWordSize(size_t num_symbols) { + int word_size = 32; + int bits = detail::SymbolBits(num_symbols); + if (bits <= 16) { + word_size = 16; } - - template - void WriteSymbol(CompressedByteT *buffer, T symbol, size_t offset) { - const int bits_per_byte = 8; - - for (size_t i = 0; i < symbol_bits_; i++) { - size_t byte_idx = ((offset + 1) * symbol_bits_ - (i + 1)) / bits_per_byte; - byte_idx += detail::kPadding; - size_t bit_idx = - ((bits_per_byte + i) - ((offset + 1) * symbol_bits_)) % bits_per_byte; - - if (detail::CheckBit(symbol, i)) { - detail::SetBit(&buffer[byte_idx], bit_idx); - } else { - detail::ClearBit(&buffer[byte_idx], bit_idx); - } - } + if (bits <= 8) { + word_size = 8; } + return word_size; +} -#ifdef __CUDACC__ - __device__ void AtomicWriteSymbol - (CompressedByteT* buffer, uint64_t symbol, size_t offset) { - size_t ibit_start = offset * symbol_bits_; - size_t ibit_end = (offset + 1) * symbol_bits_ - 1; - size_t ibyte_start = ibit_start / 8, ibyte_end = ibit_end / 8; +class CompressedWriter { + CompressedByteT *buffer_ {nullptr}; + int symbol_bits_ {0}; - symbol <<= 7 - ibit_end % 8; - for (ptrdiff_t ibyte = ibyte_end; ibyte >= static_cast(ibyte_start); --ibyte) { - dh::AtomicOrByte(reinterpret_cast(buffer + detail::kPadding), - ibyte, symbol & 0xff); - symbol >>= 8; - } + public: + CompressedWriter (CompressedByteT *buffer, size_t num_symbols) + : buffer_(buffer) { + symbol_bits_ = SmallestWordSize(num_symbols); } -#endif // __CUDACC__ - template - void Write(CompressedByteT *buffer, IterT input_begin, IterT input_end) { - uint64_t tmp = 0; - size_t stored_bits = 0; - const size_t max_stored_bits = 64 - symbol_bits_; - size_t buffer_position = detail::kPadding; - const size_t num_symbols = input_end - input_begin; - for (size_t i = 0; i < num_symbols; i++) { - typename std::iterator_traits::value_type symbol = input_begin[i]; - if (stored_bits > max_stored_bits) { - // Eject only full bytes - size_t tmp_bytes = stored_bits / 8; - for (size_t j = 0; j < tmp_bytes; j++) { - buffer[buffer_position] = static_cast( - tmp >> (stored_bits - (j + 1) * 8)); - buffer_position++; - } - stored_bits -= tmp_bytes * 8; - tmp &= (1 << stored_bits) - 1; - } - // Store symbol - tmp <<= symbol_bits_; - tmp |= symbol; - stored_bits += symbol_bits_; - } + static size_t CalculateBufferSize(size_t num_elements, size_t num_symbols) { + return (SmallestWordSize(num_symbols)/8)*num_elements; + } - // Eject all bytes - int tmp_bytes = - static_cast(std::ceil(static_cast(stored_bits) / 8)); - for (int j = 0; j < tmp_bytes; j++) { - int shift_bits = static_cast(stored_bits) - (j + 1) * 8; - if (shift_bits >= 0) { - buffer[buffer_position] = - static_cast(tmp >> shift_bits); - } else { - buffer[buffer_position] = - static_cast(tmp << std::abs(shift_bits)); - } - buffer_position++; + XGBOOST_DEVICE void Write(size_t idx, uint32_t x) { + if (symbol_bits_ == 8) { + buffer_[idx] = x; + } else if (symbol_bits_ == 16) { + reinterpret_cast(buffer_)[idx] = x; + } else if (symbol_bits_ == 32) { + reinterpret_cast(buffer_)[idx] = x; } } }; -/** - * \brief Read symbols from a bit compressed memory buffer. Usable on device and host. - * - * \author Rory - * \date 7/9/2017 - * - * \tparam T Generic type parameter. - */ -template class CompressedIterator { public: // Type definitions for thrust - typedef CompressedIterator self_type; // NOLINT + typedef CompressedIterator self_type; // NOLINT typedef ptrdiff_t difference_type; // NOLINT - typedef T value_type; // NOLINT + typedef uint32_t value_type; // NOLINT typedef value_type *pointer; // NOLINT - typedef value_type reference; // NOLINT + typedef value_type& reference; // NOLINT private: const CompressedByteT *buffer_ {nullptr}; - size_t symbol_bits_ {0}; - size_t offset_ {0}; + int symbol_bits_ {0}; public: CompressedIterator() = default; CompressedIterator(const CompressedByteT *buffer, size_t num_symbols) : buffer_(buffer) { - symbol_bits_ = detail::SymbolBits(num_symbols); + symbol_bits_ = SmallestWordSize(num_symbols); } - - XGBOOST_DEVICE reference operator*() const { - const int bits_per_byte = 8; - size_t start_bit_idx = ((offset_ + 1) * symbol_bits_ - 1); - size_t start_byte_idx = start_bit_idx / bits_per_byte; - start_byte_idx += detail::kPadding; - - // Read 5 bytes - the maximum we will need - uint64_t tmp = static_cast(buffer_[start_byte_idx - 4]) << 32 | - static_cast(buffer_[start_byte_idx - 3]) << 24 | - static_cast(buffer_[start_byte_idx - 2]) << 16 | - static_cast(buffer_[start_byte_idx - 1]) << 8 | - buffer_[start_byte_idx]; - int bit_shift = - (bits_per_byte - ((offset_ + 1) * symbol_bits_)) % bits_per_byte; - tmp >>= bit_shift; - // Mask off unneeded bits - uint64_t mask = (static_cast(1) << symbol_bits_) - 1; - return static_cast(tmp & mask); - } - - XGBOOST_DEVICE reference operator[](size_t idx) const { - self_type offset = (*this); - offset.offset_ += idx; - return *offset; + XGBOOST_DEVICE value_type operator[](size_t idx) const { + if (symbol_bits_ == 8) { + return buffer_[idx]; + } else if (symbol_bits_ == 16) { + return reinterpret_cast(buffer_)[idx]; + } else { + return reinterpret_cast(buffer_)[idx]; + } } }; } // namespace common diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index 14a1b2bbf172..df5fe1bd66b4 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -27,8 +27,7 @@ void EllpackPage::SetBaseRowId(size_t row_id) { impl_->SetBaseRowId(row_id); } // Bin each input data entry, store the bin indices in compressed form. __global__ void CompressBinEllpackKernel( - common::CompressedBufferWriter wr, - common::CompressedByteT* __restrict__ buffer, // gidx_buffer + common::CompressedWriter writer, const size_t* __restrict__ row_ptrs, // row offset of input data const Entry* __restrict__ entries, // One batch of input data const float* __restrict__ cuts, // HistogramCuts::cut_values_ @@ -72,7 +71,7 @@ __global__ void CompressBinEllpackKernel( bin += cut_rows[feature]; } // Write to gidx buffer. - wr.AtomicWriteSymbol(buffer, bin, (irow + base_row) * row_stride + ifeature); + writer.Write((irow + base_row) * row_stride + ifeature, bin); } // Construct an ELLPACK matrix with the given number of empty rows. @@ -131,21 +130,17 @@ EllpackPageImpl::EllpackPageImpl(DMatrix* dmat, const BatchParam& param) template struct WriteCompressedEllpackFunctor { - WriteCompressedEllpackFunctor(common::CompressedByteT* buffer, - const common::CompressedBufferWriter& writer, + WriteCompressedEllpackFunctor(common::CompressedWriter writer, AdapterBatchT batch, EllpackDeviceAccessor accessor, common::Span feature_types, const data::IsValidFunctor& is_valid) - : d_buffer(buffer), - writer(writer), + : writer(writer), batch(std::move(batch)), accessor(std::move(accessor)), feature_types(std::move(feature_types)), is_valid(is_valid) {} - - common::CompressedByteT* d_buffer; - common::CompressedBufferWriter writer; + common::CompressedWriter writer; AdapterBatchT batch; EllpackDeviceAccessor accessor; common::Span feature_types; @@ -164,7 +159,7 @@ struct WriteCompressedEllpackFunctor { } else { bin_idx = accessor.SearchBin(e.value, e.column_idx); } - writer.AtomicWriteSymbol(d_buffer, bin_idx, output_position); + writer.Write(output_position, bin_idx); } return 0; } @@ -218,12 +213,11 @@ void CopyDataToEllpack(const AdapterBatchT &batch, using Tuple = thrust::tuple; auto device_accessor = dst->GetDeviceAccessor(device_idx); - common::CompressedBufferWriter writer(device_accessor.NumSymbols()); - auto d_compressed_buffer = dst->gidx_buffer.DevicePointer(); + common::CompressedWriter writer(dst->gidx_buffer.DevicePointer(), dst->NumSymbols()); // We redirect the scan output into this functor to do the actual writing - WriteCompressedEllpackFunctor functor( - d_compressed_buffer, writer, batch, device_accessor, feature_types, + WriteCompressedEllpackFunctor functor(writer, + batch, device_accessor, feature_types, is_valid); dh::TypedDiscard discard; thrust::transform_output_iterator< @@ -247,18 +241,15 @@ void CopyDataToEllpack(const AdapterBatchT &batch, void WriteNullValues(EllpackPageImpl* dst, int device_idx, common::Span row_counts) { // Write the null values + common::CompressedWriter writer(dst->gidx_buffer.DevicePointer(), dst->NumSymbols()); auto device_accessor = dst->GetDeviceAccessor(device_idx); - common::CompressedBufferWriter writer(device_accessor.NumSymbols()); - auto d_compressed_buffer = dst->gidx_buffer.DevicePointer(); auto row_stride = dst->row_stride; - dh::LaunchN(row_stride * dst->n_rows, [=] __device__(size_t idx) { + dh::LaunchN(row_stride * dst->n_rows, [=] __device__(size_t idx) mutable { // For some reason this variable got captured as const - auto writer_non_const = writer; size_t row_idx = idx / row_stride; size_t row_offset = idx % row_stride; if (row_offset >= row_counts[row_idx]) { - writer_non_const.AtomicWriteSymbol(d_compressed_buffer, - device_accessor.NullValue(), idx); + writer.Write(idx,device_accessor.NullValue()); } }); } @@ -284,24 +275,6 @@ EllpackPageImpl::EllpackPageImpl(AdapterBatch batch, float missing, int device, ELLPACK_BATCH_SPECIALIZE(data::CudfAdapterBatch) ELLPACK_BATCH_SPECIALIZE(data::CupyAdapterBatch) -// A functor that copies the data from one EllpackPage to another. -struct CopyPage { - common::CompressedBufferWriter cbw; - common::CompressedByteT* dst_data_d; - common::CompressedIterator src_iterator_d; - // The number of elements to skip. - size_t offset; - - CopyPage(EllpackPageImpl *dst, EllpackPageImpl const *src, size_t offset) - : cbw{dst->NumSymbols()}, dst_data_d{dst->gidx_buffer.DevicePointer()}, - src_iterator_d{src->gidx_buffer.DevicePointer(), src->NumSymbols()}, - offset(offset) {} - - __device__ void operator()(size_t element_id) { - cbw.AtomicWriteSymbol(dst_data_d, src_iterator_d[element_id], - element_id + offset); - } -}; // Copy the data from the given EllpackPage to the current page. size_t EllpackPageImpl::Copy(int device, EllpackPageImpl const *page, @@ -315,18 +288,19 @@ size_t EllpackPageImpl::Copy(int device, EllpackPageImpl const *page, LOG(FATAL) << "Concatenating the same Ellpack."; return this->n_rows * this->row_stride; } - gidx_buffer.SetDevice(device); - page->gidx_buffer.SetDevice(device); - dh::LaunchN(num_elements, CopyPage(this, page, offset)); + auto src = page->GetDeviceAccessor(device).gidx_iter; + common::CompressedWriter writer(this->gidx_buffer.DevicePointer(), this->NumSymbols()); + dh::LaunchN(num_elements, [=]__device__ (std::size_t idx) mutable { + writer.Write(offset + idx, src[idx]); + }); monitor_.Stop("Copy"); return num_elements; } // A functor that compacts the rows from one EllpackPage into another. struct CompactPage { - common::CompressedBufferWriter cbw; - common::CompressedByteT* dst_data_d; - common::CompressedIterator src_iterator_d; + common::CompressedWriter writer; + common::CompressedIterator src_iterator_d; /*! \brief An array that maps the rows from the full DMatrix to the compacted * page. * @@ -343,10 +317,10 @@ struct CompactPage { size_t row_stride; CompactPage(EllpackPageImpl* dst, EllpackPageImpl const* src, - common::Span row_indexes) - : cbw{dst->NumSymbols()}, - dst_data_d{dst->gidx_buffer.DevicePointer()}, - src_iterator_d{src->gidx_buffer.DevicePointer(), src->NumSymbols()}, + common::Span row_indexes,int device) + : + writer(dst->gidx_buffer.DevicePointer(), dst->NumSymbols()), + src_iterator_d{src->GetDeviceAccessor(device).gidx_iter}, row_indexes(row_indexes), base_rowid{src->base_rowid}, row_stride{src->row_stride} {} @@ -358,8 +332,7 @@ struct CompactPage { size_t dst_offset = dst_row * row_stride; size_t src_offset = row_id * row_stride; for (size_t j = 0; j < row_stride; j++) { - cbw.AtomicWriteSymbol(dst_data_d, src_iterator_d[src_offset + j], - dst_offset + j); + writer.Write(dst_offset + j, src_iterator_d[src_offset + j]); } } }; @@ -373,7 +346,7 @@ void EllpackPageImpl::Compact(int device, EllpackPageImpl const* page, CHECK_LE(page->base_rowid + page->n_rows, row_indexes.size()); gidx_buffer.SetDevice(device); page->gidx_buffer.SetDevice(device); - dh::LaunchN(page->n_rows, CompactPage(this, page, row_indexes)); + dh::LaunchN(page->n_rows, CompactPage(this, page, row_indexes, device)); monitor_.Stop("Compact"); } @@ -383,7 +356,7 @@ void EllpackPageImpl::InitCompressedData(int device) { // Required buffer size for storing data matrix in ELLPack format. size_t compressed_size_bytes = - common::CompressedBufferWriter::CalculateBufferSize(row_stride * n_rows, + common::CompressedWriter::CalculateBufferSize(row_stride * n_rows, num_symbols); gidx_buffer.SetDevice(device); // Don't call fill unnecessarily @@ -445,13 +418,12 @@ void EllpackPageImpl::CreateHistIndices(int device, const dim3 grid3(common::DivRoundUp(batch_nrows, block3.x), common::DivRoundUp(row_stride, block3.y), 1); auto device_accessor = GetDeviceAccessor(device); - dh::LaunchKernel {grid3, block3}( - CompressBinEllpackKernel, common::CompressedBufferWriter(NumSymbols()), - gidx_buffer.DevicePointer(), row_ptrs.data().get(), - entries_d.data().get(), device_accessor.gidx_fvalue_map.data(), - device_accessor.feature_segments.data(), feature_types, - batch_row_begin, batch_nrows, row_stride, - null_gidx_value); + dh::LaunchKernel{grid3, block3}( + CompressBinEllpackKernel, + common::CompressedWriter(gidx_buffer.DevicePointer(), NumSymbols()), + row_ptrs.data().get(), entries_d.data().get(), device_accessor.gidx_fvalue_map.data(), + device_accessor.feature_segments.data(), feature_types, batch_row_begin, batch_nrows, + row_stride, null_gidx_value); } } @@ -463,12 +435,26 @@ size_t EllpackPageImpl::MemCostBytes(size_t num_rows, size_t row_stride, const common::HistogramCuts& cuts) { // Required buffer size for storing data matrix in EtoLLPack format. size_t compressed_size_bytes = - common::CompressedBufferWriter::CalculateBufferSize(row_stride * num_rows, + common::CompressedWriter::CalculateBufferSize(row_stride * num_rows, cuts.TotalBins() + 1); return compressed_size_bytes; } EllpackDeviceAccessor EllpackPageImpl::GetDeviceAccessor( + int device, common::Span feature_types) { + gidx_buffer.SetDevice(device); + return {device, + cuts_, + is_dense, + row_stride, + base_rowid, + n_rows, + common::CompressedIterator(gidx_buffer.DevicePointer(), + NumSymbols()), + feature_types}; + } + +const EllpackDeviceAccessor EllpackPageImpl::GetDeviceAccessor( int device, common::Span feature_types) const { gidx_buffer.SetDevice(device); return {device, @@ -477,7 +463,7 @@ EllpackDeviceAccessor EllpackPageImpl::GetDeviceAccessor( row_stride, base_rowid, n_rows, - common::CompressedIterator(gidx_buffer.ConstDevicePointer(), + common::CompressedIterator(gidx_buffer.ConstDevicePointer(), NumSymbols()), feature_types}; } diff --git a/src/data/ellpack_page.cuh b/src/data/ellpack_page.cuh index 7a2020c8b0b4..6c1dfa3c0517 100644 --- a/src/data/ellpack_page.cuh +++ b/src/data/ellpack_page.cuh @@ -24,7 +24,7 @@ struct EllpackDeviceAccessor { size_t row_stride; size_t base_rowid{}; size_t n_rows{}; - common::CompressedIterator gidx_iter; + common::CompressedIterator gidx_iter; /*! \brief Minimum value for each feature. Size equals to number of features. */ common::Span min_fvalue; /*! \brief Histogram cut pointers. Size equals to (number of features + 1). */ @@ -36,7 +36,7 @@ struct EllpackDeviceAccessor { EllpackDeviceAccessor(int device, const common::HistogramCuts& cuts, bool is_dense, size_t row_stride, size_t base_rowid, - size_t n_rows,common::CompressedIterator gidx_iter, + size_t n_rows,common::CompressedIterator gidx_iter, common::Span feature_types) : is_dense(is_dense), row_stride(row_stride), @@ -194,7 +194,10 @@ class EllpackPageImpl { EllpackDeviceAccessor GetDeviceAccessor(int device, - common::Span feature_types = {}) const; + common::Span feature_types = {}) ; + const EllpackDeviceAccessor + GetDeviceAccessor(int device, + common::Span feature_types = {}) const ; private: /*! diff --git a/tests/cpp/common/test_compressed_iterator.cc b/tests/cpp/common/test_compressed_iterator.cc index 93243c0b336e..8c7eb3d74130 100644 --- a/tests/cpp/common/test_compressed_iterator.cc +++ b/tests/cpp/common/test_compressed_iterator.cc @@ -4,6 +4,7 @@ namespace xgboost { namespace common { + TEST(CompressedIterator, Test) { ASSERT_TRUE(detail::SymbolBits(256) == 8); ASSERT_TRUE(detail::SymbolBits(150) == 8); @@ -17,36 +18,25 @@ TEST(CompressedIterator, Test) { std::vector input(num_elements); std::generate(input.begin(), input.end(), [=]() { return rand() % alphabet_size; }); - CompressedBufferWriter cbw(alphabet_size); // Test write entire array std::vector buffer( - CompressedBufferWriter::CalculateBufferSize(input.size(), + CompressedWriter::CalculateBufferSize(input.size(), alphabet_size)); + CompressedWriter writer(buffer.data(), alphabet_size); + CompressedIterator iter(buffer.data(), alphabet_size); - cbw.Write(buffer.data(), input.begin(), input.end()); + for (size_t i = 0; i < input.size(); i++) { + writer.Write(i,input[i]); + } - CompressedIterator ci(buffer.data(), alphabet_size); std::vector output(input.size()); for (size_t i = 0; i < input.size(); i++) { - output[i] = ci[i]; + output[i] = iter[i]; } ASSERT_TRUE(input == output); - // Test write Symbol - std::vector buffer2( - CompressedBufferWriter::CalculateBufferSize(input.size(), - alphabet_size)); - for (size_t i = 0; i < input.size(); i++) { - cbw.WriteSymbol(buffer2.data(), input[i], i); - } - CompressedIterator ci2(buffer.data(), alphabet_size); - std::vector output2(input.size()); - for (size_t i = 0; i < input.size(); i++) { - output2[i] = ci2[i]; - } - ASSERT_TRUE(input == output2); } } } diff --git a/tests/cpp/common/test_gpu_compressed_iterator.cu b/tests/cpp/common/test_gpu_compressed_iterator.cu index 779202a62002..4716b7788e90 100644 --- a/tests/cpp/common/test_gpu_compressed_iterator.cu +++ b/tests/cpp/common/test_gpu_compressed_iterator.cu @@ -7,31 +7,7 @@ namespace xgboost { namespace common { -struct WriteSymbolFunction { - CompressedBufferWriter cbw; - unsigned char* buffer_data_d; - int* input_data_d; - WriteSymbolFunction(CompressedBufferWriter cbw, unsigned char* buffer_data_d, - int* input_data_d) - : cbw(cbw), buffer_data_d(buffer_data_d), input_data_d(input_data_d) {} - - __device__ void operator()(size_t i) { - cbw.AtomicWriteSymbol(buffer_data_d, input_data_d[i], i); - } -}; - -struct ReadSymbolFunction { - CompressedIterator ci; - int* output_data_d; - ReadSymbolFunction(CompressedIterator ci, int* output_data_d) - : ci(ci), output_data_d(output_data_d) {} - - __device__ void operator()(size_t i) { - output_data_d[i] = ci[i]; - } -}; - -TEST(CompressedIterator, TestGPU) { +void TestCompressedIterator(){ dh::safe_cuda(cudaSetDevice(0)); std::vector test_cases = {1, 3, 426, 21, 64, 256, 100000, INT32_MAX}; int num_elements = 1000; @@ -41,26 +17,23 @@ TEST(CompressedIterator, TestGPU) { for (auto alphabet_size : test_cases) { for (int i = 0; i < repetitions; i++) { std::vector input(num_elements); - std::generate(input.begin(), input.end(), - [=]() { return rand() % alphabet_size; }); - CompressedBufferWriter cbw(alphabet_size); + std::generate(input.begin(), input.end(), [=]() { return rand() % alphabet_size; }); thrust::device_vector input_d(input); thrust::device_vector buffer_d( - CompressedBufferWriter::CalculateBufferSize(input.size(), - alphabet_size)); + CompressedWriter::CalculateBufferSize(input.size(), alphabet_size)); + CompressedWriter writer(buffer_d.data().get(), alphabet_size); // write the data on device auto input_data_d = input_d.data().get(); - auto buffer_data_d = buffer_d.data().get(); dh::LaunchN(input_d.size(), - WriteSymbolFunction(cbw, buffer_data_d, input_data_d)); + [=] __device__(std::size_t idx) mutable { writer.Write(idx, input_data_d[idx]); }); - // read the data on device - CompressedIterator ci(buffer_d.data().get(), alphabet_size); + CompressedIterator iter(buffer_d.data().get(), alphabet_size); thrust::device_vector output_d(input.size()); auto output_data_d = output_d.data().get(); - dh::LaunchN(output_d.size(), ReadSymbolFunction(ci, output_data_d)); + dh::LaunchN(output_d.size(), + [=] __device__(std::size_t idx) { output_data_d[idx] = iter[idx]; }); std::vector output(output_d.size()); thrust::copy(output_d.begin(), output_d.end(), output.begin()); @@ -70,5 +43,9 @@ TEST(CompressedIterator, TestGPU) { } } +TEST(CompressedIterator, TestGPU) { +TestCompressedIterator(); +} + } // namespace common } // namespace xgboost diff --git a/tests/cpp/data/test_ellpack_page.cu b/tests/cpp/data/test_ellpack_page.cu index a67ab1d59f02..26c7dfc9f925 100644 --- a/tests/cpp/data/test_ellpack_page.cu +++ b/tests/cpp/data/test_ellpack_page.cu @@ -31,7 +31,7 @@ TEST(EllpackPage, BuildGidxDense) { auto page = BuildEllpackPage(kNRows, kNCols); std::vector h_gidx_buffer(page->gidx_buffer.HostVector()); - common::CompressedIterator gidx(h_gidx_buffer.data(), page->NumSymbols()); + common::CompressedIterator gidx(h_gidx_buffer.data(), page->NumSymbols()); ASSERT_EQ(page->row_stride, kNCols); @@ -63,7 +63,7 @@ TEST(EllpackPage, BuildGidxSparse) { auto page = BuildEllpackPage(kNRows, kNCols, 0.9f); std::vector h_gidx_buffer(page->gidx_buffer.HostVector()); - common::CompressedIterator gidx(h_gidx_buffer.data(), 25); + common::CompressedIterator gidx(h_gidx_buffer.data(), 25); ASSERT_LE(page->row_stride, 3); @@ -107,7 +107,7 @@ TEST(EllpackPage, FromCategoricalBasic) { std::vector const &h_gidx_buffer = ellpack.Impl()->gidx_buffer.HostVector(); - auto h_gidx_iter = common::CompressedIterator( + auto h_gidx_iter = common::CompressedIterator( h_gidx_buffer.data(), accessor.NumSymbols()); for (size_t i = 0; i < x.size(); ++i) { diff --git a/tests/cpp/data/test_iterative_dmatrix.cu b/tests/cpp/data/test_iterative_dmatrix.cu index 0a83f7e8c54b..8124cb456984 100644 --- a/tests/cpp/data/test_iterative_dmatrix.cu +++ b/tests/cpp/data/test_iterative_dmatrix.cu @@ -69,9 +69,9 @@ void TestEquivalent(float sparsity) { auto const& buffer_from_data = ellpack.Impl()->gidx_buffer; ASSERT_NE(buffer_from_data.Size(), 0); - common::CompressedIterator data_buf{ + common::CompressedIterator data_buf{ buffer_from_data.ConstHostPointer(), from_data.NumSymbols()}; - common::CompressedIterator data_iter{ + common::CompressedIterator data_iter{ buffer_from_iter.ConstHostPointer(), from_iter.NumSymbols()}; CHECK_EQ(from_data.NumSymbols(), from_iter.NumSymbols()); CHECK_EQ(from_data.n_rows * from_data.row_stride, from_data.n_rows * from_iter.row_stride); @@ -96,7 +96,7 @@ TEST(IterativeDeviceDMatrix, RowMajor) { for (auto& ellpack : m.GetBatches({})) { n_batches ++; auto impl = ellpack.Impl(); - common::CompressedIterator iterator( + common::CompressedIterator iterator( impl->gidx_buffer.HostVector().data(), impl->NumSymbols()); auto cols = CudaArrayIterForTest::Cols(); auto rows = CudaArrayIterForTest::Rows(); @@ -144,7 +144,7 @@ TEST(IterativeDeviceDMatrix, RowMajorMissing) { 0, 256); auto &ellpack = *m.GetBatches({0, 256}).begin(); auto impl = ellpack.Impl(); - common::CompressedIterator iterator( + common::CompressedIterator iterator( impl->gidx_buffer.HostVector().data(), impl->NumSymbols()); EXPECT_EQ(iterator[1], impl->GetDeviceAccessor(0).NullValue()); EXPECT_EQ(iterator[5], impl->GetDeviceAccessor(0).NullValue()); diff --git a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu index 9e8cd19bec74..a21ee93497ab 100644 --- a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu +++ b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu @@ -99,14 +99,14 @@ TEST(GradientBasedSampler, NoSamplingExternalMemory) { EXPECT_EQ(sampled_page->n_rows, kRows); std::vector buffer(sampled_page->gidx_buffer.HostVector()); - common::CompressedIterator + common::CompressedIterator ci(buffer.data(), sampled_page->NumSymbols()); size_t offset = 0; for (auto& batch : dmat->GetBatches(param)) { auto page = batch.Impl(); std::vector page_buffer(page->gidx_buffer.HostVector()); - common::CompressedIterator + common::CompressedIterator page_ci(page_buffer.data(), page->NumSymbols()); size_t num_elements = page->n_rows * page->row_stride; for (size_t i = 0; i < num_elements; i++) { From c93641a476efa8ed28103bf7b4d4dc8bd55ec16d Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Mon, 25 Jul 2022 06:27:41 -0700 Subject: [PATCH 05/13] Benchmarks and 32 bit shmem addition. --- src/tree/gpu_hist/histogram.cu | 12 +-- src/tree/gpu_hist/histogram.cuh | 27 +++++-- src/tree/updater_gpu_hist.cu | 2 +- tests/cpp/tree/gpu_hist/test_histogram.cu | 95 +++++++++++++++++++++++ 4 files changed, 123 insertions(+), 13 deletions(-) diff --git a/src/tree/gpu_hist/histogram.cu b/src/tree/gpu_hist/histogram.cu index 255c4588c8c3..514926878167 100644 --- a/src/tree/gpu_hist/histogram.cu +++ b/src/tree/gpu_hist/histogram.cu @@ -90,7 +90,7 @@ HistRounding CreateRoundingFactor(common::Span CreateRoundingFactor(std::max(positive_sum.GetHess(), negative_sum.GetHess()), gpair.size())}; - using IntT = typename HistRounding::SharedSumT::ValueT; + using IntT = typename HistRounding::GlobalSumT::ValueT; /** * Factor for converting gradients from fixed-point to floating-point. @@ -150,11 +150,11 @@ class HistogramAgent { idx < min(offset + kBlockThreads * kItemsPerTile, n_elements); idx += kBlockThreads) { int ridx = d_ridx[idx / feature_stride]; int gidx = - matrix.gidx_iter[ridx * matrix.row_stride + group.start_feature + idx % feature_stride] - - group.start_bin; + matrix.gidx_iter[ridx * matrix.row_stride + group.start_feature + idx % feature_stride]; if (matrix.is_dense || gidx != matrix.NumBins()) { auto adjusted = rounding.ToFixedPoint(d_gpair[ridx]); - dh::AtomicAddGpair(smem_arr + gidx, adjusted); + //dh::AtomicAddGpair(smem_arr + gidx, adjusted); + AtomicAddGpairWithOverflow(smem_arr + gidx - group.start_bin, adjusted, d_node_hist + gidx, rounding); } } } @@ -184,7 +184,7 @@ class HistogramAgent { for (int i = 0; i < kItemsPerThread; i++) { if ((matrix.is_dense || gidx[i] != matrix.NumBins())) { auto adjusted = rounding.ToFixedPoint(gpair[i]); - dh::AtomicAddGpair(smem_arr + gidx[i] - group.start_bin, adjusted); + AtomicAddGpairWithOverflow(smem_arr + gidx[i] - group.start_bin, adjusted, d_node_hist + gidx[i], rounding); } } } @@ -202,7 +202,7 @@ class HistogramAgent { // Write shared memory back to global memory __syncthreads(); for (auto i : dh::BlockStrideRange(0, group.num_bins)) { - auto truncated = rounding.ToFloatingPoint(smem_arr[i]); + auto truncated = rounding.ToFloatingPoint({smem_arr[i].GetGrad(),smem_arr[i].GetHess()}); dh::AtomicAddGpair(d_node_hist + group.start_bin + i, truncated); } } diff --git a/src/tree/gpu_hist/histogram.cuh b/src/tree/gpu_hist/histogram.cuh index a45083f76875..4c03f09607c7 100644 --- a/src/tree/gpu_hist/histogram.cuh +++ b/src/tree/gpu_hist/histogram.cuh @@ -31,17 +31,16 @@ struct HistRounding { GradientSumT to_floating_point; /* Type used in shared memory. */ - using SharedSumT = std::conditional_t< - std::is_same::value, - GradientPairInt32, GradientPairInt64>; + using SharedSumT = GradientPairInt32; + using GlobalSumT = GradientPairInt64; using T = typename GradientSumT::ValueT; - XGBOOST_DEV_INLINE SharedSumT ToFixedPoint(GradientPair const& gpair) const { - auto adjusted = SharedSumT(T(gpair.GetGrad() * to_fixed_point.GetGrad()), + XGBOOST_DEV_INLINE GlobalSumT ToFixedPoint(GradientPair const& gpair) const { + auto adjusted = GlobalSumT(T(gpair.GetGrad() * to_fixed_point.GetGrad()), T(gpair.GetHess() * to_fixed_point.GetHess())); return adjusted; } - XGBOOST_DEV_INLINE GradientSumT ToFloatingPoint(SharedSumT const &gpair) const { + XGBOOST_DEV_INLINE GradientSumT ToFloatingPoint(GlobalSumT const &gpair) const { auto g = gpair.GetGrad() * to_floating_point.GetGrad(); auto h = gpair.GetHess() * to_floating_point.GetHess(); GradientSumT truncated{ @@ -55,6 +54,22 @@ struct HistRounding { template HistRounding CreateRoundingFactor(common::Span gpair); +XGBOOST_DEV_INLINE void AtomicAddGpairWithOverflow( + xgboost::GradientPairInt32* dst_shared, xgboost::GradientPairInt64 const& gpair, + xgboost::GradientPairPrecise* dst_global, const HistRounding& rounding) { + auto dst_ptr = reinterpret_cast(dst_shared); + int old_grad = atomicAdd(dst_ptr, static_cast(gpair.GetGrad())); + int64_t grad_diff = (old_grad + gpair.GetGrad()) - (old_grad + static_cast(gpair.GetGrad())); + + int old_hess = atomicAdd(dst_ptr + 1, static_cast(gpair.GetHess())); + int64_t hess_diff = (old_hess + gpair.GetHess()) - (old_hess + static_cast(gpair.GetHess())); + + if (grad_diff != 0 || hess_diff != 0) { + auto truncated = rounding.ToFloatingPoint({grad_diff, hess_diff}); + dh::AtomicAddGpair(dst_global, truncated); + } +} + template void BuildGradientHistogram(EllpackDeviceAccessor const& matrix, FeatureGroupsAccessor const& feature_groups, diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 3f3137c58adc..dd35468c6e00 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -232,7 +232,7 @@ struct GPUHistMakerDevice { monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(ctx_->gpu_id)); feature_groups.reset(new FeatureGroups(page->Cuts(), page->is_dense, dh::MaxSharedMemoryOptin(ctx_->gpu_id), - sizeof(GradientSumT))); + sizeof(typename HistRounding::SharedSumT))); } ~GPUHistMakerDevice() { // NOLINT diff --git a/tests/cpp/tree/gpu_hist/test_histogram.cu b/tests/cpp/tree/gpu_hist/test_histogram.cu index a24bc1e1fef7..7f7457e9bc1b 100644 --- a/tests/cpp/tree/gpu_hist/test_histogram.cu +++ b/tests/cpp/tree/gpu_hist/test_histogram.cu @@ -1,5 +1,7 @@ #include #include +#include +#include #include "../../../../src/common/categorical.h" #include "../../../../src/tree/gpu_hist/histogram.cuh" @@ -156,5 +158,98 @@ TEST(Histogram, GPUHistCategorical) { TestGPUHistogramCategorical(num_categories); } } + +void RunBenchmark(std::string branch, std::string name, std::size_t kCols, std::size_t kRows, float sparsity, int max_depth){ + size_t constexpr kBins = 256; + float constexpr kLower = -1e-2, kUpper = 1e2; + auto matrix = RandomDataGenerator(kRows, kCols, sparsity).GenerateDMatrix(); + BatchParam batch_param{0, static_cast(kBins)}; + int num_bins = kBins * kCols; + dh::device_vector histogram(num_bins); + auto d_histogram = dh::ToSpan(histogram); + auto gpair = GenerateRandomGradients(kRows, kLower, kUpper); + gpair.SetDevice(0); + auto rounding = CreateRoundingFactor(gpair.DeviceSpan()); + thrust::device_vector ridx(kRows); + for (auto const& batch : matrix->GetBatches(batch_param)) { + auto* page = batch.Impl(); + FeatureGroups feature_groups(page->Cuts(), page->is_dense, dh::MaxSharedMemoryOptin(0), + sizeof(GradientPairPrecise)); + for (int depth = 1; depth <= max_depth; depth++) { + std::size_t depth_rows = kRows / (1 << (depth - 1)); + std::cout << branch << ',' << name << ',' << depth << ',' << depth_rows << ',' << kCols + << ','; + thrust::shuffle(ridx.begin(), ridx.end(), thrust::default_random_engine(depth)); + auto d_ridx = dh::ToSpan(ridx).subspan(0, depth_rows); + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + cudaEventRecord(start); + BuildGradientHistogram(page->GetDeviceAccessor(0), feature_groups.DeviceAccessor(0), + gpair.DeviceSpan(), d_ridx, d_histogram, rounding); + cudaEventRecord(stop); + cudaEventSynchronize(stop); + float milliseconds = 0; + cudaEventElapsedTime(&milliseconds, start, stop); + + std::cout << milliseconds << ',' << ((depth_rows * kCols) / milliseconds)/1000 << '\n'; + } + } +} + +TEST(Histogram, Benchmark) { + float sparsity = 0.0f; + std::string branch("overflow"); + std::cout << "branch,dataset,depth,rows,cols,time(ms),Million elements/s\n"; + int max_depth = 8; + RunBenchmark(branch, "epsilon", 2000, 500000, sparsity, max_depth); + RunBenchmark(branch, "higgs", 32, 10000000, sparsity, max_depth); + RunBenchmark(branch, "airline", 13, 115000000, sparsity, max_depth); + RunBenchmark(branch, "year", 90, 515345, sparsity, max_depth); +} + + +void TestAtomicAddWithOverflow() { + thrust::device_vector histogram(2); + thrust::device_vector gpair = std::vector{{1.0, 1.0}, {-0.01, 0.1}, {0.02, 0.1}, {-2.0, 1.0}}; + auto d_gpair = dh::ToSpan(gpair); + auto rounding = CreateRoundingFactor(d_gpair); + auto d_histogram = histogram.data().get(); + dh::LaunchN(gpair.size(), [=] __device__(int idx) { + __shared__ char shared[sizeof(GradientPairInt32)]; + auto shared_histogram = reinterpret_cast(shared); + if (idx == 0) { + shared_histogram[0] = GradientPairInt32(); + } + + // Global memory version + GradientPairPrecise truncated{ + TruncateWithRoundingFactor(rounding.rounding.GetGrad(), + d_gpair[idx].GetGrad()), + TruncateWithRoundingFactor(rounding.rounding.GetHess(), + d_gpair[idx].GetHess()), + }; + dh::AtomicAddGpair(d_histogram, truncated); + + // Reduced precision shared memory version + auto adjusted = rounding.ToFixedPoint(d_gpair[idx]); + AtomicAddGpairWithOverflow(shared_histogram, adjusted, d_histogram + 1, rounding); + // First thread copies shared back to global + if (idx == 0) { + dh::AtomicAddGpair(d_histogram + 1, rounding.ToFloatingPoint(GradientPairInt64{shared_histogram[idx].GetGrad(),shared_histogram[idx].GetHess()})); + } + }); + + GradientPairPrecise global = histogram[0]; + GradientPairPrecise shared = histogram[1]; + ASSERT_EQ(global.GetGrad(), shared.GetGrad()); + ASSERT_EQ(global.GetHess(), shared.GetHess()); +} + +TEST(Histogram, AtomicAddWithOverflow) { +TestAtomicAddWithOverflow(); +} + } // namespace tree } // namespace xgboost From 7d34f7896dd06cbd1338948e0047093695e454bb Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Tue, 26 Jul 2022 07:28:00 -0700 Subject: [PATCH 06/13] Revert "Benchmarks and 32 bit shmem addition." This reverts commit c93641a476efa8ed28103bf7b4d4dc8bd55ec16d. --- src/tree/gpu_hist/histogram.cu | 12 +-- src/tree/gpu_hist/histogram.cuh | 27 ++----- src/tree/updater_gpu_hist.cu | 2 +- tests/cpp/tree/gpu_hist/test_histogram.cu | 95 ----------------------- 4 files changed, 13 insertions(+), 123 deletions(-) diff --git a/src/tree/gpu_hist/histogram.cu b/src/tree/gpu_hist/histogram.cu index 514926878167..255c4588c8c3 100644 --- a/src/tree/gpu_hist/histogram.cu +++ b/src/tree/gpu_hist/histogram.cu @@ -90,7 +90,7 @@ HistRounding CreateRoundingFactor(common::Span CreateRoundingFactor(std::max(positive_sum.GetHess(), negative_sum.GetHess()), gpair.size())}; - using IntT = typename HistRounding::GlobalSumT::ValueT; + using IntT = typename HistRounding::SharedSumT::ValueT; /** * Factor for converting gradients from fixed-point to floating-point. @@ -150,11 +150,11 @@ class HistogramAgent { idx < min(offset + kBlockThreads * kItemsPerTile, n_elements); idx += kBlockThreads) { int ridx = d_ridx[idx / feature_stride]; int gidx = - matrix.gidx_iter[ridx * matrix.row_stride + group.start_feature + idx % feature_stride]; + matrix.gidx_iter[ridx * matrix.row_stride + group.start_feature + idx % feature_stride] - + group.start_bin; if (matrix.is_dense || gidx != matrix.NumBins()) { auto adjusted = rounding.ToFixedPoint(d_gpair[ridx]); - //dh::AtomicAddGpair(smem_arr + gidx, adjusted); - AtomicAddGpairWithOverflow(smem_arr + gidx - group.start_bin, adjusted, d_node_hist + gidx, rounding); + dh::AtomicAddGpair(smem_arr + gidx, adjusted); } } } @@ -184,7 +184,7 @@ class HistogramAgent { for (int i = 0; i < kItemsPerThread; i++) { if ((matrix.is_dense || gidx[i] != matrix.NumBins())) { auto adjusted = rounding.ToFixedPoint(gpair[i]); - AtomicAddGpairWithOverflow(smem_arr + gidx[i] - group.start_bin, adjusted, d_node_hist + gidx[i], rounding); + dh::AtomicAddGpair(smem_arr + gidx[i] - group.start_bin, adjusted); } } } @@ -202,7 +202,7 @@ class HistogramAgent { // Write shared memory back to global memory __syncthreads(); for (auto i : dh::BlockStrideRange(0, group.num_bins)) { - auto truncated = rounding.ToFloatingPoint({smem_arr[i].GetGrad(),smem_arr[i].GetHess()}); + auto truncated = rounding.ToFloatingPoint(smem_arr[i]); dh::AtomicAddGpair(d_node_hist + group.start_bin + i, truncated); } } diff --git a/src/tree/gpu_hist/histogram.cuh b/src/tree/gpu_hist/histogram.cuh index 4c03f09607c7..a45083f76875 100644 --- a/src/tree/gpu_hist/histogram.cuh +++ b/src/tree/gpu_hist/histogram.cuh @@ -31,16 +31,17 @@ struct HistRounding { GradientSumT to_floating_point; /* Type used in shared memory. */ - using SharedSumT = GradientPairInt32; - using GlobalSumT = GradientPairInt64; + using SharedSumT = std::conditional_t< + std::is_same::value, + GradientPairInt32, GradientPairInt64>; using T = typename GradientSumT::ValueT; - XGBOOST_DEV_INLINE GlobalSumT ToFixedPoint(GradientPair const& gpair) const { - auto adjusted = GlobalSumT(T(gpair.GetGrad() * to_fixed_point.GetGrad()), + XGBOOST_DEV_INLINE SharedSumT ToFixedPoint(GradientPair const& gpair) const { + auto adjusted = SharedSumT(T(gpair.GetGrad() * to_fixed_point.GetGrad()), T(gpair.GetHess() * to_fixed_point.GetHess())); return adjusted; } - XGBOOST_DEV_INLINE GradientSumT ToFloatingPoint(GlobalSumT const &gpair) const { + XGBOOST_DEV_INLINE GradientSumT ToFloatingPoint(SharedSumT const &gpair) const { auto g = gpair.GetGrad() * to_floating_point.GetGrad(); auto h = gpair.GetHess() * to_floating_point.GetHess(); GradientSumT truncated{ @@ -54,22 +55,6 @@ struct HistRounding { template HistRounding CreateRoundingFactor(common::Span gpair); -XGBOOST_DEV_INLINE void AtomicAddGpairWithOverflow( - xgboost::GradientPairInt32* dst_shared, xgboost::GradientPairInt64 const& gpair, - xgboost::GradientPairPrecise* dst_global, const HistRounding& rounding) { - auto dst_ptr = reinterpret_cast(dst_shared); - int old_grad = atomicAdd(dst_ptr, static_cast(gpair.GetGrad())); - int64_t grad_diff = (old_grad + gpair.GetGrad()) - (old_grad + static_cast(gpair.GetGrad())); - - int old_hess = atomicAdd(dst_ptr + 1, static_cast(gpair.GetHess())); - int64_t hess_diff = (old_hess + gpair.GetHess()) - (old_hess + static_cast(gpair.GetHess())); - - if (grad_diff != 0 || hess_diff != 0) { - auto truncated = rounding.ToFloatingPoint({grad_diff, hess_diff}); - dh::AtomicAddGpair(dst_global, truncated); - } -} - template void BuildGradientHistogram(EllpackDeviceAccessor const& matrix, FeatureGroupsAccessor const& feature_groups, diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index dd35468c6e00..3f3137c58adc 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -232,7 +232,7 @@ struct GPUHistMakerDevice { monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(ctx_->gpu_id)); feature_groups.reset(new FeatureGroups(page->Cuts(), page->is_dense, dh::MaxSharedMemoryOptin(ctx_->gpu_id), - sizeof(typename HistRounding::SharedSumT))); + sizeof(GradientSumT))); } ~GPUHistMakerDevice() { // NOLINT diff --git a/tests/cpp/tree/gpu_hist/test_histogram.cu b/tests/cpp/tree/gpu_hist/test_histogram.cu index 7f7457e9bc1b..a24bc1e1fef7 100644 --- a/tests/cpp/tree/gpu_hist/test_histogram.cu +++ b/tests/cpp/tree/gpu_hist/test_histogram.cu @@ -1,7 +1,5 @@ #include #include -#include -#include #include "../../../../src/common/categorical.h" #include "../../../../src/tree/gpu_hist/histogram.cuh" @@ -158,98 +156,5 @@ TEST(Histogram, GPUHistCategorical) { TestGPUHistogramCategorical(num_categories); } } - -void RunBenchmark(std::string branch, std::string name, std::size_t kCols, std::size_t kRows, float sparsity, int max_depth){ - size_t constexpr kBins = 256; - float constexpr kLower = -1e-2, kUpper = 1e2; - auto matrix = RandomDataGenerator(kRows, kCols, sparsity).GenerateDMatrix(); - BatchParam batch_param{0, static_cast(kBins)}; - int num_bins = kBins * kCols; - dh::device_vector histogram(num_bins); - auto d_histogram = dh::ToSpan(histogram); - auto gpair = GenerateRandomGradients(kRows, kLower, kUpper); - gpair.SetDevice(0); - auto rounding = CreateRoundingFactor(gpair.DeviceSpan()); - thrust::device_vector ridx(kRows); - for (auto const& batch : matrix->GetBatches(batch_param)) { - auto* page = batch.Impl(); - FeatureGroups feature_groups(page->Cuts(), page->is_dense, dh::MaxSharedMemoryOptin(0), - sizeof(GradientPairPrecise)); - for (int depth = 1; depth <= max_depth; depth++) { - std::size_t depth_rows = kRows / (1 << (depth - 1)); - std::cout << branch << ',' << name << ',' << depth << ',' << depth_rows << ',' << kCols - << ','; - thrust::shuffle(ridx.begin(), ridx.end(), thrust::default_random_engine(depth)); - auto d_ridx = dh::ToSpan(ridx).subspan(0, depth_rows); - cudaEvent_t start, stop; - cudaEventCreate(&start); - cudaEventCreate(&stop); - - cudaEventRecord(start); - BuildGradientHistogram(page->GetDeviceAccessor(0), feature_groups.DeviceAccessor(0), - gpair.DeviceSpan(), d_ridx, d_histogram, rounding); - cudaEventRecord(stop); - cudaEventSynchronize(stop); - float milliseconds = 0; - cudaEventElapsedTime(&milliseconds, start, stop); - - std::cout << milliseconds << ',' << ((depth_rows * kCols) / milliseconds)/1000 << '\n'; - } - } -} - -TEST(Histogram, Benchmark) { - float sparsity = 0.0f; - std::string branch("overflow"); - std::cout << "branch,dataset,depth,rows,cols,time(ms),Million elements/s\n"; - int max_depth = 8; - RunBenchmark(branch, "epsilon", 2000, 500000, sparsity, max_depth); - RunBenchmark(branch, "higgs", 32, 10000000, sparsity, max_depth); - RunBenchmark(branch, "airline", 13, 115000000, sparsity, max_depth); - RunBenchmark(branch, "year", 90, 515345, sparsity, max_depth); -} - - -void TestAtomicAddWithOverflow() { - thrust::device_vector histogram(2); - thrust::device_vector gpair = std::vector{{1.0, 1.0}, {-0.01, 0.1}, {0.02, 0.1}, {-2.0, 1.0}}; - auto d_gpair = dh::ToSpan(gpair); - auto rounding = CreateRoundingFactor(d_gpair); - auto d_histogram = histogram.data().get(); - dh::LaunchN(gpair.size(), [=] __device__(int idx) { - __shared__ char shared[sizeof(GradientPairInt32)]; - auto shared_histogram = reinterpret_cast(shared); - if (idx == 0) { - shared_histogram[0] = GradientPairInt32(); - } - - // Global memory version - GradientPairPrecise truncated{ - TruncateWithRoundingFactor(rounding.rounding.GetGrad(), - d_gpair[idx].GetGrad()), - TruncateWithRoundingFactor(rounding.rounding.GetHess(), - d_gpair[idx].GetHess()), - }; - dh::AtomicAddGpair(d_histogram, truncated); - - // Reduced precision shared memory version - auto adjusted = rounding.ToFixedPoint(d_gpair[idx]); - AtomicAddGpairWithOverflow(shared_histogram, adjusted, d_histogram + 1, rounding); - // First thread copies shared back to global - if (idx == 0) { - dh::AtomicAddGpair(d_histogram + 1, rounding.ToFloatingPoint(GradientPairInt64{shared_histogram[idx].GetGrad(),shared_histogram[idx].GetHess()})); - } - }); - - GradientPairPrecise global = histogram[0]; - GradientPairPrecise shared = histogram[1]; - ASSERT_EQ(global.GetGrad(), shared.GetGrad()); - ASSERT_EQ(global.GetHess(), shared.GetHess()); -} - -TEST(Histogram, AtomicAddWithOverflow) { -TestAtomicAddWithOverflow(); -} - } // namespace tree } // namespace xgboost From ab4894ab3efa8b1ca48778094facc663afe7c619 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Tue, 26 Jul 2022 08:59:35 -0700 Subject: [PATCH 07/13] Lint --- src/common/compressed_iterator.h | 3 +-- src/data/ellpack_page.cu | 20 +++++++++----------- src/tree/gpu_hist/histogram.cu | 1 - 3 files changed, 10 insertions(+), 14 deletions(-) diff --git a/src/common/compressed_iterator.h b/src/common/compressed_iterator.h index 1aff259cb145..b48941687b7b 100644 --- a/src/common/compressed_iterator.h +++ b/src/common/compressed_iterator.h @@ -44,8 +44,7 @@ class CompressedWriter { int symbol_bits_ {0}; public: - CompressedWriter (CompressedByteT *buffer, size_t num_symbols) - : buffer_(buffer) { + CompressedWriter(CompressedByteT *buffer, size_t num_symbols) : buffer_(buffer) { symbol_bits_ = SmallestWordSize(num_symbols); } diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index df5fe1bd66b4..355dad5a9455 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -249,7 +249,7 @@ void WriteNullValues(EllpackPageImpl* dst, int device_idx, size_t row_idx = idx / row_stride; size_t row_offset = idx % row_stride; if (row_offset >= row_counts[row_idx]) { - writer.Write(idx,device_accessor.NullValue()); + writer.Write(idx, device_accessor.NullValue()); } }); } @@ -290,9 +290,8 @@ size_t EllpackPageImpl::Copy(int device, EllpackPageImpl const *page, } auto src = page->GetDeviceAccessor(device).gidx_iter; common::CompressedWriter writer(this->gidx_buffer.DevicePointer(), this->NumSymbols()); - dh::LaunchN(num_elements, [=]__device__ (std::size_t idx) mutable { - writer.Write(offset + idx, src[idx]); - }); + dh::LaunchN(num_elements, + [=] __device__(std::size_t idx) mutable { writer.Write(offset + idx, src[idx]); }); monitor_.Stop("Copy"); return num_elements; } @@ -316,10 +315,9 @@ struct CompactPage { size_t base_rowid; size_t row_stride; - CompactPage(EllpackPageImpl* dst, EllpackPageImpl const* src, - common::Span row_indexes,int device) - : - writer(dst->gidx_buffer.DevicePointer(), dst->NumSymbols()), + CompactPage(EllpackPageImpl* dst, EllpackPageImpl const* src, common::Span row_indexes, + int device) + : writer(dst->gidx_buffer.DevicePointer(), dst->NumSymbols()), src_iterator_d{src->GetDeviceAccessor(device).gidx_iter}, row_indexes(row_indexes), base_rowid{src->base_rowid}, @@ -418,10 +416,10 @@ void EllpackPageImpl::CreateHistIndices(int device, const dim3 grid3(common::DivRoundUp(batch_nrows, block3.x), common::DivRoundUp(row_stride, block3.y), 1); auto device_accessor = GetDeviceAccessor(device); - dh::LaunchKernel{grid3, block3}( + dh::LaunchKernel {grid3, block3}( CompressBinEllpackKernel, - common::CompressedWriter(gidx_buffer.DevicePointer(), NumSymbols()), - row_ptrs.data().get(), entries_d.data().get(), device_accessor.gidx_fvalue_map.data(), + common::CompressedWriter(gidx_buffer.DevicePointer(), NumSymbols()), row_ptrs.data().get(), + entries_d.data().get(), device_accessor.gidx_fvalue_map.data(), device_accessor.feature_segments.data(), feature_types, batch_row_begin, batch_nrows, row_stride, null_gidx_value); } diff --git a/src/tree/gpu_hist/histogram.cu b/src/tree/gpu_hist/histogram.cu index 255c4588c8c3..bd0523f42a48 100644 --- a/src/tree/gpu_hist/histogram.cu +++ b/src/tree/gpu_hist/histogram.cu @@ -315,7 +315,6 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix, runit(SharedMemHistKernel); } - dh::safe_cuda(cudaDeviceSynchronize()); dh::safe_cuda(cudaGetLastError()); } From 0d3480ed56682ba1d369c30c5e82c0c1a100ba17 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Tue, 2 Aug 2022 02:24:05 -0700 Subject: [PATCH 08/13] Clang-tidy --- src/tree/gpu_hist/histogram.cu | 94 +++++++++++++++++----------------- 1 file changed, 48 insertions(+), 46 deletions(-) diff --git a/src/tree/gpu_hist/histogram.cu b/src/tree/gpu_hist/histogram.cu index bd0523f42a48..2093634051d9 100644 --- a/src/tree/gpu_hist/histogram.cu +++ b/src/tree/gpu_hist/histogram.cu @@ -121,40 +121,41 @@ template class HistogramAgent { using SharedSumT = typename HistRounding::SharedSumT; - SharedSumT* smem_arr; - GradientSumT* d_node_hist; - dh::LDGIterator d_ridx; - const GradientPair* d_gpair; - const FeatureGroup group; - const EllpackDeviceAccessor& matrix; - const int feature_stride; - const std::size_t n_elements; - const HistRounding& rounding; + SharedSumT* smem_arr_; + GradientSumT* d_node_hist_; + dh::LDGIterator d_ridx_; + const GradientPair* d_gpair_; + const FeatureGroup group_; + const EllpackDeviceAccessor& matrix_; + const int feature_stride_; + const std::size_t n_elements_; + const HistRounding& rounding_; public: __device__ HistogramAgent(SharedSumT* smem_arr, GradientSumT* __restrict__ d_node_hist, const FeatureGroup& group, const EllpackDeviceAccessor& matrix, common::Span d_ridx, const HistRounding& rounding, const GradientPair* d_gpair) - : smem_arr(smem_arr), - d_node_hist(d_node_hist), - d_ridx(d_ridx.data()), - group(group), - matrix(matrix), - feature_stride(matrix.is_dense ? group.num_features : matrix.row_stride), - n_elements(feature_stride * d_ridx.size()), - rounding(rounding), - d_gpair(d_gpair) {} + : smem_arr_(smem_arr), + d_node_hist_(d_node_hist), + d_ridx_(d_ridx.data()), + group_(group), + matrix_(matrix), + feature_stride_(matrix.is_dense ? group.num_features : matrix.row_stride), + n_elements_(feature_stride_ * d_ridx.size()), + rounding_(rounding), + d_gpair_(d_gpair) {} __device__ void ProcessPartialTileShared(std::size_t offset) { for (std::size_t idx = offset + threadIdx.x; - idx < min(offset + kBlockThreads * kItemsPerTile, n_elements); idx += kBlockThreads) { - int ridx = d_ridx[idx / feature_stride]; + idx < min(offset + kBlockThreads * kItemsPerTile, n_elements_); idx += kBlockThreads) { + int ridx = d_ridx_[idx / feature_stride_]; int gidx = - matrix.gidx_iter[ridx * matrix.row_stride + group.start_feature + idx % feature_stride] - - group.start_bin; - if (matrix.is_dense || gidx != matrix.NumBins()) { - auto adjusted = rounding.ToFixedPoint(d_gpair[ridx]); - dh::AtomicAddGpair(smem_arr + gidx, adjusted); + matrix_ + .gidx_iter[ridx * matrix_.row_stride + group_.start_feature + idx % feature_stride_] - + group_.start_bin; + if (matrix_.is_dense || gidx != matrix_.NumBins()) { + auto adjusted = rounding_.ToFixedPoint(d_gpair_[ridx]); + dh::AtomicAddGpair(smem_arr_ + gidx, adjusted); } } } @@ -172,28 +173,28 @@ class HistogramAgent { } #pragma unroll for (int i = 0; i < kItemsPerThread; i++) { - ridx[i] = d_ridx[idx[i] / feature_stride]; + ridx[i] = d_ridx_[idx[i] / feature_stride_]; } #pragma unroll for (int i = 0; i < kItemsPerThread; i++) { - gpair[i] = d_gpair[ridx[i]]; - gidx[i] = matrix.gidx_iter[ridx[i] * matrix.row_stride + group.start_feature + - idx[i] % feature_stride]; + gpair[i] = d_gpair_[ridx[i]]; + gidx[i] = matrix_.gidx_iter[ridx[i] * matrix_.row_stride + group_.start_feature + + idx[i] % feature_stride_]; } #pragma unroll for (int i = 0; i < kItemsPerThread; i++) { - if ((matrix.is_dense || gidx[i] != matrix.NumBins())) { - auto adjusted = rounding.ToFixedPoint(gpair[i]); - dh::AtomicAddGpair(smem_arr + gidx[i] - group.start_bin, adjusted); + if ((matrix_.is_dense || gidx[i] != matrix_.NumBins())) { + auto adjusted = rounding_.ToFixedPoint(gpair[i]); + dh::AtomicAddGpair(smem_arr_ + gidx[i] - group_.start_bin, adjusted); } } } __device__ void BuildHistogramWithShared() { - dh::BlockFill(smem_arr, group.num_bins, SharedSumT()); + dh::BlockFill(smem_arr_, group_.num_bins, SharedSumT()); __syncthreads(); std::size_t offset = blockIdx.x * kItemsPerTile; - while (offset + kItemsPerTile <= n_elements) { + while (offset + kItemsPerTile <= n_elements_) { ProcessFullTileShared(offset); offset += kItemsPerTile * gridDim.x; } @@ -201,27 +202,28 @@ class HistogramAgent { // Write shared memory back to global memory __syncthreads(); - for (auto i : dh::BlockStrideRange(0, group.num_bins)) { - auto truncated = rounding.ToFloatingPoint(smem_arr[i]); - dh::AtomicAddGpair(d_node_hist + group.start_bin + i, truncated); + for (auto i : dh::BlockStrideRange(0, group_.num_bins)) { + auto truncated = rounding_.ToFloatingPoint(smem_arr_[i]); + dh::AtomicAddGpair(d_node_hist_ + group_.start_bin + i, truncated); } } __device__ void BuildHistogramWithGlobal() { - for (auto idx : dh::GridStrideRange(static_cast(0), n_elements)) { - int ridx = d_ridx[idx / feature_stride]; + for (auto idx : dh::GridStrideRange(static_cast(0), n_elements_)) { + int ridx = d_ridx_[idx / feature_stride_]; int gidx = - matrix.gidx_iter[ridx * matrix.row_stride + group.start_feature + idx % feature_stride]; - if (matrix.is_dense || gidx != matrix.NumBins()) { + matrix_ + .gidx_iter[ridx * matrix_.row_stride + group_.start_feature + idx % feature_stride_]; + if (matrix_.is_dense || gidx != matrix_.NumBins()) { // If we are not using shared memory, accumulate the values directly into // global memory GradientSumT truncated{ - TruncateWithRoundingFactor(rounding.rounding.GetGrad(), - d_gpair[ridx].GetGrad()), - TruncateWithRoundingFactor(rounding.rounding.GetHess(), - d_gpair[ridx].GetHess()), + TruncateWithRoundingFactor(rounding_.rounding.GetGrad(), + d_gpair_[ridx].GetGrad()), + TruncateWithRoundingFactor(rounding_.rounding.GetHess(), + d_gpair_[ridx].GetHess()), }; - dh::AtomicAddGpair(d_node_hist + gidx, truncated); + dh::AtomicAddGpair(d_node_hist_ + gidx, truncated); } } } From e6220268858e3fbeba91b6b8046f2de90652de52 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Tue, 2 Aug 2022 07:56:23 -0700 Subject: [PATCH 09/13] Fix test --- tests/cpp/data/test_ellpack_page.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/data/test_ellpack_page.cu b/tests/cpp/data/test_ellpack_page.cu index 26c7dfc9f925..f7e4f0997e0b 100644 --- a/tests/cpp/data/test_ellpack_page.cu +++ b/tests/cpp/data/test_ellpack_page.cu @@ -23,7 +23,7 @@ TEST(EllpackPage, EmptyDMatrix) { auto impl = page.Impl(); ASSERT_EQ(impl->row_stride, 0); ASSERT_EQ(impl->Cuts().TotalBins(), 0); - ASSERT_EQ(impl->gidx_buffer.Size(), 4); + ASSERT_EQ(impl->gidx_buffer.Size(), 0); } TEST(EllpackPage, BuildGidxDense) { From 08cc82d3bb5cc0eadf6228ec18cd760c51552a7d Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Wed, 3 Aug 2022 04:01:08 -0700 Subject: [PATCH 10/13] More test fixes. --- src/data/ellpack_page_raw_format.cu | 1 - tests/python-gpu/test_large_input.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/data/ellpack_page_raw_format.cu b/src/data/ellpack_page_raw_format.cu index 2f54b91c9bbc..9cc4ecb76c5b 100644 --- a/src/data/ellpack_page_raw_format.cu +++ b/src/data/ellpack_page_raw_format.cu @@ -41,7 +41,6 @@ class EllpackPageRawFormat : public SparsePageFormat { bytes += sizeof(impl->is_dense); fo->Write(impl->row_stride); bytes += sizeof(impl->row_stride); - CHECK(!impl->gidx_buffer.ConstHostVector().empty()); fo->Write(impl->gidx_buffer.HostVector()); bytes += impl->gidx_buffer.ConstHostSpan().size_bytes() + sizeof(uint64_t); fo->Write(impl->base_rowid); diff --git a/tests/python-gpu/test_large_input.py b/tests/python-gpu/test_large_input.py index 4c8e06a6f6a5..d21e364f4e65 100644 --- a/tests/python-gpu/test_large_input.py +++ b/tests/python-gpu/test_large_input.py @@ -9,7 +9,7 @@ def test_large_input(): available_bytes, _ = cp.cuda.runtime.memGetInfo() # 15 GB - required_bytes = 1.5e+10 + required_bytes = 1.6e+10 if available_bytes < required_bytes: pytest.skip("Not enough memory on this device") n = 1000 From f3568684a8f33da1c30251e2cd2f8b08889c099a Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Thu, 11 Aug 2022 03:45:26 -0700 Subject: [PATCH 11/13] Revert "Use aligned loads in compressed iterator." This reverts commit 4058e788649dd054065f492be3f5810b1dbfac32. --- src/common/compressed_iterator.h | 203 ++++++++++++++---- src/data/ellpack_page.cu | 106 +++++---- src/data/ellpack_page.cuh | 9 +- tests/cpp/common/test_compressed_iterator.cc | 26 ++- .../common/test_gpu_compressed_iterator.cu | 47 ++-- tests/cpp/data/test_ellpack_page.cu | 6 +- tests/cpp/data/test_iterative_dmatrix.cu | 8 +- .../gpu_hist/test_gradient_based_sampler.cu | 4 +- 8 files changed, 291 insertions(+), 118 deletions(-) diff --git a/src/common/compressed_iterator.h b/src/common/compressed_iterator.h index b48941687b7b..9f60722fb982 100644 --- a/src/common/compressed_iterator.h +++ b/src/common/compressed_iterator.h @@ -20,6 +20,19 @@ namespace common { using CompressedByteT = unsigned char; namespace detail { +inline void SetBit(CompressedByteT *byte, int bit_idx) { + *byte |= 1 << bit_idx; +} +template +inline T CheckBit(const T &byte, int bit_idx) { + return byte & (1 << bit_idx); +} +inline void ClearBit(CompressedByteT *byte, int bit_idx) { + *byte &= ~(1 << bit_idx); +} +static const int kPadding = 4; // Assign padding so we can read slightly off + // the beginning of the array + // The number of bits required to represent a given unsigned range inline XGBOOST_DEVICE size_t SymbolBits(size_t num_symbols) { auto bits = std::ceil(log2(static_cast(num_symbols))); @@ -27,69 +40,183 @@ inline XGBOOST_DEVICE size_t SymbolBits(size_t num_symbols) { } } // namespace detail -inline XGBOOST_DEVICE int SmallestWordSize(size_t num_symbols) { - int word_size = 32; - int bits = detail::SymbolBits(num_symbols); - if (bits <= 16) { - word_size = 16; - } - if (bits <= 8) { - word_size = 8; - } - return word_size; -} +/** + * \class CompressedBufferWriter + * + * \brief Writes bit compressed symbols to a memory buffer. Use + * CompressedIterator to read symbols back from buffer. Currently limited to a + * maximum symbol size of 28 bits. + * + * \author Rory + * \date 7/9/2017 + */ -class CompressedWriter { - CompressedByteT *buffer_ {nullptr}; - int symbol_bits_ {0}; +class CompressedBufferWriter { + size_t symbol_bits_; public: - CompressedWriter(CompressedByteT *buffer, size_t num_symbols) : buffer_(buffer) { - symbol_bits_ = SmallestWordSize(num_symbols); - } + XGBOOST_DEVICE explicit CompressedBufferWriter(size_t num_symbols) + : symbol_bits_(detail::SymbolBits(num_symbols)) {} + /** + * \fn static size_t CompressedBufferWriter::CalculateBufferSize(int + * num_elements, int num_symbols) + * + * \brief Calculates number of bytes required for a given number of elements + * and a symbol range. + * + * \author Rory + * \date 7/9/2017 + * + * \param num_elements Number of elements. + * \param num_symbols Max number of symbols (alphabet size) + * + * \return The calculated buffer size. + */ static size_t CalculateBufferSize(size_t num_elements, size_t num_symbols) { - return (SmallestWordSize(num_symbols)/8)*num_elements; + constexpr int kBitsPerByte = 8; + size_t compressed_size = static_cast(std::ceil( + static_cast(detail::SymbolBits(num_symbols) * num_elements) / + kBitsPerByte)); + // Handle atomicOr where input must be unsigned int, hence 4 bytes aligned. + size_t ret = + std::ceil(static_cast(compressed_size + detail::kPadding) / + static_cast(sizeof(unsigned int))) * + sizeof(unsigned int); + return ret; + } + + template + void WriteSymbol(CompressedByteT *buffer, T symbol, size_t offset) { + const int bits_per_byte = 8; + + for (size_t i = 0; i < symbol_bits_; i++) { + size_t byte_idx = ((offset + 1) * symbol_bits_ - (i + 1)) / bits_per_byte; + byte_idx += detail::kPadding; + size_t bit_idx = + ((bits_per_byte + i) - ((offset + 1) * symbol_bits_)) % bits_per_byte; + + if (detail::CheckBit(symbol, i)) { + detail::SetBit(&buffer[byte_idx], bit_idx); + } else { + detail::ClearBit(&buffer[byte_idx], bit_idx); + } + } } - XGBOOST_DEVICE void Write(size_t idx, uint32_t x) { - if (symbol_bits_ == 8) { - buffer_[idx] = x; - } else if (symbol_bits_ == 16) { - reinterpret_cast(buffer_)[idx] = x; - } else if (symbol_bits_ == 32) { - reinterpret_cast(buffer_)[idx] = x; +#ifdef __CUDACC__ + __device__ void AtomicWriteSymbol + (CompressedByteT* buffer, uint64_t symbol, size_t offset) { + size_t ibit_start = offset * symbol_bits_; + size_t ibit_end = (offset + 1) * symbol_bits_ - 1; + size_t ibyte_start = ibit_start / 8, ibyte_end = ibit_end / 8; + + symbol <<= 7 - ibit_end % 8; + for (ptrdiff_t ibyte = ibyte_end; ibyte >= static_cast(ibyte_start); --ibyte) { + dh::AtomicOrByte(reinterpret_cast(buffer + detail::kPadding), + ibyte, symbol & 0xff); + symbol >>= 8; + } + } +#endif // __CUDACC__ + + template + void Write(CompressedByteT *buffer, IterT input_begin, IterT input_end) { + uint64_t tmp = 0; + size_t stored_bits = 0; + const size_t max_stored_bits = 64 - symbol_bits_; + size_t buffer_position = detail::kPadding; + const size_t num_symbols = input_end - input_begin; + for (size_t i = 0; i < num_symbols; i++) { + typename std::iterator_traits::value_type symbol = input_begin[i]; + if (stored_bits > max_stored_bits) { + // Eject only full bytes + size_t tmp_bytes = stored_bits / 8; + for (size_t j = 0; j < tmp_bytes; j++) { + buffer[buffer_position] = static_cast( + tmp >> (stored_bits - (j + 1) * 8)); + buffer_position++; + } + stored_bits -= tmp_bytes * 8; + tmp &= (1 << stored_bits) - 1; + } + // Store symbol + tmp <<= symbol_bits_; + tmp |= symbol; + stored_bits += symbol_bits_; + } + + // Eject all bytes + int tmp_bytes = + static_cast(std::ceil(static_cast(stored_bits) / 8)); + for (int j = 0; j < tmp_bytes; j++) { + int shift_bits = static_cast(stored_bits) - (j + 1) * 8; + if (shift_bits >= 0) { + buffer[buffer_position] = + static_cast(tmp >> shift_bits); + } else { + buffer[buffer_position] = + static_cast(tmp << std::abs(shift_bits)); + } + buffer_position++; } } }; +/** + * \brief Read symbols from a bit compressed memory buffer. Usable on device and host. + * + * \author Rory + * \date 7/9/2017 + * + * \tparam T Generic type parameter. + */ +template class CompressedIterator { public: // Type definitions for thrust - typedef CompressedIterator self_type; // NOLINT + typedef CompressedIterator self_type; // NOLINT typedef ptrdiff_t difference_type; // NOLINT - typedef uint32_t value_type; // NOLINT + typedef T value_type; // NOLINT typedef value_type *pointer; // NOLINT - typedef value_type& reference; // NOLINT + typedef value_type reference; // NOLINT private: const CompressedByteT *buffer_ {nullptr}; - int symbol_bits_ {0}; + size_t symbol_bits_ {0}; + size_t offset_ {0}; public: CompressedIterator() = default; CompressedIterator(const CompressedByteT *buffer, size_t num_symbols) : buffer_(buffer) { - symbol_bits_ = SmallestWordSize(num_symbols); + symbol_bits_ = detail::SymbolBits(num_symbols); } - XGBOOST_DEVICE value_type operator[](size_t idx) const { - if (symbol_bits_ == 8) { - return buffer_[idx]; - } else if (symbol_bits_ == 16) { - return reinterpret_cast(buffer_)[idx]; - } else { - return reinterpret_cast(buffer_)[idx]; - } + + XGBOOST_DEVICE reference operator*() const { + const int bits_per_byte = 8; + size_t start_bit_idx = ((offset_ + 1) * symbol_bits_ - 1); + size_t start_byte_idx = start_bit_idx / bits_per_byte; + start_byte_idx += detail::kPadding; + + // Read 5 bytes - the maximum we will need + uint64_t tmp = static_cast(buffer_[start_byte_idx - 4]) << 32 | + static_cast(buffer_[start_byte_idx - 3]) << 24 | + static_cast(buffer_[start_byte_idx - 2]) << 16 | + static_cast(buffer_[start_byte_idx - 1]) << 8 | + buffer_[start_byte_idx]; + int bit_shift = + (bits_per_byte - ((offset_ + 1) * symbol_bits_)) % bits_per_byte; + tmp >>= bit_shift; + // Mask off unneeded bits + uint64_t mask = (static_cast(1) << symbol_bits_) - 1; + return static_cast(tmp & mask); + } + + XGBOOST_DEVICE reference operator[](size_t idx) const { + self_type offset = (*this); + offset.offset_ += idx; + return *offset; } }; } // namespace common diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index 355dad5a9455..14a1b2bbf172 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -27,7 +27,8 @@ void EllpackPage::SetBaseRowId(size_t row_id) { impl_->SetBaseRowId(row_id); } // Bin each input data entry, store the bin indices in compressed form. __global__ void CompressBinEllpackKernel( - common::CompressedWriter writer, + common::CompressedBufferWriter wr, + common::CompressedByteT* __restrict__ buffer, // gidx_buffer const size_t* __restrict__ row_ptrs, // row offset of input data const Entry* __restrict__ entries, // One batch of input data const float* __restrict__ cuts, // HistogramCuts::cut_values_ @@ -71,7 +72,7 @@ __global__ void CompressBinEllpackKernel( bin += cut_rows[feature]; } // Write to gidx buffer. - writer.Write((irow + base_row) * row_stride + ifeature, bin); + wr.AtomicWriteSymbol(buffer, bin, (irow + base_row) * row_stride + ifeature); } // Construct an ELLPACK matrix with the given number of empty rows. @@ -130,17 +131,21 @@ EllpackPageImpl::EllpackPageImpl(DMatrix* dmat, const BatchParam& param) template struct WriteCompressedEllpackFunctor { - WriteCompressedEllpackFunctor(common::CompressedWriter writer, + WriteCompressedEllpackFunctor(common::CompressedByteT* buffer, + const common::CompressedBufferWriter& writer, AdapterBatchT batch, EllpackDeviceAccessor accessor, common::Span feature_types, const data::IsValidFunctor& is_valid) - : writer(writer), + : d_buffer(buffer), + writer(writer), batch(std::move(batch)), accessor(std::move(accessor)), feature_types(std::move(feature_types)), is_valid(is_valid) {} - common::CompressedWriter writer; + + common::CompressedByteT* d_buffer; + common::CompressedBufferWriter writer; AdapterBatchT batch; EllpackDeviceAccessor accessor; common::Span feature_types; @@ -159,7 +164,7 @@ struct WriteCompressedEllpackFunctor { } else { bin_idx = accessor.SearchBin(e.value, e.column_idx); } - writer.Write(output_position, bin_idx); + writer.AtomicWriteSymbol(d_buffer, bin_idx, output_position); } return 0; } @@ -213,11 +218,12 @@ void CopyDataToEllpack(const AdapterBatchT &batch, using Tuple = thrust::tuple; auto device_accessor = dst->GetDeviceAccessor(device_idx); - common::CompressedWriter writer(dst->gidx_buffer.DevicePointer(), dst->NumSymbols()); + common::CompressedBufferWriter writer(device_accessor.NumSymbols()); + auto d_compressed_buffer = dst->gidx_buffer.DevicePointer(); // We redirect the scan output into this functor to do the actual writing - WriteCompressedEllpackFunctor functor(writer, - batch, device_accessor, feature_types, + WriteCompressedEllpackFunctor functor( + d_compressed_buffer, writer, batch, device_accessor, feature_types, is_valid); dh::TypedDiscard discard; thrust::transform_output_iterator< @@ -241,15 +247,18 @@ void CopyDataToEllpack(const AdapterBatchT &batch, void WriteNullValues(EllpackPageImpl* dst, int device_idx, common::Span row_counts) { // Write the null values - common::CompressedWriter writer(dst->gidx_buffer.DevicePointer(), dst->NumSymbols()); auto device_accessor = dst->GetDeviceAccessor(device_idx); + common::CompressedBufferWriter writer(device_accessor.NumSymbols()); + auto d_compressed_buffer = dst->gidx_buffer.DevicePointer(); auto row_stride = dst->row_stride; - dh::LaunchN(row_stride * dst->n_rows, [=] __device__(size_t idx) mutable { + dh::LaunchN(row_stride * dst->n_rows, [=] __device__(size_t idx) { // For some reason this variable got captured as const + auto writer_non_const = writer; size_t row_idx = idx / row_stride; size_t row_offset = idx % row_stride; if (row_offset >= row_counts[row_idx]) { - writer.Write(idx, device_accessor.NullValue()); + writer_non_const.AtomicWriteSymbol(d_compressed_buffer, + device_accessor.NullValue(), idx); } }); } @@ -275,6 +284,24 @@ EllpackPageImpl::EllpackPageImpl(AdapterBatch batch, float missing, int device, ELLPACK_BATCH_SPECIALIZE(data::CudfAdapterBatch) ELLPACK_BATCH_SPECIALIZE(data::CupyAdapterBatch) +// A functor that copies the data from one EllpackPage to another. +struct CopyPage { + common::CompressedBufferWriter cbw; + common::CompressedByteT* dst_data_d; + common::CompressedIterator src_iterator_d; + // The number of elements to skip. + size_t offset; + + CopyPage(EllpackPageImpl *dst, EllpackPageImpl const *src, size_t offset) + : cbw{dst->NumSymbols()}, dst_data_d{dst->gidx_buffer.DevicePointer()}, + src_iterator_d{src->gidx_buffer.DevicePointer(), src->NumSymbols()}, + offset(offset) {} + + __device__ void operator()(size_t element_id) { + cbw.AtomicWriteSymbol(dst_data_d, src_iterator_d[element_id], + element_id + offset); + } +}; // Copy the data from the given EllpackPage to the current page. size_t EllpackPageImpl::Copy(int device, EllpackPageImpl const *page, @@ -288,18 +315,18 @@ size_t EllpackPageImpl::Copy(int device, EllpackPageImpl const *page, LOG(FATAL) << "Concatenating the same Ellpack."; return this->n_rows * this->row_stride; } - auto src = page->GetDeviceAccessor(device).gidx_iter; - common::CompressedWriter writer(this->gidx_buffer.DevicePointer(), this->NumSymbols()); - dh::LaunchN(num_elements, - [=] __device__(std::size_t idx) mutable { writer.Write(offset + idx, src[idx]); }); + gidx_buffer.SetDevice(device); + page->gidx_buffer.SetDevice(device); + dh::LaunchN(num_elements, CopyPage(this, page, offset)); monitor_.Stop("Copy"); return num_elements; } // A functor that compacts the rows from one EllpackPage into another. struct CompactPage { - common::CompressedWriter writer; - common::CompressedIterator src_iterator_d; + common::CompressedBufferWriter cbw; + common::CompressedByteT* dst_data_d; + common::CompressedIterator src_iterator_d; /*! \brief An array that maps the rows from the full DMatrix to the compacted * page. * @@ -315,10 +342,11 @@ struct CompactPage { size_t base_rowid; size_t row_stride; - CompactPage(EllpackPageImpl* dst, EllpackPageImpl const* src, common::Span row_indexes, - int device) - : writer(dst->gidx_buffer.DevicePointer(), dst->NumSymbols()), - src_iterator_d{src->GetDeviceAccessor(device).gidx_iter}, + CompactPage(EllpackPageImpl* dst, EllpackPageImpl const* src, + common::Span row_indexes) + : cbw{dst->NumSymbols()}, + dst_data_d{dst->gidx_buffer.DevicePointer()}, + src_iterator_d{src->gidx_buffer.DevicePointer(), src->NumSymbols()}, row_indexes(row_indexes), base_rowid{src->base_rowid}, row_stride{src->row_stride} {} @@ -330,7 +358,8 @@ struct CompactPage { size_t dst_offset = dst_row * row_stride; size_t src_offset = row_id * row_stride; for (size_t j = 0; j < row_stride; j++) { - writer.Write(dst_offset + j, src_iterator_d[src_offset + j]); + cbw.AtomicWriteSymbol(dst_data_d, src_iterator_d[src_offset + j], + dst_offset + j); } } }; @@ -344,7 +373,7 @@ void EllpackPageImpl::Compact(int device, EllpackPageImpl const* page, CHECK_LE(page->base_rowid + page->n_rows, row_indexes.size()); gidx_buffer.SetDevice(device); page->gidx_buffer.SetDevice(device); - dh::LaunchN(page->n_rows, CompactPage(this, page, row_indexes, device)); + dh::LaunchN(page->n_rows, CompactPage(this, page, row_indexes)); monitor_.Stop("Compact"); } @@ -354,7 +383,7 @@ void EllpackPageImpl::InitCompressedData(int device) { // Required buffer size for storing data matrix in ELLPack format. size_t compressed_size_bytes = - common::CompressedWriter::CalculateBufferSize(row_stride * n_rows, + common::CompressedBufferWriter::CalculateBufferSize(row_stride * n_rows, num_symbols); gidx_buffer.SetDevice(device); // Don't call fill unnecessarily @@ -417,11 +446,12 @@ void EllpackPageImpl::CreateHistIndices(int device, common::DivRoundUp(row_stride, block3.y), 1); auto device_accessor = GetDeviceAccessor(device); dh::LaunchKernel {grid3, block3}( - CompressBinEllpackKernel, - common::CompressedWriter(gidx_buffer.DevicePointer(), NumSymbols()), row_ptrs.data().get(), + CompressBinEllpackKernel, common::CompressedBufferWriter(NumSymbols()), + gidx_buffer.DevicePointer(), row_ptrs.data().get(), entries_d.data().get(), device_accessor.gidx_fvalue_map.data(), - device_accessor.feature_segments.data(), feature_types, batch_row_begin, batch_nrows, - row_stride, null_gidx_value); + device_accessor.feature_segments.data(), feature_types, + batch_row_begin, batch_nrows, row_stride, + null_gidx_value); } } @@ -433,26 +463,12 @@ size_t EllpackPageImpl::MemCostBytes(size_t num_rows, size_t row_stride, const common::HistogramCuts& cuts) { // Required buffer size for storing data matrix in EtoLLPack format. size_t compressed_size_bytes = - common::CompressedWriter::CalculateBufferSize(row_stride * num_rows, + common::CompressedBufferWriter::CalculateBufferSize(row_stride * num_rows, cuts.TotalBins() + 1); return compressed_size_bytes; } EllpackDeviceAccessor EllpackPageImpl::GetDeviceAccessor( - int device, common::Span feature_types) { - gidx_buffer.SetDevice(device); - return {device, - cuts_, - is_dense, - row_stride, - base_rowid, - n_rows, - common::CompressedIterator(gidx_buffer.DevicePointer(), - NumSymbols()), - feature_types}; - } - -const EllpackDeviceAccessor EllpackPageImpl::GetDeviceAccessor( int device, common::Span feature_types) const { gidx_buffer.SetDevice(device); return {device, @@ -461,7 +477,7 @@ const EllpackDeviceAccessor EllpackPageImpl::GetDeviceAccessor( row_stride, base_rowid, n_rows, - common::CompressedIterator(gidx_buffer.ConstDevicePointer(), + common::CompressedIterator(gidx_buffer.ConstDevicePointer(), NumSymbols()), feature_types}; } diff --git a/src/data/ellpack_page.cuh b/src/data/ellpack_page.cuh index 6c1dfa3c0517..7a2020c8b0b4 100644 --- a/src/data/ellpack_page.cuh +++ b/src/data/ellpack_page.cuh @@ -24,7 +24,7 @@ struct EllpackDeviceAccessor { size_t row_stride; size_t base_rowid{}; size_t n_rows{}; - common::CompressedIterator gidx_iter; + common::CompressedIterator gidx_iter; /*! \brief Minimum value for each feature. Size equals to number of features. */ common::Span min_fvalue; /*! \brief Histogram cut pointers. Size equals to (number of features + 1). */ @@ -36,7 +36,7 @@ struct EllpackDeviceAccessor { EllpackDeviceAccessor(int device, const common::HistogramCuts& cuts, bool is_dense, size_t row_stride, size_t base_rowid, - size_t n_rows,common::CompressedIterator gidx_iter, + size_t n_rows,common::CompressedIterator gidx_iter, common::Span feature_types) : is_dense(is_dense), row_stride(row_stride), @@ -194,10 +194,7 @@ class EllpackPageImpl { EllpackDeviceAccessor GetDeviceAccessor(int device, - common::Span feature_types = {}) ; - const EllpackDeviceAccessor - GetDeviceAccessor(int device, - common::Span feature_types = {}) const ; + common::Span feature_types = {}) const; private: /*! diff --git a/tests/cpp/common/test_compressed_iterator.cc b/tests/cpp/common/test_compressed_iterator.cc index 8c7eb3d74130..93243c0b336e 100644 --- a/tests/cpp/common/test_compressed_iterator.cc +++ b/tests/cpp/common/test_compressed_iterator.cc @@ -4,7 +4,6 @@ namespace xgboost { namespace common { - TEST(CompressedIterator, Test) { ASSERT_TRUE(detail::SymbolBits(256) == 8); ASSERT_TRUE(detail::SymbolBits(150) == 8); @@ -18,25 +17,36 @@ TEST(CompressedIterator, Test) { std::vector input(num_elements); std::generate(input.begin(), input.end(), [=]() { return rand() % alphabet_size; }); + CompressedBufferWriter cbw(alphabet_size); // Test write entire array std::vector buffer( - CompressedWriter::CalculateBufferSize(input.size(), + CompressedBufferWriter::CalculateBufferSize(input.size(), alphabet_size)); - CompressedWriter writer(buffer.data(), alphabet_size); - CompressedIterator iter(buffer.data(), alphabet_size); - for (size_t i = 0; i < input.size(); i++) { - writer.Write(i,input[i]); - } + cbw.Write(buffer.data(), input.begin(), input.end()); + CompressedIterator ci(buffer.data(), alphabet_size); std::vector output(input.size()); for (size_t i = 0; i < input.size(); i++) { - output[i] = iter[i]; + output[i] = ci[i]; } ASSERT_TRUE(input == output); + // Test write Symbol + std::vector buffer2( + CompressedBufferWriter::CalculateBufferSize(input.size(), + alphabet_size)); + for (size_t i = 0; i < input.size(); i++) { + cbw.WriteSymbol(buffer2.data(), input[i], i); + } + CompressedIterator ci2(buffer.data(), alphabet_size); + std::vector output2(input.size()); + for (size_t i = 0; i < input.size(); i++) { + output2[i] = ci2[i]; + } + ASSERT_TRUE(input == output2); } } } diff --git a/tests/cpp/common/test_gpu_compressed_iterator.cu b/tests/cpp/common/test_gpu_compressed_iterator.cu index 4716b7788e90..779202a62002 100644 --- a/tests/cpp/common/test_gpu_compressed_iterator.cu +++ b/tests/cpp/common/test_gpu_compressed_iterator.cu @@ -7,7 +7,31 @@ namespace xgboost { namespace common { -void TestCompressedIterator(){ +struct WriteSymbolFunction { + CompressedBufferWriter cbw; + unsigned char* buffer_data_d; + int* input_data_d; + WriteSymbolFunction(CompressedBufferWriter cbw, unsigned char* buffer_data_d, + int* input_data_d) + : cbw(cbw), buffer_data_d(buffer_data_d), input_data_d(input_data_d) {} + + __device__ void operator()(size_t i) { + cbw.AtomicWriteSymbol(buffer_data_d, input_data_d[i], i); + } +}; + +struct ReadSymbolFunction { + CompressedIterator ci; + int* output_data_d; + ReadSymbolFunction(CompressedIterator ci, int* output_data_d) + : ci(ci), output_data_d(output_data_d) {} + + __device__ void operator()(size_t i) { + output_data_d[i] = ci[i]; + } +}; + +TEST(CompressedIterator, TestGPU) { dh::safe_cuda(cudaSetDevice(0)); std::vector test_cases = {1, 3, 426, 21, 64, 256, 100000, INT32_MAX}; int num_elements = 1000; @@ -17,23 +41,26 @@ void TestCompressedIterator(){ for (auto alphabet_size : test_cases) { for (int i = 0; i < repetitions; i++) { std::vector input(num_elements); - std::generate(input.begin(), input.end(), [=]() { return rand() % alphabet_size; }); + std::generate(input.begin(), input.end(), + [=]() { return rand() % alphabet_size; }); + CompressedBufferWriter cbw(alphabet_size); thrust::device_vector input_d(input); thrust::device_vector buffer_d( - CompressedWriter::CalculateBufferSize(input.size(), alphabet_size)); - CompressedWriter writer(buffer_d.data().get(), alphabet_size); + CompressedBufferWriter::CalculateBufferSize(input.size(), + alphabet_size)); // write the data on device auto input_data_d = input_d.data().get(); + auto buffer_data_d = buffer_d.data().get(); dh::LaunchN(input_d.size(), - [=] __device__(std::size_t idx) mutable { writer.Write(idx, input_data_d[idx]); }); + WriteSymbolFunction(cbw, buffer_data_d, input_data_d)); - CompressedIterator iter(buffer_d.data().get(), alphabet_size); + // read the data on device + CompressedIterator ci(buffer_d.data().get(), alphabet_size); thrust::device_vector output_d(input.size()); auto output_data_d = output_d.data().get(); - dh::LaunchN(output_d.size(), - [=] __device__(std::size_t idx) { output_data_d[idx] = iter[idx]; }); + dh::LaunchN(output_d.size(), ReadSymbolFunction(ci, output_data_d)); std::vector output(output_d.size()); thrust::copy(output_d.begin(), output_d.end(), output.begin()); @@ -43,9 +70,5 @@ void TestCompressedIterator(){ } } -TEST(CompressedIterator, TestGPU) { -TestCompressedIterator(); -} - } // namespace common } // namespace xgboost diff --git a/tests/cpp/data/test_ellpack_page.cu b/tests/cpp/data/test_ellpack_page.cu index f7e4f0997e0b..deccaed50020 100644 --- a/tests/cpp/data/test_ellpack_page.cu +++ b/tests/cpp/data/test_ellpack_page.cu @@ -31,7 +31,7 @@ TEST(EllpackPage, BuildGidxDense) { auto page = BuildEllpackPage(kNRows, kNCols); std::vector h_gidx_buffer(page->gidx_buffer.HostVector()); - common::CompressedIterator gidx(h_gidx_buffer.data(), page->NumSymbols()); + common::CompressedIterator gidx(h_gidx_buffer.data(), page->NumSymbols()); ASSERT_EQ(page->row_stride, kNCols); @@ -63,7 +63,7 @@ TEST(EllpackPage, BuildGidxSparse) { auto page = BuildEllpackPage(kNRows, kNCols, 0.9f); std::vector h_gidx_buffer(page->gidx_buffer.HostVector()); - common::CompressedIterator gidx(h_gidx_buffer.data(), 25); + common::CompressedIterator gidx(h_gidx_buffer.data(), 25); ASSERT_LE(page->row_stride, 3); @@ -107,7 +107,7 @@ TEST(EllpackPage, FromCategoricalBasic) { std::vector const &h_gidx_buffer = ellpack.Impl()->gidx_buffer.HostVector(); - auto h_gidx_iter = common::CompressedIterator( + auto h_gidx_iter = common::CompressedIterator( h_gidx_buffer.data(), accessor.NumSymbols()); for (size_t i = 0; i < x.size(); ++i) { diff --git a/tests/cpp/data/test_iterative_dmatrix.cu b/tests/cpp/data/test_iterative_dmatrix.cu index 8124cb456984..0a83f7e8c54b 100644 --- a/tests/cpp/data/test_iterative_dmatrix.cu +++ b/tests/cpp/data/test_iterative_dmatrix.cu @@ -69,9 +69,9 @@ void TestEquivalent(float sparsity) { auto const& buffer_from_data = ellpack.Impl()->gidx_buffer; ASSERT_NE(buffer_from_data.Size(), 0); - common::CompressedIterator data_buf{ + common::CompressedIterator data_buf{ buffer_from_data.ConstHostPointer(), from_data.NumSymbols()}; - common::CompressedIterator data_iter{ + common::CompressedIterator data_iter{ buffer_from_iter.ConstHostPointer(), from_iter.NumSymbols()}; CHECK_EQ(from_data.NumSymbols(), from_iter.NumSymbols()); CHECK_EQ(from_data.n_rows * from_data.row_stride, from_data.n_rows * from_iter.row_stride); @@ -96,7 +96,7 @@ TEST(IterativeDeviceDMatrix, RowMajor) { for (auto& ellpack : m.GetBatches({})) { n_batches ++; auto impl = ellpack.Impl(); - common::CompressedIterator iterator( + common::CompressedIterator iterator( impl->gidx_buffer.HostVector().data(), impl->NumSymbols()); auto cols = CudaArrayIterForTest::Cols(); auto rows = CudaArrayIterForTest::Rows(); @@ -144,7 +144,7 @@ TEST(IterativeDeviceDMatrix, RowMajorMissing) { 0, 256); auto &ellpack = *m.GetBatches({0, 256}).begin(); auto impl = ellpack.Impl(); - common::CompressedIterator iterator( + common::CompressedIterator iterator( impl->gidx_buffer.HostVector().data(), impl->NumSymbols()); EXPECT_EQ(iterator[1], impl->GetDeviceAccessor(0).NullValue()); EXPECT_EQ(iterator[5], impl->GetDeviceAccessor(0).NullValue()); diff --git a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu index a21ee93497ab..9e8cd19bec74 100644 --- a/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu +++ b/tests/cpp/tree/gpu_hist/test_gradient_based_sampler.cu @@ -99,14 +99,14 @@ TEST(GradientBasedSampler, NoSamplingExternalMemory) { EXPECT_EQ(sampled_page->n_rows, kRows); std::vector buffer(sampled_page->gidx_buffer.HostVector()); - common::CompressedIterator + common::CompressedIterator ci(buffer.data(), sampled_page->NumSymbols()); size_t offset = 0; for (auto& batch : dmat->GetBatches(param)) { auto page = batch.Impl(); std::vector page_buffer(page->gidx_buffer.HostVector()); - common::CompressedIterator + common::CompressedIterator page_ci(page_buffer.data(), page->NumSymbols()); size_t num_elements = page->n_rows * page->row_stride; for (size_t i = 0; i < num_elements; i++) { From 1e8a84206c26777961d36ffdca039ed0eee46529 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Thu, 11 Aug 2022 04:56:47 -0700 Subject: [PATCH 12/13] Revert "Fix test" This reverts commit e6220268858e3fbeba91b6b8046f2de90652de52. --- tests/cpp/data/test_ellpack_page.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/data/test_ellpack_page.cu b/tests/cpp/data/test_ellpack_page.cu index deccaed50020..a67ab1d59f02 100644 --- a/tests/cpp/data/test_ellpack_page.cu +++ b/tests/cpp/data/test_ellpack_page.cu @@ -23,7 +23,7 @@ TEST(EllpackPage, EmptyDMatrix) { auto impl = page.Impl(); ASSERT_EQ(impl->row_stride, 0); ASSERT_EQ(impl->Cuts().TotalBins(), 0); - ASSERT_EQ(impl->gidx_buffer.Size(), 0); + ASSERT_EQ(impl->gidx_buffer.Size(), 4); } TEST(EllpackPage, BuildGidxDense) { From 39ac148c4c772400344afce4f6374e76029d41d2 Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Thu, 11 Aug 2022 04:57:19 -0700 Subject: [PATCH 13/13] Revert "More test fixes." This reverts commit 08cc82d3bb5cc0eadf6228ec18cd760c51552a7d. --- src/data/ellpack_page_raw_format.cu | 1 + tests/python-gpu/test_large_input.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/data/ellpack_page_raw_format.cu b/src/data/ellpack_page_raw_format.cu index 9cc4ecb76c5b..2f54b91c9bbc 100644 --- a/src/data/ellpack_page_raw_format.cu +++ b/src/data/ellpack_page_raw_format.cu @@ -41,6 +41,7 @@ class EllpackPageRawFormat : public SparsePageFormat { bytes += sizeof(impl->is_dense); fo->Write(impl->row_stride); bytes += sizeof(impl->row_stride); + CHECK(!impl->gidx_buffer.ConstHostVector().empty()); fo->Write(impl->gidx_buffer.HostVector()); bytes += impl->gidx_buffer.ConstHostSpan().size_bytes() + sizeof(uint64_t); fo->Write(impl->base_rowid); diff --git a/tests/python-gpu/test_large_input.py b/tests/python-gpu/test_large_input.py index d21e364f4e65..4c8e06a6f6a5 100644 --- a/tests/python-gpu/test_large_input.py +++ b/tests/python-gpu/test_large_input.py @@ -9,7 +9,7 @@ def test_large_input(): available_bytes, _ = cp.cuda.runtime.memGetInfo() # 15 GB - required_bytes = 1.6e+10 + required_bytes = 1.5e+10 if available_bytes < required_bytes: pytest.skip("Not enough memory on this device") n = 1000