Skip to content

Commit

Permalink
Remove constant memory in favour of __ldg().
Browse files Browse the repository at this point in the history
  • Loading branch information
RAMitchell committed Jun 28, 2022
1 parent 8caed98 commit 776ef9f
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 46 deletions.
20 changes: 20 additions & 0 deletions src/common/device_helpers.cuh
Expand Up @@ -1939,4 +1939,24 @@ class CUDAStream {
CUDAStreamView View() const { return CUDAStreamView{stream_}; }
void Sync() { this->View().Sync(); }
};

// Force nvcc to load data as constant
template <typename T>
class LDGIterator {
typedef typename cub::UnitWord<T>::DeviceWord DeviceWordT;
static constexpr std::size_t kNumWords = sizeof(T) / sizeof(DeviceWordT);

const T* ptr;

public:
LDGIterator(const T* ptr) : ptr(ptr) {}
__device__ T operator[](std::size_t idx) const {
DeviceWordT tmp[kNumWords];
#pragma unroll
for (int i = 0; i < kNumWords; i++) {
tmp[i] = __ldg(reinterpret_cast<const DeviceWordT*>(ptr + idx) + i);
}
return *reinterpret_cast<const T*>(tmp);
}
};
} // namespace dh
57 changes: 16 additions & 41 deletions src/tree/gpu_hist/row_partitioner.cuh
Expand Up @@ -36,8 +36,6 @@ struct PerNodeData {
OpDataT data;
};

__constant__ char constant_memory[kMaxUpdatePositionBatchSize * 256];

template <typename BatchIterT>
__device__ __forceinline__ void AssignBatch(BatchIterT batch_info, std::size_t global_thread_idx,
int* batch_idx, std::size_t* item_idx) {
Expand All @@ -52,36 +50,14 @@ __device__ __forceinline__ void AssignBatch(BatchIterT batch_info, std::size_t g
}
}

template <typename OpDataT>
struct SharedStorage {
PerNodeData<OpDataT> data[kMaxUpdatePositionBatchSize];
// Collectively load from global memory into shared memory
template <int kBlockSize>
__device__ const PerNodeData<OpDataT>* BlockLoad(const PerNodeData<OpDataT>* d_batch_info) {
for (int i = threadIdx.x; i < kMaxUpdatePositionBatchSize; i += kBlockSize) {
data[i] = d_batch_info[i];
}
__syncthreads();
return data;
}
};

