Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimise histogram kernels #8118

Merged
merged 15 commits into from Aug 18, 2022
292 changes: 181 additions & 111 deletions src/tree/gpu_hist/histogram.cu
@@ -1,19 +1,18 @@
/*!
* Copyright 2020-2021 by XGBoost Contributors
*/
#include <thrust/reduce.h>
#include <thrust/iterator/transform_iterator.h>
#include <thrust/reduce.h>

#include <algorithm>
#include <ctgmath>
#include <limits>

#include "xgboost/base.h"
#include "row_partitioner.cuh"

#include "histogram.cuh"

#include "../../data/ellpack_page.cuh"
#include "../../common/device_helpers.cuh"
#include "../../data/ellpack_page.cuh"
#include "histogram.cuh"
#include "row_partitioner.cuh"
#include "xgboost/base.h"

namespace xgboost {
namespace tree {
Expand Down Expand Up @@ -59,12 +58,8 @@ __host__ XGBOOST_DEV_INLINE Pair operator+(Pair const& lhs, Pair const& rhs) {
} // anonymous namespace

struct Clip : public thrust::unary_function<GradientPair, Pair> {
static XGBOOST_DEV_INLINE float Pclip(float v) {
return v > 0 ? v : 0;
}
static XGBOOST_DEV_INLINE float Nclip(float v) {
return v < 0 ? abs(v) : 0;
}
static XGBOOST_DEV_INLINE float Pclip(float v) { return v > 0 ? v : 0; }
static XGBOOST_DEV_INLINE float Nclip(float v) { return v < 0 ? abs(v) : 0; }

XGBOOST_DEV_INLINE Pair operator()(GradientPair x) const {
auto pg = Pclip(x.GetGrad());
Expand All @@ -73,7 +68,7 @@ struct Clip : public thrust::unary_function<GradientPair, Pair> {
auto ng = Nclip(x.GetGrad());
auto nh = Nclip(x.GetHess());

return { GradientPair{ pg, ph }, GradientPair{ ng, nh } };
return {GradientPair{pg, ph}, GradientPair{ng, nh}};
}
};

Expand All @@ -82,18 +77,18 @@ HistRounding<GradientSumT> CreateRoundingFactor(common::Span<GradientPair const>
using T = typename GradientSumT::ValueT;
dh::XGBCachingDeviceAllocator<char> alloc;

thrust::device_ptr<GradientPair const> gpair_beg {gpair.data()};
thrust::device_ptr<GradientPair const> gpair_end {gpair.data() + gpair.size()};
thrust::device_ptr<GradientPair const> gpair_beg{gpair.data()};
thrust::device_ptr<GradientPair const> gpair_end{gpair.data() + gpair.size()};
auto beg = thrust::make_transform_iterator(gpair_beg, Clip());
auto end = thrust::make_transform_iterator(gpair_end, Clip());
Pair p = dh::Reduce(thrust::cuda::par(alloc), beg, end, Pair{}, thrust::plus<Pair>{});
GradientPair positive_sum {p.first}, negative_sum {p.second};
GradientPair positive_sum{p.first}, negative_sum{p.second};

auto histogram_rounding = GradientSumT {
CreateRoundingFactor<T>(std::max(positive_sum.GetGrad(), negative_sum.GetGrad()),
gpair.size()),
CreateRoundingFactor<T>(std::max(positive_sum.GetHess(), negative_sum.GetHess()),
gpair.size()) };
auto histogram_rounding =
GradientSumT{CreateRoundingFactor<T>(std::max(positive_sum.GetGrad(), negative_sum.GetGrad()),
gpair.size()),
CreateRoundingFactor<T>(std::max(positive_sum.GetHess(), negative_sum.GetHess()),
gpair.size())};

using IntT = typename HistRounding<GradientSumT>::SharedSumT::ValueT;

Expand All @@ -102,8 +97,7 @@ HistRounding<GradientSumT> CreateRoundingFactor(common::Span<GradientPair const>
*/
GradientSumT to_floating_point =
histogram_rounding /
T(IntT(1) << (sizeof(typename GradientSumT::ValueT) * 8 -
2)); // keep 1 for sign bit
T(IntT(1) << (sizeof(typename GradientSumT::ValueT) * 8 - 2)); // keep 1 for sign bit
/**
* Factor for converting gradients from floating-point to fixed-point. For
* f64:
Expand All @@ -113,66 +107,149 @@ HistRounding<GradientSumT> CreateRoundingFactor(common::Span<GradientPair const>
* rounding is calcuated as exp(m), see the rounding factor calcuation for
* details.
*/
GradientSumT to_fixed_point = GradientSumT(
T(1) / to_floating_point.GetGrad(), T(1) / to_floating_point.GetHess());
GradientSumT to_fixed_point =
GradientSumT(T(1) / to_floating_point.GetGrad(), T(1) / to_floating_point.GetHess());

return {histogram_rounding, to_fixed_point, to_floating_point};
}

template HistRounding<GradientPairPrecise>
CreateRoundingFactor(common::Span<GradientPair const> gpair);
template HistRounding<GradientPair>
CreateRoundingFactor(common::Span<GradientPair const> gpair);

template <typename GradientSumT, bool use_shared_memory_histograms>
__global__ void SharedMemHistKernel(EllpackDeviceAccessor matrix,
FeatureGroupsAccessor feature_groups,
common::Span<const RowPartitioner::RowIndexT> d_ridx,
GradientSumT* __restrict__ d_node_hist,
const GradientPair* __restrict__ d_gpair,
HistRounding<GradientSumT> const rounding) {
template HistRounding<GradientPairPrecise> CreateRoundingFactor(
common::Span<GradientPair const> gpair);
template HistRounding<GradientPair> CreateRoundingFactor(common::Span<GradientPair const> gpair);

template <typename GradientSumT, int kBlockThreads, int kItemsPerThread,
int kItemsPerTile = kBlockThreads* kItemsPerThread>
class HistogramAgent {
using SharedSumT = typename HistRounding<GradientSumT>::SharedSumT;
using T = typename GradientSumT::ValueT;
SharedSumT* smem_arr_;
GradientSumT* d_node_hist_;
dh::LDGIterator<const RowPartitioner::RowIndexT> d_ridx_;
const GradientPair* d_gpair_;
const FeatureGroup group_;
const EllpackDeviceAccessor& matrix_;
const int feature_stride_;
const std::size_t n_elements_;
const HistRounding<GradientSumT>& rounding_;

public:
__device__ HistogramAgent(SharedSumT* smem_arr, GradientSumT* __restrict__ d_node_hist,
const FeatureGroup& group, const EllpackDeviceAccessor& matrix,
common::Span<const RowPartitioner::RowIndexT> d_ridx,
const HistRounding<GradientSumT>& rounding, const GradientPair* d_gpair)
: smem_arr_(smem_arr),
d_node_hist_(d_node_hist),
d_ridx_(d_ridx.data()),
group_(group),
matrix_(matrix),
feature_stride_(matrix.is_dense ? group.num_features : matrix.row_stride),
n_elements_(feature_stride_ * d_ridx.size()),
rounding_(rounding),
d_gpair_(d_gpair) {}
__device__ void ProcessPartialTileShared(std::size_t offset) {
for (std::size_t idx = offset + threadIdx.x;
idx < min(offset + kBlockThreads * kItemsPerTile, n_elements_); idx += kBlockThreads) {
int ridx = d_ridx_[idx / feature_stride_];
int gidx =
matrix_
.gidx_iter[ridx * matrix_.row_stride + group_.start_feature + idx % feature_stride_] -
group_.start_bin;
if (matrix_.is_dense || gidx != matrix_.NumBins()) {
auto adjusted = rounding_.ToFixedPoint(d_gpair_[ridx]);
dh::AtomicAddGpair(smem_arr_ + gidx, adjusted);
}
}
}
// Instruction level parallelism by loop unrolling
// Allows the kernel to pipeline many operations while waiting for global memory
// Increases the throughput of this kernel significantly
__device__ void ProcessFullTileShared(std::size_t offset) {
std::size_t idx[kItemsPerThread];
int ridx[kItemsPerThread];
int gidx[kItemsPerThread];
GradientPair gpair[kItemsPerThread];
#pragma unroll
for (int i = 0; i < kItemsPerThread; i++) {
idx[i] = offset + i * kBlockThreads + threadIdx.x;
}
#pragma unroll
for (int i = 0; i < kItemsPerThread; i++) {
ridx[i] = d_ridx_[idx[i] / feature_stride_];
}
#pragma unroll
for (int i = 0; i < kItemsPerThread; i++) {
gpair[i] = d_gpair_[ridx[i]];
gidx[i] = matrix_.gidx_iter[ridx[i] * matrix_.row_stride + group_.start_feature +
idx[i] % feature_stride_];
}
#pragma unroll
for (int i = 0; i < kItemsPerThread; i++) {
if ((matrix_.is_dense || gidx[i] != matrix_.NumBins())) {
auto adjusted = rounding_.ToFixedPoint(gpair[i]);
dh::AtomicAddGpair(smem_arr_ + gidx[i] - group_.start_bin, adjusted);
}
}
}
__device__ void BuildHistogramWithShared() {
dh::BlockFill(smem_arr_, group_.num_bins, SharedSumT());
__syncthreads();

extern __shared__ char smem[];
FeatureGroup group = feature_groups[blockIdx.y];
SharedSumT *smem_arr = reinterpret_cast<SharedSumT *>(smem);
if (use_shared_memory_histograms) {
dh::BlockFill(smem_arr, group.num_bins, SharedSumT());
std::size_t offset = blockIdx.x * kItemsPerTile;
while (offset + kItemsPerTile <= n_elements_) {
ProcessFullTileShared(offset);
offset += kItemsPerTile * gridDim.x;
}
ProcessPartialTileShared(offset);

// Write shared memory back to global memory
__syncthreads();
for (auto i : dh::BlockStrideRange(0, group_.num_bins)) {
auto truncated = rounding_.ToFloatingPoint(smem_arr_[i]);
dh::AtomicAddGpair(d_node_hist_ + group_.start_bin + i, truncated);
}
}
int feature_stride = matrix.is_dense ? group.num_features : matrix.row_stride;
size_t n_elements = feature_stride * d_ridx.size();
for (auto idx : dh::GridStrideRange(static_cast<size_t>(0), n_elements)) {
int ridx = d_ridx[idx / feature_stride];
int gidx = matrix.gidx_iter[ridx * matrix.row_stride + group.start_feature +
idx % feature_stride];
if (gidx != matrix.NumBins()) {
// If we are not using shared memory, accumulate the values directly into
// global memory
gidx = use_shared_memory_histograms ? gidx - group.start_bin : gidx;
if (use_shared_memory_histograms) {
auto adjusted = rounding.ToFixedPoint(d_gpair[ridx]);
dh::AtomicAddGpair(smem_arr + gidx, adjusted);
} else {

__device__ void BuildHistogramWithGlobal() {
for (auto idx : dh::GridStrideRange(static_cast<std::size_t>(0), n_elements_)) {
int ridx = d_ridx_[idx / feature_stride_];
int gidx =
matrix_
.gidx_iter[ridx * matrix_.row_stride + group_.start_feature + idx % feature_stride_];
if (matrix_.is_dense || gidx != matrix_.NumBins()) {
// If we are not using shared memory, accumulate the values directly into
// global memory
GradientSumT truncated{
TruncateWithRoundingFactor<T>(rounding.rounding.GetGrad(),
d_gpair[ridx].GetGrad()),
TruncateWithRoundingFactor<T>(rounding.rounding.GetHess(),
d_gpair[ridx].GetHess()),
TruncateWithRoundingFactor<GradientSumT::ValueT>(rounding_.rounding.GetGrad(),
d_gpair_[ridx].GetGrad()),
TruncateWithRoundingFactor<GradientSumT::ValueT>(rounding_.rounding.GetHess(),
d_gpair_[ridx].GetHess()),
};
dh::AtomicAddGpair(d_node_hist + gidx, truncated);
dh::AtomicAddGpair(d_node_hist_ + gidx, truncated);
}
}
}
};

template <typename GradientSumT, bool use_shared_memory_histograms, int kBlockThreads,
int kItemsPerThread>
__global__ void __launch_bounds__(kBlockThreads)
SharedMemHistKernel(const EllpackDeviceAccessor matrix,
const FeatureGroupsAccessor feature_groups,
common::Span<const RowPartitioner::RowIndexT> d_ridx,
GradientSumT* __restrict__ d_node_hist,
const GradientPair* __restrict__ d_gpair,
HistRounding<GradientSumT> const rounding) {
using SharedSumT = typename HistRounding<GradientSumT>::SharedSumT;
using T = typename GradientSumT::ValueT;

extern __shared__ char smem[];
const FeatureGroup group = feature_groups[blockIdx.y];
SharedSumT* smem_arr = reinterpret_cast<SharedSumT*>(smem);
auto agent = HistogramAgent<GradientSumT, kBlockThreads, kItemsPerThread>(
smem_arr, d_node_hist, group, matrix, d_ridx, rounding, d_gpair);
if (use_shared_memory_histograms) {
// Write shared memory back to global memory
__syncthreads();
for (auto i : dh::BlockStrideRange(0, group.num_bins)) {
auto truncated = rounding.ToFloatingPoint(smem_arr[i]);
dh::AtomicAddGpair(d_node_hist + group.start_bin + i, truncated);
}
agent.BuildHistogramWithShared();
} else {
agent.BuildHistogramWithGlobal();
}
}

Expand All @@ -182,78 +259,71 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix,
common::Span<GradientPair const> gpair,
common::Span<const uint32_t> d_ridx,
common::Span<GradientSumT> histogram,
HistRounding<GradientSumT> rounding,
bool force_global_memory) {
HistRounding<GradientSumT> rounding, bool force_global_memory) {
// decide whether to use shared memory
int device = 0;
dh::safe_cuda(cudaGetDevice(&device));
// opt into maximum shared memory for the kernel if necessary
size_t max_shared_memory = dh::MaxSharedMemoryOptin(device);

size_t smem_size = sizeof(typename HistRounding<GradientSumT>::SharedSumT) *
feature_groups.max_group_bins;
size_t smem_size =
sizeof(typename HistRounding<GradientSumT>::SharedSumT) * feature_groups.max_group_bins;
bool shared = !force_global_memory && smem_size <= max_shared_memory;
smem_size = shared ? smem_size : 0;

constexpr int kBlockThreads = 1024;
constexpr int kItemsPerThread = 8;
constexpr int kItemsPerTile = kBlockThreads * kItemsPerThread;

auto runit = [&](auto kernel) {
if (shared) {
dh::safe_cuda(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
max_shared_memory));
dh::safe_cuda(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
max_shared_memory));
}

// determine the launch configuration
int min_grid_size;
int block_threads = 1024;
dh::safe_cuda(cudaOccupancyMaxPotentialBlockSize(
&min_grid_size, &block_threads, kernel, smem_size, 0));

int num_groups = feature_groups.NumGroups();
int n_mps = 0;
dh::safe_cuda(
cudaDeviceGetAttribute(&n_mps, cudaDevAttrMultiProcessorCount, device));
dh::safe_cuda(cudaDeviceGetAttribute(&n_mps, cudaDevAttrMultiProcessorCount, device));
int n_blocks_per_mp = 0;
dh::safe_cuda(cudaOccupancyMaxActiveBlocksPerMultiprocessor(
&n_blocks_per_mp, kernel, block_threads, smem_size));
dh::safe_cuda(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&n_blocks_per_mp, kernel,
kBlockThreads, smem_size));
// This gives the number of blocks to keep the device occupied
// Use this as the maximum number of blocks
unsigned grid_size = n_blocks_per_mp * n_mps;

// TODO(canonizer): This is really a hack, find a better way to distribute
// the data among thread blocks. The intention is to generate enough thread
// blocks to fill the GPU, but avoid having too many thread blocks, as this
// is less efficient when the number of rows is low. At least one thread
// block per feature group is required. The number of thread blocks:
// - for num_groups <= num_groups_threshold, around grid_size * num_groups
// - for num_groups_threshold <= num_groups <= num_groups_threshold *
// grid_size,
// around grid_size * num_groups_threshold
// - for num_groups_threshold * grid_size <= num_groups, around num_groups
int num_groups_threshold = 4;
grid_size = common::DivRoundUp(
grid_size, common::DivRoundUp(num_groups, num_groups_threshold));

using T = typename GradientSumT::ValueT;
// Otherwise launch blocks such that each block has a minimum amount of work to do
// There are fixed costs to launching each block, e.g. zeroing shared memory
// The below amount of minimum work was found by experimentation
constexpr int kMinItemsPerBlock = kItemsPerTile;
int columns_per_group = common::DivRoundUp(matrix.row_stride, feature_groups.NumGroups());
// Average number of matrix elements processed by each group
std::size_t items_per_group = d_ridx.size() * columns_per_group;

// Allocate number of blocks such that each block has about kMinItemsPerBlock work
// Up to a maximum where the device is saturated
grid_size =
min(grid_size,
unsigned(common::DivRoundUp(items_per_group, kMinItemsPerBlock)));

dh::LaunchKernel {dim3(grid_size, num_groups),
static_cast<uint32_t>(block_threads),
smem_size} (kernel, matrix, feature_groups, d_ridx,
histogram.data(), gpair.data(), rounding);
static_cast<uint32_t>(kBlockThreads), smem_size}(
kernel, matrix, feature_groups, d_ridx, histogram.data(), gpair.data(), rounding);
};

if (shared) {
runit(SharedMemHistKernel<GradientSumT, true>);
runit(SharedMemHistKernel<GradientSumT, true, kBlockThreads, kItemsPerThread>);
} else {
runit(SharedMemHistKernel<GradientSumT, false>);
runit(SharedMemHistKernel<GradientSumT, false, kBlockThreads, kItemsPerThread>);
}

dh::safe_cuda(cudaGetLastError());
}

template void BuildGradientHistogram<GradientPairPrecise>(
EllpackDeviceAccessor const& matrix,
FeatureGroupsAccessor const& feature_groups,
common::Span<GradientPair const> gpair,
common::Span<const uint32_t> ridx,
common::Span<GradientPairPrecise> histogram,
HistRounding<GradientPairPrecise> rounding,
EllpackDeviceAccessor const& matrix, FeatureGroupsAccessor const& feature_groups,
common::Span<GradientPair const> gpair, common::Span<const uint32_t> ridx,
common::Span<GradientPairPrecise> histogram, HistRounding<GradientPairPrecise> rounding,
bool force_global_memory);

} // namespace tree
Expand Down