Skip to content

Commit

Permalink
Use integer gradients in gpu_hist split evaluation (#8274)
Browse files Browse the repository at this point in the history
  • Loading branch information
RAMitchell committed Oct 11, 2022
1 parent c68684f commit 210915c
Show file tree
Hide file tree
Showing 12 changed files with 224 additions and 292 deletions.
4 changes: 2 additions & 2 deletions include/xgboost/base.h
Expand Up @@ -264,8 +264,8 @@ using GradientPairPrecise = detail::GradientPairInternal<double>;
* we don't accidentally use it in gain calculations.*/
class GradientPairInt64 {
using T = int64_t;
T grad_;
T hess_;
T grad_ = 0;
T hess_ = 0;

public:
using ValueT = T;
Expand Down
145 changes: 79 additions & 66 deletions src/tree/gpu_hist/evaluate_splits.cu
Expand Up @@ -15,17 +15,20 @@ namespace xgboost {
namespace tree {

// With constraints
XGBOOST_DEVICE float LossChangeMissing(const GradientPairPrecise &scan,
const GradientPairPrecise &missing,
const GradientPairPrecise &parent_sum,
XGBOOST_DEVICE float LossChangeMissing(const GradientPairInt64 &scan,
const GradientPairInt64 &missing,
const GradientPairInt64 &parent_sum,
const GPUTrainingParam &param, bst_node_t nidx,
bst_feature_t fidx,
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
bool &missing_left_out) { // NOLINT
bool &missing_left_out, const GradientQuantiser& quantiser) { // NOLINT
const auto left_sum = scan + missing;
float missing_left_gain =
evaluator.CalcSplitGain(param, nidx, fidx, left_sum, parent_sum - left_sum);
float missing_right_gain = evaluator.CalcSplitGain(param, nidx, fidx, scan, parent_sum - scan);
float missing_left_gain = evaluator.CalcSplitGain(
param, nidx, fidx, quantiser.ToFloatingPoint(left_sum),
quantiser.ToFloatingPoint(parent_sum - left_sum));
float missing_right_gain = evaluator.CalcSplitGain(
param, nidx, fidx, quantiser.ToFloatingPoint(scan),
quantiser.ToFloatingPoint(parent_sum - scan));

missing_left_out = missing_left_gain > missing_right_gain;
return missing_left_out?missing_left_gain:missing_right_gain;
Expand All @@ -42,9 +45,9 @@ template <int kBlockSize>
class EvaluateSplitAgent {
public:
using ArgMaxT = cub::KeyValuePair<int, float>;
using BlockScanT = cub::BlockScan<GradientPairPrecise, kBlockSize>;
using BlockScanT = cub::BlockScan<GradientPairInt64, kBlockSize>;
using MaxReduceT = cub::WarpReduce<ArgMaxT>;
using SumReduceT = cub::WarpReduce<GradientPairPrecise>;
using SumReduceT = cub::WarpReduce<GradientPairInt64>;

struct TempStorage {
typename BlockScanT::TempStorage scan;
Expand All @@ -59,67 +62,67 @@ class EvaluateSplitAgent {
const uint32_t gidx_end; // end bin for i^th feature
const dh::LDGIterator<float> feature_values;
const GradientPairInt64 *node_histogram;
const GradientQuantizer &rounding;
const GradientPairPrecise parent_sum;
const GradientPairPrecise missing;
const GradientQuantiser &rounding;
const GradientPairInt64 parent_sum;
const GradientPairInt64 missing;
const GPUTrainingParam &param;
const TreeEvaluator::SplitEvaluator<GPUTrainingParam> &evaluator;
TempStorage *temp_storage;
SumCallbackOp<GradientPairPrecise> prefix_op;
SumCallbackOp<GradientPairInt64> prefix_op;
static float constexpr kNullGain = -std::numeric_limits<bst_float>::infinity();

__device__ EvaluateSplitAgent(TempStorage *temp_storage, int fidx,
const EvaluateSplitInputs &inputs,
const EvaluateSplitSharedInputs &shared_inputs,
const TreeEvaluator::SplitEvaluator<GPUTrainingParam> &evaluator)
: temp_storage(temp_storage),
nidx(inputs.nidx),
fidx(fidx),
__device__ EvaluateSplitAgent(
TempStorage *temp_storage, int fidx, const EvaluateSplitInputs &inputs,
const EvaluateSplitSharedInputs &shared_inputs,
const TreeEvaluator::SplitEvaluator<GPUTrainingParam> &evaluator)
: temp_storage(temp_storage), nidx(inputs.nidx), fidx(fidx),
min_fvalue(__ldg(shared_inputs.min_fvalue.data() + fidx)),
gidx_begin(__ldg(shared_inputs.feature_segments.data() + fidx)),
gidx_end(__ldg(shared_inputs.feature_segments.data() + fidx + 1)),
feature_values(shared_inputs.feature_values.data()),
node_histogram(inputs.gradient_histogram.data()),
rounding(shared_inputs.rounding),
parent_sum(dh::LDGIterator<GradientPairPrecise>(&inputs.parent_sum)[0]),
param(shared_inputs.param),
evaluator(evaluator),
parent_sum(dh::LDGIterator<GradientPairInt64>(&inputs.parent_sum)[0]),
param(shared_inputs.param), evaluator(evaluator),
missing(parent_sum - ReduceFeature()) {
static_assert(kBlockSize == 32,
"This kernel relies on the assumption block_size == warp_size");
static_assert(
kBlockSize == 32,
"This kernel relies on the assumption block_size == warp_size");
// There should be no missing value gradients for a dense matrix
KERNEL_CHECK(!shared_inputs.is_dense || missing.GetQuantisedHess() == 0);
}
__device__ GradientPairPrecise ReduceFeature() {
GradientPairPrecise local_sum;
for (int idx = gidx_begin + threadIdx.x; idx < gidx_end; idx += kBlockSize) {
__device__ GradientPairInt64 ReduceFeature() {
GradientPairInt64 local_sum;
for (int idx = gidx_begin + threadIdx.x; idx < gidx_end;
idx += kBlockSize) {
local_sum += LoadGpair(node_histogram + idx);
}
local_sum = SumReduceT(temp_storage->sum_reduce).Sum(local_sum);
// Broadcast result from thread 0
return {__shfl_sync(0xffffffff, local_sum.GetGrad(), 0),
__shfl_sync(0xffffffff, local_sum.GetHess(), 0)};
return {__shfl_sync(0xffffffff, local_sum.GetQuantisedGrad(), 0),
__shfl_sync(0xffffffff, local_sum.GetQuantisedHess(), 0)};
}

// Load using efficient 128 vector load instruction
__device__ __forceinline__ GradientPairPrecise LoadGpair(const GradientPairInt64 *ptr) {
__device__ __forceinline__ GradientPairInt64 LoadGpair(const GradientPairInt64 *ptr) {
float4 tmp = *reinterpret_cast<const float4 *>(ptr);
auto gpair_int = *reinterpret_cast<const GradientPairInt64 *>(&tmp);
static_assert(sizeof(decltype(gpair_int)) == sizeof(float4),
auto gpair = *reinterpret_cast<const GradientPairInt64 *>(&tmp);
static_assert(sizeof(decltype(gpair)) == sizeof(float4),
"Vector type size does not match gradient pair size.");
return rounding.ToFloatingPoint(gpair_int);
return gpair;
}

__device__ __forceinline__ void Numerical(DeviceSplitCandidate *__restrict__ best_split) {
for (int scan_begin = gidx_begin; scan_begin < gidx_end; scan_begin += kBlockSize) {
bool thread_active = (scan_begin + threadIdx.x) < gidx_end;
GradientPairPrecise bin = thread_active ? LoadGpair(node_histogram + scan_begin + threadIdx.x)
: GradientPairPrecise();
GradientPairInt64 bin = thread_active ? LoadGpair(node_histogram + scan_begin + threadIdx.x)
: GradientPairInt64();
BlockScanT(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum(), prefix_op);
// Whether the gradient of missing values is put to the left side.
bool missing_left = true;
float gain = thread_active ? LossChangeMissing(bin, missing, parent_sum, param, nidx, fidx,
evaluator, missing_left)
evaluator, missing_left, rounding)
: kNullGain;

// Find thread with best gain
auto best = MaxReduceT(temp_storage->max_reduce).Reduce({threadIdx.x, gain}, cub::ArgMax());
// This reduce result is only valid in thread 0
Expand All @@ -132,10 +135,10 @@ class EvaluateSplitAgent {
int split_gidx = (scan_begin + threadIdx.x) - 1;
float fvalue =
split_gidx < static_cast<int>(gidx_begin) ? min_fvalue : feature_values[split_gidx];
GradientPairPrecise left = missing_left ? bin + missing : bin;
GradientPairPrecise right = parent_sum - left;
GradientPairInt64 left = missing_left ? bin + missing : bin;
GradientPairInt64 right = parent_sum - left;
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, fidx, left, right,
false, param);
false, param, rounding);
}
}
}
Expand All @@ -145,12 +148,12 @@ class EvaluateSplitAgent {
bool thread_active = (scan_begin + threadIdx.x) < gidx_end;

auto rest = thread_active ? LoadGpair(node_histogram + scan_begin + threadIdx.x)
: GradientPairPrecise();
GradientPairPrecise bin = parent_sum - rest - missing;
: GradientPairInt64();
GradientPairInt64 bin = parent_sum - rest - missing;
// Whether the gradient of missing values is put to the left side.
bool missing_left = true;
float gain = thread_active ? LossChangeMissing(bin, missing, parent_sum, param, nidx, fidx,
evaluator, missing_left)
evaluator, missing_left, rounding)
: kNullGain;

// Find thread with best gain
Expand All @@ -162,10 +165,10 @@ class EvaluateSplitAgent {
if (threadIdx.x == best_thread) {
int32_t split_gidx = (scan_begin + threadIdx.x);
float fvalue = feature_values[split_gidx];
GradientPairPrecise left = missing_left ? bin + missing : bin;
GradientPairPrecise right = parent_sum - left;
GradientPairInt64 left = missing_left ? bin + missing : bin;
GradientPairInt64 right = parent_sum - left;
best_split->UpdateCat(gain, missing_left ? kLeftDir : kRightDir,
static_cast<bst_cat_t>(fvalue), fidx, left, right, param);
static_cast<bst_cat_t>(fvalue), fidx, left, right, param, rounding);
}
}
}
Expand All @@ -174,11 +177,13 @@ class EvaluateSplitAgent {
*/
__device__ __forceinline__ void PartitionUpdate(bst_bin_t scan_begin, bool thread_active,
bool missing_left, bst_bin_t it,
GradientPairPrecise const &left_sum,
GradientPairPrecise const &right_sum,
GradientPairInt64 const &left_sum,
GradientPairInt64 const &right_sum,
DeviceSplitCandidate *__restrict__ best_split) {
auto gain =
thread_active ? evaluator.CalcSplitGain(param, nidx, fidx, left_sum, right_sum) : kNullGain;
auto gain = thread_active
? evaluator.CalcSplitGain(param, nidx, fidx, rounding.ToFloatingPoint(left_sum),
rounding.ToFloatingPoint(right_sum))
: kNullGain;

// Find thread with best gain
auto best = MaxReduceT(temp_storage->max_reduce).Reduce({threadIdx.x, gain}, cub::ArgMax());
Expand All @@ -191,7 +196,7 @@ class EvaluateSplitAgent {
// index of best threshold inside a feature.
auto best_thresh = it - gidx_begin;
best_split->UpdateCat(gain, missing_left ? kLeftDir : kRightDir, best_thresh, fidx, left_sum,
right_sum, param);
right_sum, param, rounding);
}
}
/**
Expand All @@ -213,28 +218,28 @@ class EvaluateSplitAgent {
bool thread_active = it < it_end;

auto right_sum = thread_active ? LoadGpair(node_histogram + sorted_idx[it] - node_offset)
: GradientPairPrecise();
: GradientPairInt64();
// No min value for cat feature, use inclusive scan.
BlockScanT(temp_storage->scan).InclusiveSum(right_sum, right_sum, prefix_op);
GradientPairPrecise left_sum = parent_sum - right_sum;
GradientPairInt64 left_sum = parent_sum - right_sum;

PartitionUpdate(scan_begin, thread_active, true, it, left_sum, right_sum, best_split);
}

// backward
it_begin = gidx_end - 1;
it_end = it_begin - n_bins + 1;
prefix_op = SumCallbackOp<GradientPairPrecise>{}; // reset
prefix_op = SumCallbackOp<GradientPairInt64>{}; // reset

for (bst_bin_t scan_begin = it_begin; scan_begin > it_end; scan_begin -= kBlockSize) {
auto it = scan_begin - static_cast<bst_bin_t>(threadIdx.x);
bool thread_active = it > it_end;

auto left_sum = thread_active ? LoadGpair(node_histogram + sorted_idx[it] - node_offset)
: GradientPairPrecise();
: GradientPairInt64();
// No min value for cat feature, use inclusive scan.
BlockScanT(temp_storage->scan).InclusiveSum(left_sum, left_sum, prefix_op);
GradientPairPrecise right_sum = parent_sum - left_sum;
GradientPairInt64 right_sum = parent_sum - left_sum;

PartitionUpdate(scan_begin, thread_active, false, it, left_sum, right_sum, best_split);
}
Expand Down Expand Up @@ -399,22 +404,30 @@ void GPUHistEvaluator::EvaluateSplits(
auto const input = d_inputs[i];
auto &split = out_splits[i];
// Subtract parent gain here
// As it is constant, this is more efficient than doing it during every split evaluation
float parent_gain = CalcGain(shared_inputs.param, input.parent_sum);
// As it is constant, this is more efficient than doing it during every
// split evaluation
float parent_gain =
CalcGain(shared_inputs.param,
shared_inputs.rounding.ToFloatingPoint(input.parent_sum));
split.loss_chg -= parent_gain;
auto fidx = out_splits[i].findex;

if (split.is_cat) {
SetCategoricalSplit(shared_inputs, d_sorted_idx, fidx, i,
device_cats_accessor.GetNodeCatStorage(input.nidx), &out_splits[i]);
device_cats_accessor.GetNodeCatStorage(input.nidx),
&out_splits[i]);
}

float base_weight = evaluator.CalcWeight(input.nidx, shared_inputs.param,
GradStats{split.left_sum + split.right_sum});
float left_weight =
evaluator.CalcWeight(input.nidx, shared_inputs.param, GradStats{split.left_sum});
float right_weight =
evaluator.CalcWeight(input.nidx, shared_inputs.param, GradStats{split.right_sum});
float base_weight =
evaluator.CalcWeight(input.nidx, shared_inputs.param,
shared_inputs.rounding.ToFloatingPoint(
split.left_sum + split.right_sum));
float left_weight = evaluator.CalcWeight(
input.nidx, shared_inputs.param,
shared_inputs.rounding.ToFloatingPoint(split.left_sum));
float right_weight = evaluator.CalcWeight(
input.nidx, shared_inputs.param,
shared_inputs.rounding.ToFloatingPoint(split.right_sum));

d_entries[i] = GPUExpandEntry{input.nidx, input.depth, out_splits[i],
base_weight, left_weight, right_weight};
Expand Down
5 changes: 3 additions & 2 deletions src/tree/gpu_hist/evaluate_splits.cuh
Expand Up @@ -23,19 +23,20 @@ namespace tree {
struct EvaluateSplitInputs {
int nidx;
int depth;
GradientPairPrecise parent_sum;
GradientPairInt64 parent_sum;
common::Span<const bst_feature_t> feature_set;
common::Span<const GradientPairInt64> gradient_histogram;
};

// Inputs necessary for all nodes
struct EvaluateSplitSharedInputs {
GPUTrainingParam param;
GradientQuantizer rounding;
GradientQuantiser rounding;
common::Span<FeatureType const> feature_types;
common::Span<const uint32_t> feature_segments;
common::Span<const float> feature_values;
common::Span<const float> min_fvalue;
bool is_dense;
XGBOOST_DEVICE auto Features() const { return feature_segments.size() - 1; }
__device__ auto FeatureBins(bst_feature_t fidx) const {
return feature_segments[fidx + 1] - feature_segments[fidx];
Expand Down
2 changes: 1 addition & 1 deletion src/tree/gpu_hist/expand_entry.cuh
Expand Up @@ -27,7 +27,7 @@ struct GPUExpandEntry {
left_weight{left}, right_weight{right} {}
bool IsValid(const TrainParam& param, int num_leaves) const {
if (split.loss_chg <= kRtEps) return false;
if (split.left_sum.GetHess() == 0 || split.right_sum.GetHess() == 0) {
if (split.left_sum.GetQuantisedHess() == 0 || split.right_sum.GetQuantisedHess() == 0) {
return false;
}
if (split.loss_chg < param.min_split_loss) {
Expand Down
10 changes: 5 additions & 5 deletions src/tree/gpu_hist/histogram.cu
Expand Up @@ -72,7 +72,7 @@ struct Clip : public thrust::unary_function<GradientPair, Pair> {
}
};

GradientQuantizer::GradientQuantizer(common::Span<GradientPair const> gpair) {
GradientQuantiser::GradientQuantiser(common::Span<GradientPair const> gpair) {
using GradientSumT = GradientPairPrecise;
using T = typename GradientSumT::ValueT;
dh::XGBCachingDeviceAllocator<char> alloc;
Expand Down Expand Up @@ -153,14 +153,14 @@ class HistogramAgent {
const EllpackDeviceAccessor& matrix_;
const int feature_stride_;
const std::size_t n_elements_;
const GradientQuantizer& rounding_;
const GradientQuantiser& rounding_;

public:
__device__ HistogramAgent(GradientPairInt64* smem_arr,
GradientPairInt64* __restrict__ d_node_hist, const FeatureGroup& group,
const EllpackDeviceAccessor& matrix,
common::Span<const RowPartitioner::RowIndexT> d_ridx,
const GradientQuantizer& rounding, const GradientPair* d_gpair)
const GradientQuantiser& rounding, const GradientPair* d_gpair)
: smem_arr_(smem_arr),
d_node_hist_(d_node_hist),
d_ridx_(d_ridx.data()),
Expand Down Expand Up @@ -254,7 +254,7 @@ __global__ void __launch_bounds__(kBlockThreads)
common::Span<const RowPartitioner::RowIndexT> d_ridx,
GradientPairInt64* __restrict__ d_node_hist,
const GradientPair* __restrict__ d_gpair,
GradientQuantizer const rounding) {
GradientQuantiser const rounding) {
extern __shared__ char smem[];
const FeatureGroup group = feature_groups[blockIdx.y];
auto smem_arr = reinterpret_cast<GradientPairInt64*>(smem);
Expand All @@ -272,7 +272,7 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix,
common::Span<GradientPair const> gpair,
common::Span<const uint32_t> d_ridx,
common::Span<GradientPairInt64> histogram,
GradientQuantizer rounding, bool force_global_memory) {
GradientQuantiser rounding, bool force_global_memory) {
// decide whether to use shared memory
int device = 0;
dh::safe_cuda(cudaGetDevice(&device));
Expand Down

0 comments on commit 210915c

Please sign in to comment.