From 776ef9fb807dfc568fef0bdec561d7753886464b Mon Sep 17 00:00:00 2001 From: Rory Mitchell Date: Tue, 28 Jun 2022 05:25:56 -0700 Subject: [PATCH] Remove constant memory in favour of __ldg(). --- src/common/device_helpers.cuh | 20 +++++++ src/tree/gpu_hist/row_partitioner.cuh | 57 ++++++------------- .../cpp/tree/gpu_hist/test_row_partitioner.cu | 11 ++-- 3 files changed, 42 insertions(+), 46 deletions(-) diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 123dc14e57be..33989a230464 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -1939,4 +1939,24 @@ class CUDAStream { CUDAStreamView View() const { return CUDAStreamView{stream_}; } void Sync() { this->View().Sync(); } }; + +// Force nvcc to load data as constant +template +class LDGIterator { + typedef typename cub::UnitWord::DeviceWord DeviceWordT; + static constexpr std::size_t kNumWords = sizeof(T) / sizeof(DeviceWordT); + + const T* ptr; + + public: + LDGIterator(const T* ptr) : ptr(ptr) {} + __device__ T operator[](std::size_t idx) const { + DeviceWordT tmp[kNumWords]; +#pragma unroll + for (int i = 0; i < kNumWords; i++) { + tmp[i] = __ldg(reinterpret_cast(ptr + idx) + i); + } + return *reinterpret_cast(tmp); + } +}; } // namespace dh diff --git a/src/tree/gpu_hist/row_partitioner.cuh b/src/tree/gpu_hist/row_partitioner.cuh index 3a42f9245a63..e9fb7e86add7 100644 --- a/src/tree/gpu_hist/row_partitioner.cuh +++ b/src/tree/gpu_hist/row_partitioner.cuh @@ -36,8 +36,6 @@ struct PerNodeData { OpDataT data; }; -__constant__ char constant_memory[kMaxUpdatePositionBatchSize * 256]; - template __device__ __forceinline__ void AssignBatch(BatchIterT batch_info, std::size_t global_thread_idx, int* batch_idx, std::size_t* item_idx) { @@ -52,36 +50,14 @@ __device__ __forceinline__ void AssignBatch(BatchIterT batch_info, std::size_t g } } -template -struct SharedStorage { - PerNodeData data[kMaxUpdatePositionBatchSize]; - // Collectively load from global memory into shared memory - template - __device__ const PerNodeData* BlockLoad(const PerNodeData* d_batch_info) { - for (int i = threadIdx.x; i < kMaxUpdatePositionBatchSize; i += kBlockSize) { - data[i] = d_batch_info[i]; - } - __syncthreads(); - return data; - } -}; - template __global__ __launch_bounds__(kBlockSize) void SortPositionCopyKernel( - common::Span d_ridx, const common::Span ridx_tmp, - std::size_t total_rows) { - // Load this into shared memory - // the compiler puts it into registers otherwise - // then we get spilling to local memory - const PerNodeData* batch_info = - reinterpret_cast*>(constant_memory); - __shared__ cub::Uninitialized> shared; - auto s_batch_info = shared.Alias().BlockLoad(batch_info); - + dh::LDGIterator> batch_info, common::Span d_ridx, + const common::Span ridx_tmp, std::size_t total_rows) { for (auto idx : dh::GridStrideRange(0, total_rows)) { int batch_idx; std::size_t item_idx; - AssignBatch(s_batch_info, idx, &batch_idx, &item_idx); + AssignBatch(batch_info, idx, &batch_idx, &item_idx); d_ridx[item_idx] = ridx_tmp[item_idx]; } } @@ -109,14 +85,13 @@ struct IndexFlagOp { template struct WriteResultsFunctor { + dh::LDGIterator> batch_info; const bst_uint* ridx_in; bst_uint* ridx_out; bst_uint* counts; __device__ IndexFlagTuple operator()(const IndexFlagTuple& x) { std::size_t scatter_address; - const PerNodeData* batch_info = - reinterpret_cast*>(constant_memory); const Segment& segment = batch_info[x.batch_idx].segment; if (x.flag) { bst_uint num_previous_flagged = x.flag_scan - 1; // -1 because inclusive scan @@ -138,18 +113,19 @@ struct WriteResultsFunctor { }; template -void SortPositionBatch(common::Span ridx, common::Span ridx_tmp, +void SortPositionBatch(common::Span> d_batch_info, + common::Span ridx, common::Span ridx_tmp, common::Span d_counts, std::size_t total_rows, OpT op, dh::device_vector* tmp, cudaStream_t stream) { - WriteResultsFunctor write_results{ridx.data(), ridx_tmp.data(), d_counts.data()}; + dh::LDGIterator> batch_info_itr(d_batch_info.data()); + WriteResultsFunctor write_results{batch_info_itr, ridx.data(), ridx_tmp.data(), + d_counts.data()}; auto discard_write_iterator = thrust::make_transform_output_iterator(dh::TypedDiscard(), write_results); auto counting = thrust::make_counting_iterator(0llu); auto input_iterator = dh::MakeTransformIterator(counting, [=] __device__(size_t idx) { - const PerNodeData* batch_info_itr = - reinterpret_cast*>(constant_memory); int batch_idx; std::size_t item_idx; AssignBatch(batch_info_itr, idx, &batch_idx, &item_idx); @@ -173,7 +149,7 @@ void SortPositionBatch(common::Span ridx, common::Span rid const int grid_size = xgboost::common::DivRoundUp(total_rows, kBlockSize * kItemsThread); SortPositionCopyKernel - <<>>(ridx, ridx_tmp, total_rows); + <<>>(batch_info_itr, ridx, ridx_tmp, total_rows); } struct NodePositionInfo { @@ -294,25 +270,24 @@ class RowPartitioner { CHECK_EQ(nidx.size(), op_data.size()); auto h_batch_info = pinned2_.GetSpan>(nidx.size()); + dh::TemporaryArray> d_batch_info(nidx.size()); std::size_t total_rows = 0; for (int i = 0; i < nidx.size(); i++) { h_batch_info[i] = {ridx_segments_.at(nidx.at(i)).segment, op_data.at(i)}; total_rows += ridx_segments_.at(nidx.at(i)).segment.Size(); } - static_assert(sizeof(PerNodeData) * kMaxUpdatePositionBatchSize <= - sizeof(constant_memory),"Not enough constant memory allocated.") ; - dh::safe_cuda(cudaMemcpyToSymbolAsync(constant_memory, h_batch_info.data(), - h_batch_info.size() * sizeof(PerNodeData), 0, - cudaMemcpyDefault, stream_)); + dh::safe_cuda(cudaMemcpyAsync(d_batch_info.data().get(), h_batch_info.data(), + h_batch_info.size() * sizeof(PerNodeData), + cudaMemcpyDefault, stream_)); // Temporary arrays auto h_counts = pinned_.GetSpan(nidx.size(), 0); // Partition the rows according to the operator SortPositionBatch( - dh::ToSpan(ridx_), dh::ToSpan(ridx_tmp_), dh::ToSpan(d_counts_), total_rows, op, &tmp_, - stream_); + dh::ToSpan(d_batch_info), dh::ToSpan(ridx_), dh::ToSpan(ridx_tmp_), dh::ToSpan(d_counts_), + total_rows, op, &tmp_, stream_); dh::safe_cuda(cudaMemcpyAsync(h_counts.data(), d_counts_.data().get(), sizeof(decltype(d_counts_)::value_type) * h_counts.size(), cudaMemcpyDefault, stream_)); diff --git a/tests/cpp/tree/gpu_hist/test_row_partitioner.cu b/tests/cpp/tree/gpu_hist/test_row_partitioner.cu index d35178c643c3..520cc3cd0b81 100644 --- a/tests/cpp/tree/gpu_hist/test_row_partitioner.cu +++ b/tests/cpp/tree/gpu_hist/test_row_partitioner.cu @@ -67,12 +67,13 @@ void TestSortPositionBatch(const std::vector& ridx_in, const std::vector), 0, - cudaMemcpyDefault, nullptr)); + dh::safe_cuda(cudaMemcpyAsync(d_batch_info.data().get(), h_batch_info.data(), + h_batch_info.size() * sizeof(PerNodeData), cudaMemcpyDefault, + nullptr)); dh::device_vector tmp; - SortPositionBatch(dh::ToSpan(ridx), dh::ToSpan(ridx_tmp), - dh::ToSpan(counts), total_rows, op, &tmp,nullptr); + SortPositionBatch(dh::ToSpan(d_batch_info), dh::ToSpan(ridx), + dh::ToSpan(ridx_tmp), dh::ToSpan(counts), + total_rows, op, &tmp, nullptr); auto op_without_data = [=] __device__(auto ridx) { return ridx % 2 == 0; }; for (int i = 0; i < segments.size(); i++) {