template <int kBlockSize, typename RowIndexT, typename OpDataT>
__global__ __launch_bounds__(kBlockSize) void SortPositionCopyKernel(
common::Span<RowIndexT> d_ridx, const common::Span<const RowIndexT> ridx_tmp,
std::size_t total_rows) {
// Load this into shared memory
// the compiler puts it into registers otherwise
// then we get spilling to local memory
const PerNodeData<OpDataT>* batch_info =
reinterpret_cast<const PerNodeData<OpDataT>*>(constant_memory);
__shared__ cub::Uninitialized<SharedStorage<OpDataT>> shared;
auto s_batch_info = shared.Alias().BlockLoad<kBlockSize>(batch_info);

dh::LDGIterator<PerNodeData<OpDataT>> batch_info, common::Span<RowIndexT> d_ridx,
const common::Span<const RowIndexT> ridx_tmp, std::size_t total_rows) {
for (auto idx : dh::GridStrideRange<std::size_t>(0, total_rows)) {
int batch_idx;
std::size_t item_idx;
AssignBatch(s_batch_info, idx, &batch_idx, &item_idx);
AssignBatch(batch_info, idx, &batch_idx, &item_idx);
d_ridx[item_idx] = ridx_tmp[item_idx];
}
}
Expand Down Expand Up @@ -109,14 +85,13 @@ struct IndexFlagOp {

template <typename OpDataT>
struct WriteResultsFunctor {
dh::LDGIterator<PerNodeData<OpDataT>> batch_info;
const bst_uint* ridx_in;
bst_uint* ridx_out;
bst_uint* counts;

__device__ IndexFlagTuple operator()(const IndexFlagTuple& x) {
std::size_t scatter_address;
const PerNodeData<OpDataT>* batch_info =
reinterpret_cast<const PerNodeData<OpDataT>*>(constant_memory);
const Segment& segment = batch_info[x.batch_idx].segment;
if (x.flag) {
bst_uint num_previous_flagged = x.flag_scan - 1; // -1 because inclusive scan
Expand All @@ -138,18 +113,19 @@ struct WriteResultsFunctor {
};

template <typename RowIndexT, typename OpT, typename OpDataT>
void SortPositionBatch(common::Span<RowIndexT> ridx, common::Span<RowIndexT> ridx_tmp,
void SortPositionBatch(common::Span<const PerNodeData<OpDataT>> d_batch_info,
common::Span<RowIndexT> ridx, common::Span<RowIndexT> ridx_tmp,
common::Span<bst_uint> d_counts, std::size_t total_rows, OpT op,
dh::device_vector<int8_t>* tmp, cudaStream_t stream) {
WriteResultsFunctor<OpDataT> write_results{ridx.data(), ridx_tmp.data(), d_counts.data()};
dh::LDGIterator<PerNodeData<OpDataT>> batch_info_itr(d_batch_info.data());
WriteResultsFunctor<OpDataT> write_results{batch_info_itr, ridx.data(), ridx_tmp.data(),
d_counts.data()};

auto discard_write_iterator =
thrust::make_transform_output_iterator(dh::TypedDiscard<IndexFlagTuple>(), write_results);
auto counting = thrust::make_counting_iterator(0llu);
auto input_iterator =
dh::MakeTransformIterator<IndexFlagTuple>(counting, [=] __device__(size_t idx) {
const PerNodeData<OpDataT>* batch_info_itr =
reinterpret_cast<const PerNodeData<OpDataT>*>(constant_memory);
int batch_idx;
std::size_t item_idx;
AssignBatch(batch_info_itr, idx, &batch_idx, &item_idx);
Expand All @@ -173,7 +149,7 @@ void SortPositionBatch(common::Span<RowIndexT> ridx, common::Span<RowIndexT> rid
const int grid_size = xgboost::common::DivRoundUp(total_rows, kBlockSize * kItemsThread);

SortPositionCopyKernel<kBlockSize, RowIndexT, OpDataT>
<<<grid_size, kBlockSize, 0, stream>>>(ridx, ridx_tmp, total_rows);
<<<grid_size, kBlockSize, 0, stream>>>(batch_info_itr, ridx, ridx_tmp, total_rows);
}

struct NodePositionInfo {
Expand Down Expand Up @@ -294,25 +270,24 @@ class RowPartitioner {
CHECK_EQ(nidx.size(), op_data.size());

auto h_batch_info = pinned2_.GetSpan<PerNodeData<OpDataT>>(nidx.size());
dh::TemporaryArray<PerNodeData<OpDataT>> d_batch_info(nidx.size());

std::size_t total_rows = 0;
for (int i = 0; i < nidx.size(); i++) {
h_batch_info[i] = {ridx_segments_.at(nidx.at(i)).segment, op_data.at(i)};
total_rows += ridx_segments_.at(nidx.at(i)).segment.Size();
}
static_assert(sizeof(PerNodeData<OpDataT>) * kMaxUpdatePositionBatchSize <=
sizeof(constant_memory),"Not enough constant memory allocated.") ;
dh::safe_cuda(cudaMemcpyToSymbolAsync(constant_memory, h_batch_info.data(),
h_batch_info.size() * sizeof(PerNodeData<OpDataT>), 0,
cudaMemcpyDefault, stream_));
dh::safe_cuda(cudaMemcpyAsync(d_batch_info.data().get(), h_batch_info.data(),
h_batch_info.size() * sizeof(PerNodeData<OpDataT>),
cudaMemcpyDefault, stream_));

// Temporary arrays
auto h_counts = pinned_.GetSpan<bst_uint>(nidx.size(), 0);

// Partition the rows according to the operator
SortPositionBatch<RowIndexT, UpdatePositionOpT, OpDataT>(
dh::ToSpan(ridx_), dh::ToSpan(ridx_tmp_), dh::ToSpan(d_counts_), total_rows, op, &tmp_,
stream_);
dh::ToSpan(d_batch_info), dh::ToSpan(ridx_), dh::ToSpan(ridx_tmp_), dh::ToSpan(d_counts_),
total_rows, op, &tmp_, stream_);
dh::safe_cuda(cudaMemcpyAsync(h_counts.data(), d_counts_.data().get(),
sizeof(decltype(d_counts_)::value_type) * h_counts.size(),
cudaMemcpyDefault, stream_));
Expand Down
11 changes: 6 additions & 5 deletions tests/cpp/tree/gpu_hist/test_row_partitioner.cu
Expand Up @@ -67,12 +67,13 @@ void TestSortPositionBatch(const std::vector<int>& ridx_in, const std::vector<Se
h_batch_info[i] = {segments.at(i), 0};
total_rows += segments.at(i).Size();
}
dh::safe_cuda(cudaMemcpyToSymbolAsync(constant_memory, h_batch_info.data(),
h_batch_info.size() * sizeof(PerNodeData<int>), 0,
cudaMemcpyDefault, nullptr));
dh::safe_cuda(cudaMemcpyAsync(d_batch_info.data().get(), h_batch_info.data(),
h_batch_info.size() * sizeof(PerNodeData<int>), cudaMemcpyDefault,
nullptr));
dh::device_vector<int8_t> tmp;
SortPositionBatch<uint32_t, decltype(op), int>(dh::ToSpan(ridx), dh::ToSpan(ridx_tmp),
dh::ToSpan(counts), total_rows, op, &tmp,nullptr);
SortPositionBatch<uint32_t, decltype(op), int>(dh::ToSpan(d_batch_info), dh::ToSpan(ridx),
dh::ToSpan(ridx_tmp), dh::ToSpan(counts),
total_rows, op, &tmp, nullptr);

auto op_without_data = [=] __device__(auto ridx) { return ridx % 2 == 0; };
for (int i = 0; i < segments.size(); i++) {
Expand Down

0 comments on commit 776ef9f

Please sign in to comment.