Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minimal inclusive scan #6234

Merged
merged 10 commits into from Nov 3, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 2 additions & 4 deletions cmake/Utils.cmake
Expand Up @@ -144,10 +144,8 @@ function(xgboost_set_cuda_flags target)
endif (CMAKE_VERSION VERSION_GREATER_EQUAL "3.18")

if (USE_DEVICE_DEBUG)
if (CMAKE_BUILD_TYPE MATCHES "Debug")
target_compile_options(${target} PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:-G;-src-in-ptx>)
endif(CMAKE_BUILD_TYPE MATCHES "Debug")
target_compile_options(${target} PRIVATE
$<$<AND:$<CONFIG:DEBUG>,$<COMPILE_LANGUAGE:CUDA>>:-G;-src-in-ptx>)
hcho3 marked this conversation as resolved.
Show resolved Hide resolved
else (USE_DEVICE_DEBUG)
target_compile_options(${target} PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:-lineinfo>)
Expand Down
92 changes: 91 additions & 1 deletion src/common/device_helpers.cuh
Expand Up @@ -307,7 +307,7 @@ class MemoryLogger {
void RegisterDeallocation(void *ptr, size_t n, int current_device) {
auto itr = device_allocations.find(ptr);
if (itr == device_allocations.end()) {
LOG(FATAL) << "Attempting to deallocate " << n << " bytes on device "
LOG(WARNING) << "Attempting to deallocate " << n << " bytes on device "
<< current_device << " that was never allocated ";
}
num_deallocations++;
Expand Down Expand Up @@ -1149,4 +1149,94 @@ auto Reduce(Policy policy, InputIt first, InputIt second, Init init, Func reduce
}
return aggregate;
}

namespace detail {
template <typename T>
__device__ xgboost::common::Range ScanBlockedRange(T begin, T end) {
size_t items_per_block =
xgboost::common::DivRoundUp(size_t(end - begin), gridDim.x);
// Round up to nearest block size
items_per_block =
xgboost::common::DivRoundUp(items_per_block, blockDim.x) * blockDim.x;
size_t local_begin = begin + items_per_block * blockIdx.x + threadIdx.x;
size_t local_end = local_begin + items_per_block;
local_begin = std::min(local_begin, local_end);
xgboost::common::Range r(local_begin, local_end);
r.Step(blockDim.x);
return r;
}

template <int kBlockSize, typename InputIt, typename Func, typename InputT>
__global__ void InclusiveScanReduceKernel(InputIt first, size_t size, Func op,
InputT* partials) {
using BlockScanT = cub::BlockScan<InputT, kBlockSize>;
// Allocate shared memory for BlockScan
__shared__ typename BlockScanT::TempStorage temp_storage;
InputT partial;
for (auto i : ScanBlockedRange(size_t(0), size)) {
InputT block_aggregate;
InputT thread_data = i < size ? first[i] : InputT();
BlockScanT(temp_storage)
.InclusiveScan(thread_data, thread_data, op, block_aggregate);
__syncthreads();
partial = op(partial, block_aggregate);
}
if (threadIdx.x == 0) {
partials[blockIdx.x] = partial;
}
}

template <int kBlockSize, typename Func, typename InputT>
__global__ void InclusiveScanPartialsKernel(Func op, InputT* partials) {
using BlockScanT = cub::BlockScan<InputT, kBlockSize>;
// Allocate shared memory for BlockScan
__shared__ typename BlockScanT::TempStorage temp_storage;
InputT thread_data = partials[threadIdx.x];
BlockScanT(temp_storage)
.ExclusiveScan(thread_data, thread_data, InputT(), op);
__syncthreads();
partials[threadIdx.x] = thread_data;
}

template <int kBlockSize, typename InputIt, typename OutputIt, typename Func,
typename InputT>
__global__ void InclusiveScanFinalKernel(InputIt first, OutputIt out,
size_t size, Func op,
InputT* partials) {
using BlockScanT = cub::BlockScan<InputT, kBlockSize>;
// Allocate shared memory for BlockScan
__shared__ typename BlockScanT::TempStorage temp_storage;
InputT partial = partials[blockIdx.x];
for (auto i : ScanBlockedRange(size_t(0), size)) {
InputT block_aggregate;
InputT thread_data = i < size ? first[i] : InputT();
BlockScanT(temp_storage)
.InclusiveScan(thread_data, thread_data, op, block_aggregate);
__syncthreads();
if (i < size) {
out[i] = op(partial, thread_data);
}
partial = op(partial, block_aggregate);
}
}
}; // namespace detail

// Workaround for thrust::inclusive_scan, which is unable to handle n > n^31
template <typename InputIt, typename OutputIt, typename Func>
void InclusiveScan(InputIt first, InputIt second, OutputIt output, Func op) {
size_t size = std::distance(first, second);
const int kNumBlocks = 128;
const int kBlockSize = 256;
using InputT = typename std::iterator_traits<InputIt>::value_type;
dh::TemporaryArray<InputT> partials(kNumBlocks, InputT());

detail::InclusiveScanReduceKernel<kBlockSize>
<<<kNumBlocks, kBlockSize>>>(first, size, op, partials.data().get());

detail::InclusiveScanPartialsKernel<kNumBlocks>
<<<1, kNumBlocks>>>(op, partials.data().get());

detail::InclusiveScanFinalKernel<kBlockSize><<<kNumBlocks, kBlockSize>>>(
first, output, size, op, partials.data().get());
}
} // namespace dh
10 changes: 1 addition & 9 deletions src/data/ellpack_page.cu
Expand Up @@ -206,15 +206,7 @@ void CopyDataToEllpack(const AdapterBatchT& batch, EllpackPageImpl* dst,
WriteCompressedEllpackFunctor<AdapterBatchT>, decltype(discard)>
out(discard, functor);
dh::XGBCachingDeviceAllocator<char> alloc;
// 1000 as a safe factor for inclusive_scan, otherwise it might generate overflow and
// lead to oom error.
// or:
// after reduction step 2: cudaErrorInvalidConfiguration: invalid configuration argument
// https://github.com/NVIDIA/thrust/issues/1299
CHECK_LE(batch.Size(), std::numeric_limits<int32_t>::max() - 1000)
<< "Known limitation, size (rows * cols) of quantile based DMatrix "
"cannot exceed the limit of 32-bit integer.";
thrust::inclusive_scan(thrust::cuda::par(alloc), key_value_index_iter,
dh::InclusiveScan(key_value_index_iter,
RAMitchell marked this conversation as resolved.
Show resolved Hide resolved
key_value_index_iter + batch.Size(), out,
[=] __device__(Tuple a, Tuple b) {
// Key equal
Expand Down
13 changes: 13 additions & 0 deletions tests/cpp/common/test_device_helpers.cu
Expand Up @@ -129,6 +129,19 @@ TEST(DeviceHelpers, Reduce) {
CHECK_EQ(batched, kSize - 1);
}

TEST(DeviceHelpers, InclusiveScan) {
size_t sizes[] = {0, 1, 14, 1781, 59268};
for (auto n : sizes) {
thrust::device_vector<int> x(n, 1);
thrust::device_vector<int> out(x.size());
thrust::device_vector<int> thrust_out(x.size());

dh::InclusiveScan(x.begin(), x.end(), out.begin(), thrust::plus<int>{});
thrust::inclusive_scan(x.begin(), x.end(), thrust_out.begin());

EXPECT_TRUE(thrust::equal(out.begin(), out.end(), thrust_out.begin()));
}
}

TEST(SegmentedUnique, Regression) {
{
Expand Down