Skip to content

Commit

Permalink
Add max_cat_threshold to GPU and handle missing cat values. (#8212)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Sep 6, 2022
1 parent 441ffc0 commit b5eb36f
Show file tree
Hide file tree
Showing 10 changed files with 545 additions and 121 deletions.
142 changes: 83 additions & 59 deletions src/tree/gpu_hist/evaluate_splits.cu
Expand Up @@ -43,9 +43,9 @@ class EvaluateSplitAgent {
public:
using ArgMaxT = cub::KeyValuePair<int, float>;
using BlockScanT = cub::BlockScan<GradientPairPrecise, kBlockSize>;
using MaxReduceT =
cub::WarpReduce<ArgMaxT>;
using MaxReduceT = cub::WarpReduce<ArgMaxT>;
using SumReduceT = cub::WarpReduce<GradientPairPrecise>;

struct TempStorage {
typename BlockScanT::TempStorage scan;
typename MaxReduceT::TempStorage max_reduce;
Expand Down Expand Up @@ -159,49 +159,81 @@ class EvaluateSplitAgent {
if (threadIdx.x == best_thread) {
int32_t split_gidx = (scan_begin + threadIdx.x);
float fvalue = feature_values[split_gidx];
GradientPairPrecise left =
missing_left ? bin + missing : bin;
GradientPairPrecise left = missing_left ? bin + missing : bin;
GradientPairPrecise right = parent_sum - left;
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, fidx, left, right,
true, param);
best_split->UpdateCat(gain, missing_left ? kLeftDir : kRightDir,
static_cast<bst_cat_t>(fvalue), fidx, left, right, param);
}
}
}
/**
* \brief Gather and update the best split.
*/
__device__ __forceinline__ void PartitionUpdate(bst_bin_t scan_begin, bool thread_active,
bool missing_left, bst_bin_t it,
GradientPairPrecise const &left_sum,
GradientPairPrecise const &right_sum,
DeviceSplitCandidate *__restrict__ best_split) {
auto gain =
thread_active ? evaluator.CalcSplitGain(param, nidx, fidx, left_sum, right_sum) : kNullGain;

// Find thread with best gain
auto best = MaxReduceT(temp_storage->max_reduce).Reduce({threadIdx.x, gain}, cub::ArgMax());
// This reduce result is only valid in thread 0
// broadcast to the rest of the warp
auto best_thread = __shfl_sync(0xffffffff, best.key, 0);
// Best thread updates the split
if (threadIdx.x == best_thread) {
assert(thread_active);
// index of best threshold inside a feature.
auto best_thresh = it - gidx_begin;
best_split->UpdateCat(gain, missing_left ? kLeftDir : kRightDir, best_thresh, fidx, left_sum,
right_sum, param);
}
}
/**
* \brief Partition-based split for categorical feature.
*/
__device__ __forceinline__ void Partition(DeviceSplitCandidate *__restrict__ best_split,
bst_feature_t * __restrict__ sorted_idx,
std::size_t offset) {
for (int scan_begin = gidx_begin; scan_begin < gidx_end; scan_begin += kBlockSize) {
bool thread_active = (scan_begin + threadIdx.x) < gidx_end;

auto rest = thread_active
? LoadGpair(node_histogram + sorted_idx[scan_begin + threadIdx.x] - offset)
: GradientPairPrecise();
common::Span<bst_feature_t> sorted_idx,
std::size_t node_offset,
GPUTrainingParam const &param) {
bst_bin_t n_bins_feature = gidx_end - gidx_begin;
auto n_bins = std::min(param.max_cat_threshold, n_bins_feature);

bst_bin_t it_begin = gidx_begin;
bst_bin_t it_end = it_begin + n_bins - 1;

// forward
for (bst_bin_t scan_begin = it_begin; scan_begin < it_end; scan_begin += kBlockSize) {
auto it = scan_begin + static_cast<bst_bin_t>(threadIdx.x);
bool thread_active = it < it_end;

auto right_sum = thread_active ? LoadGpair(node_histogram + sorted_idx[it] - node_offset)
: GradientPairPrecise();
// No min value for cat feature, use inclusive scan.
BlockScanT(temp_storage->scan).InclusiveSum(rest, rest, prefix_op);
GradientPairPrecise bin = parent_sum - rest - missing;
BlockScanT(temp_storage->scan).InclusiveSum(right_sum, right_sum, prefix_op);
GradientPairPrecise left_sum = parent_sum - right_sum;

// Whether the gradient of missing values is put to the left side.
bool missing_left = true;
float gain = thread_active ? LossChangeMissing(bin, missing, parent_sum, param, nidx, fidx,
evaluator, missing_left)
: kNullGain;
PartitionUpdate(scan_begin, thread_active, true, it, left_sum, right_sum, best_split);
}

// backward
it_begin = gidx_end - 1;
it_end = it_begin - n_bins + 1;
prefix_op = SumCallbackOp<GradientPairPrecise>{}; // reset

// Find thread with best gain
auto best =
MaxReduceT(temp_storage->max_reduce).Reduce({threadIdx.x, gain}, cub::ArgMax());
// This reduce result is only valid in thread 0
// broadcast to the rest of the warp
auto best_thread = __shfl_sync(0xffffffff, best.key, 0);
// Best thread updates the split
if (threadIdx.x == best_thread) {
GradientPairPrecise left = missing_left ? bin + missing : bin;
GradientPairPrecise right = parent_sum - left;
auto best_thresh =
threadIdx.x + (scan_begin - gidx_begin); // index of best threshold inside a feature.
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, best_thresh, fidx, left,
right, true, param);
}
for (bst_bin_t scan_begin = it_begin; scan_begin > it_end; scan_begin -= kBlockSize) {
auto it = scan_begin - static_cast<bst_bin_t>(threadIdx.x);
bool thread_active = it > it_end;

auto left_sum = thread_active ? LoadGpair(node_histogram + sorted_idx[it] - node_offset)
: GradientPairPrecise();
// No min value for cat feature, use inclusive scan.
BlockScanT(temp_storage->scan).InclusiveSum(left_sum, left_sum, prefix_op);
GradientPairPrecise right_sum = parent_sum - left_sum;

PartitionUpdate(scan_begin, thread_active, false, it, left_sum, right_sum, best_split);
}
}
};
Expand Down Expand Up @@ -242,7 +274,7 @@ __global__ __launch_bounds__(kBlockSize) void EvaluateSplitsKernel(
auto total_bins = shared_inputs.feature_values.size();
size_t offset = total_bins * input_idx;
auto node_sorted_idx = sorted_idx.subspan(offset, total_bins);
agent.Partition(&best_split, node_sorted_idx.data(), offset);
agent.Partition(&best_split, node_sorted_idx, offset, shared_inputs.param);
}
} else {
agent.Numerical(&best_split);
Expand Down Expand Up @@ -273,36 +305,28 @@ __device__ void SetCategoricalSplit(const EvaluateSplitSharedInputs &shared_inpu

// Simple case for one hot split
if (common::UseOneHot(shared_inputs.FeatureBins(fidx), shared_inputs.param.max_cat_to_onehot)) {
out_split.split_cats.Set(common::AsCat(out_split.fvalue));
out_split.split_cats.Set(common::AsCat(out_split.thresh));
return;
}

// partition-based split
auto node_sorted_idx = d_sorted_idx.subspan(shared_inputs.feature_values.size() * input_idx,
shared_inputs.feature_values.size());
size_t node_offset = input_idx * shared_inputs.feature_values.size();
auto best_thresh = out_split.PopBestThresh();
auto const best_thresh = out_split.thresh;
if (best_thresh == -1) {
return;
}
auto f_sorted_idx = node_sorted_idx.subspan(shared_inputs.feature_segments[fidx],
shared_inputs.FeatureBins(fidx));
if (out_split.dir != kLeftDir) {
// forward, missing on right
auto beg = dh::tcbegin(f_sorted_idx);
// Don't put all the categories into one side
auto boundary = std::min(static_cast<size_t>((best_thresh + 1)), (f_sorted_idx.size() - 1));
boundary = std::max(boundary, static_cast<size_t>(1ul));
auto end = beg + boundary;
thrust::for_each(thrust::seq, beg, end, [&](auto c) {
auto cat = shared_inputs.feature_values[c - node_offset];
assert(!out_split.split_cats.Check(cat) && "already set");
out_split.SetCat(cat);
});
} else {
assert((f_sorted_idx.size() - best_thresh + 1) != 0 && " == 0");
thrust::for_each(thrust::seq, dh::tcrbegin(f_sorted_idx),
dh::tcrbegin(f_sorted_idx) + (f_sorted_idx.size() - best_thresh), [&](auto c) {
auto cat = shared_inputs.feature_values[c - node_offset];
out_split.SetCat(cat);
});
}
bool forward = out_split.dir == kLeftDir;
bst_bin_t partition = forward ? best_thresh + 1 : best_thresh;
auto beg = dh::tcbegin(f_sorted_idx);
assert(partition > 0 && "Invalid partition.");
thrust::for_each(thrust::seq, beg, beg + partition, [&](size_t c) {
auto cat = shared_inputs.feature_values[c - node_offset];
out_split.SetCat(cat);
});
}

void GPUHistEvaluator::LaunchEvaluateSplits(
Expand Down
3 changes: 2 additions & 1 deletion src/tree/gpu_hist/evaluate_splits.cuh
Expand Up @@ -141,7 +141,8 @@ class GPUHistEvaluator {
*/
common::Span<CatST const> GetHostNodeCats(bst_node_t nidx) const {
copy_stream_.View().Sync();
auto cats_out = common::Span<CatST const>{h_split_cats_}.subspan(nidx * node_categorical_storage_size_, node_categorical_storage_size_);
auto cats_out = common::Span<CatST const>{h_split_cats_}.subspan(
nidx * node_categorical_storage_size_, node_categorical_storage_size_);
return cats_out;
}
/**
Expand Down
31 changes: 18 additions & 13 deletions src/tree/hist/evaluate_splits.h
Expand Up @@ -143,8 +143,12 @@ class HistEvaluator {
static_assert(d_step == +1 || d_step == -1, "Invalid step.");

auto const &cut_ptr = cut.Ptrs();
auto const &cut_val = cut.Values();
auto const &parent = snode_[nidx];
bst_bin_t n_bins_feature{static_cast<bst_bin_t>(cut_ptr[fidx + 1] - cut_ptr[fidx])};

bst_bin_t f_begin = cut_ptr[fidx];
bst_bin_t f_end = cut_ptr[fidx + 1];
bst_bin_t n_bins_feature{f_end - f_begin};
auto n_bins = std::min(param_.max_cat_threshold, n_bins_feature);

// statistics on both sides of split
Expand All @@ -153,19 +157,18 @@ class HistEvaluator {
// best split so far
SplitEntry best;

auto f_hist = hist.subspan(cut_ptr[fidx], n_bins_feature);
bst_bin_t ibegin, iend;
bst_bin_t f_begin = cut_ptr[fidx];
auto f_hist = hist.subspan(f_begin, n_bins_feature);
bst_bin_t it_begin, it_end;
if (d_step > 0) {
ibegin = f_begin;
iend = ibegin + n_bins - 1;
it_begin = f_begin;
it_end = it_begin + n_bins - 1;
} else {
ibegin = static_cast<bst_bin_t>(cut_ptr[fidx + 1]) - 1;
iend = ibegin - n_bins + 1;
it_begin = f_end - 1;
it_end = it_begin - n_bins + 1;
}

bst_bin_t best_thresh{-1};
for (bst_bin_t i = ibegin; i != iend; i += d_step) {
for (bst_bin_t i = it_begin; i != it_end; i += d_step) {
auto j = i - f_begin; // index local to current feature
if (d_step == 1) {
right_sum.Add(f_hist[sorted_idx[j]].GetGrad(), f_hist[sorted_idx[j]].GetHess());
Expand All @@ -187,13 +190,15 @@ class HistEvaluator {
}

if (best_thresh != -1) {
auto n = common::CatBitField::ComputeStorageSize(n_bins_feature + 1);
auto n = common::CatBitField::ComputeStorageSize(n_bins_feature);
best.cat_bits = decltype(best.cat_bits)(n, 0);
common::CatBitField cat_bits{best.cat_bits};
bst_bin_t partition = d_step == 1 ? (best_thresh - ibegin + 1) : (best_thresh - f_begin);
bst_bin_t partition = d_step == 1 ? (best_thresh - it_begin + 1) : (best_thresh - f_begin);
CHECK_GT(partition, 0);
std::for_each(sorted_idx.begin(), sorted_idx.begin() + partition,
[&](size_t c) { cat_bits.Set(c); });
std::for_each(sorted_idx.begin(), sorted_idx.begin() + partition, [&](size_t c) {
auto cat = cut_val[c + f_begin];
cat_bits.Set(cat);
});
}

p_best->Update(best);
Expand Down
44 changes: 27 additions & 17 deletions src/tree/updater_gpu_common.cuh
Expand Up @@ -29,6 +29,7 @@ struct GPUTrainingParam {
float max_delta_step;
float learning_rate;
uint32_t max_cat_to_onehot;
bst_bin_t max_cat_threshold;

GPUTrainingParam() = default;

Expand All @@ -38,7 +39,8 @@ struct GPUTrainingParam {
reg_alpha(param.reg_alpha),
max_delta_step(param.max_delta_step),
learning_rate{param.learning_rate},
max_cat_to_onehot{param.max_cat_to_onehot} {}
max_cat_to_onehot{param.max_cat_to_onehot},
max_cat_threshold{param.max_cat_threshold} {}
};

/**
Expand All @@ -57,6 +59,9 @@ struct DeviceSplitCandidate {
DefaultDirection dir {kLeftDir};
int findex {-1};
float fvalue {0};
// categorical split, either it's the split category for OHE or the threshold for partition-based
// split.
bst_cat_t thresh{-1};

common::CatBitField split_cats;
bool is_cat { false };
Expand All @@ -75,22 +80,6 @@ struct DeviceSplitCandidate {
*this = other;
}
}
/**
* \brief The largest encoded category in the split bitset
*/
bst_cat_t MaxCat() const {
// Reuse the fvalue for categorical values.
return static_cast<bst_cat_t>(fvalue);
}
/**
* \brief Return the best threshold for cat split, reset the value after return.
*/
XGBOOST_DEVICE size_t PopBestThresh() {
// fvalue is also being used for storing the threshold for categorical split
auto best_thresh = static_cast<size_t>(this->fvalue);
this->fvalue = 0;
return best_thresh;
}

template <typename T>
XGBOOST_DEVICE void SetCat(T c) {
Expand All @@ -116,13 +105,34 @@ struct DeviceSplitCandidate {
findex = findex_in;
}
}

/**
* \brief Update for partition-based splits.
*/
XGBOOST_DEVICE void UpdateCat(float loss_chg_in, DefaultDirection dir_in, bst_cat_t thresh_in,
bst_feature_t findex_in, GradientPairPrecise left_sum_in,
GradientPairPrecise right_sum_in, GPUTrainingParam const& param) {
if (loss_chg_in > loss_chg && left_sum_in.GetHess() >= param.min_child_weight &&
right_sum_in.GetHess() >= param.min_child_weight) {
loss_chg = loss_chg_in;
dir = dir_in;
fvalue = std::numeric_limits<float>::quiet_NaN();
thresh = thresh_in;
is_cat = true;
left_sum = left_sum_in;
right_sum = right_sum_in;
findex = findex_in;
}
}

XGBOOST_DEVICE bool IsValid() const { return loss_chg > 0.0f; }

friend std::ostream& operator<<(std::ostream& os, DeviceSplitCandidate const& c) {
os << "loss_chg:" << c.loss_chg << ", "
<< "dir: " << c.dir << ", "
<< "findex: " << c.findex << ", "
<< "fvalue: " << c.fvalue << ", "
<< "thresh: " << c.thresh << ", "
<< "is_cat: " << c.is_cat << ", "
<< "left sum: " << c.left_sum << ", "
<< "right sum: " << c.right_sum << std::endl;
Expand Down
12 changes: 7 additions & 5 deletions src/tree/updater_gpu_hist.cu
Expand Up @@ -601,13 +601,14 @@ struct GPUHistMakerDevice {

auto is_cat = candidate.split.is_cat;
if (is_cat) {
CHECK_LT(candidate.split.fvalue, std::numeric_limits<bst_cat_t>::max())
<< "Categorical feature value too large.";
std::vector<uint32_t> split_cats;
// should be set to nan in evaluation split.
CHECK(common::CheckNAN(candidate.split.fvalue));
std::vector<common::CatBitField::value_type> split_cats;

CHECK_GT(candidate.split.split_cats.Bits().size(), 0);
auto h_cats = this->evaluator_.GetHostNodeCats(candidate.nid);
auto max_cat = candidate.split.MaxCat();
split_cats.resize(common::CatBitField::ComputeStorageSize(max_cat + 1), 0);
auto n_bins_feature = page->Cuts().FeatureBins(candidate.split.findex);
split_cats.resize(common::CatBitField::ComputeStorageSize(n_bins_feature), 0);
CHECK_LE(split_cats.size(), h_cats.size());
std::copy(h_cats.data(), h_cats.data() + split_cats.size(), split_cats.data());

Expand All @@ -616,6 +617,7 @@ struct GPUHistMakerDevice {
base_weight, left_weight, right_weight, candidate.split.loss_chg, parent_sum.GetHess(),
candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess());
} else {
CHECK(!common::CheckNAN(candidate.split.fvalue));
tree.ExpandNode(candidate.nid, candidate.split.findex, candidate.split.fvalue,
candidate.split.dir == kLeftDir, base_weight, left_weight, right_weight,
candidate.split.loss_chg, parent_sum.GetHess(),
Expand Down

0 comments on commit b5eb36f

Please sign in to comment.