Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add categorical data support for gpu_hist updater. #6164

Merged
merged 1 commit into from Sep 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/common/device_helpers.cuh
Expand Up @@ -536,6 +536,21 @@ void CopyDeviceSpanToVector(std::vector<T> *dst, xgboost::common::Span<const T>
cudaMemcpyDeviceToHost));
}

template <class HContainer, class DContainer>
void CopyToD(HContainer const &h, DContainer *d) {
if (h.empty()) {
d->clear();
return;
}
d->resize(h.size());
using HVT = std::remove_cv_t<typename HContainer::value_type>;
using DVT = std::remove_cv_t<typename DContainer::value_type>;
static_assert(std::is_same<HVT, DVT>::value,
"Host and device containers must have same value type.");
dh::safe_cuda(cudaMemcpyAsync(d->data().get(), h.data(), h.size() * sizeof(HVT),
cudaMemcpyHostToDevice));
}

// Keep track of pinned memory allocation
struct PinnedMemory {
void *temp_storage{nullptr};
Expand Down
2 changes: 1 addition & 1 deletion src/common/hist_util.cu
Expand Up @@ -178,7 +178,7 @@ void ProcessBatch(int device, MetaInfo const &info, const SparsePage &page,
dh::XGBCachingDeviceAllocator<char> alloc;
const auto& host_data = page.data.ConstHostVector();
dh::device_vector<Entry> sorted_entries(host_data.begin() + begin,
host_data.begin() + end);
host_data.begin() + end);
thrust::sort(thrust::cuda::par(alloc), sorted_entries.begin(),
sorted_entries.end(), detail::EntryCompareOp());

Expand Down
127 changes: 98 additions & 29 deletions src/tree/gpu_hist/evaluate_splits.cu
@@ -1,8 +1,9 @@
/*!
* Copyright 2020 by XGBoost Contributors
*/
#include "evaluate_splits.cuh"
#include <limits>
#include "evaluate_splits.cuh"
#include "../../common/categorical.h"

namespace xgboost {
namespace tree {
Expand Down Expand Up @@ -66,13 +67,84 @@ ReduceFeature(common::Span<const GradientSumT> feature_histogram,
if (threadIdx.x == 0) {
shared_sum = local_sum;
}
__syncthreads();
cub::CTA_SYNC();
return shared_sum;
}

template <typename GradientSumT, typename TempStorageT> struct OneHotBin {
GradientSumT __device__ operator()(
bool thread_active, uint32_t scan_begin,
SumCallbackOp<GradientSumT>*,
GradientSumT const &missing,
EvaluateSplitInputs<GradientSumT> const &inputs, TempStorageT *) {
GradientSumT bin = thread_active
? inputs.gradient_histogram[scan_begin + threadIdx.x]
: GradientSumT();
auto rest = inputs.parent_sum - bin - missing;
return rest;
}
};

template <typename GradientSumT>
struct UpdateOneHot {
void __device__ operator()(bool missing_left, uint32_t scan_begin, float gain,
bst_feature_t fidx, GradientSumT const &missing,
GradientSumT const &bin,
EvaluateSplitInputs<GradientSumT> const &inputs,
DeviceSplitCandidate *best_split) {
int split_gidx = (scan_begin + threadIdx.x);
float fvalue = inputs.feature_values[split_gidx];
GradientSumT left = missing_left ? bin + missing : bin;
GradientSumT right = inputs.parent_sum - left;
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, fidx,
GradientPair(left), GradientPair(right), true,
inputs.param);
}
};

template <typename GradientSumT, typename TempStorageT, typename ScanT>
struct NumericBin {
GradientSumT __device__ operator()(bool thread_active, uint32_t scan_begin,
SumCallbackOp<GradientSumT>* prefix_callback,
GradientSumT const &missing,
EvaluateSplitInputs<GradientSumT> inputs,
TempStorageT *temp_storage) {
GradientSumT bin = thread_active
? inputs.gradient_histogram[scan_begin + threadIdx.x]
: GradientSumT();
ScanT(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum(), *prefix_callback);
return bin;
}
};

template <typename GradientSumT>
struct UpdateNumeric {
void __device__ operator()(bool missing_left, uint32_t scan_begin, float gain,
bst_feature_t fidx, GradientSumT const &missing,
GradientSumT const &bin,
EvaluateSplitInputs<GradientSumT> const &inputs,
DeviceSplitCandidate *best_split) {
// Use pointer from cut to indicate begin and end of bins for each feature.
uint32_t gidx_begin = inputs.feature_segments[fidx]; // begining bin
int split_gidx = (scan_begin + threadIdx.x) - 1;
float fvalue;
if (split_gidx < static_cast<int>(gidx_begin)) {
fvalue = inputs.min_fvalue[fidx];
} else {
fvalue = inputs.feature_values[split_gidx];
}
GradientSumT left = missing_left ? bin + missing : bin;
GradientSumT right = inputs.parent_sum - left;
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue,
fidx, GradientPair(left), GradientPair(right),
false, inputs.param);
}
};

/*! \brief Find the thread with best gain. */
template <int BLOCK_THREADS, typename ReduceT, typename ScanT,
typename MaxReduceT, typename TempStorageT, typename GradientSumT>
typename MaxReduceT, typename TempStorageT, typename GradientSumT,
typename BinFn, typename UpdateFn>
__device__ void EvaluateFeature(
int fidx, EvaluateSplitInputs<GradientSumT> inputs,
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
Expand All @@ -83,12 +155,14 @@ __device__ void EvaluateFeature(
uint32_t gidx_begin = inputs.feature_segments[fidx]; // begining bin
uint32_t gidx_end =
inputs.feature_segments[fidx + 1]; // end bin for i^th feature
auto feature_hist = inputs.gradient_histogram.subspan(gidx_begin, gidx_end - gidx_begin);
auto bin_fn = BinFn();
auto update_fn = UpdateFn();

// Sum histogram bins for current feature
GradientSumT const feature_sum =
ReduceFeature<BLOCK_THREADS, ReduceT, TempStorageT, GradientSumT>(
inputs.gradient_histogram.subspan(gidx_begin, gidx_end - gidx_begin),
temp_storage);
feature_hist, temp_storage);

GradientSumT const missing = inputs.parent_sum - feature_sum;
float const null_gain = -std::numeric_limits<bst_float>::infinity();
Expand All @@ -97,12 +171,7 @@ __device__ void EvaluateFeature(
for (int scan_begin = gidx_begin; scan_begin < gidx_end;
scan_begin += BLOCK_THREADS) {
bool thread_active = (scan_begin + threadIdx.x) < gidx_end;

// Gradient value for current bin.
GradientSumT bin = thread_active
? inputs.gradient_histogram[scan_begin + threadIdx.x]
: GradientSumT();
ScanT(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum(), prefix_op);
auto bin = bin_fn(thread_active, scan_begin, &prefix_op, missing, inputs, temp_storage);

// Whether the gradient of missing values is put to the left side.
bool missing_left = true;
Expand All @@ -127,24 +196,14 @@ __device__ void EvaluateFeature(
block_max = best;
}

__syncthreads();
cub::CTA_SYNC();

// Best thread updates split
if (threadIdx.x == block_max.key) {
int split_gidx = (scan_begin + threadIdx.x) - 1;
float fvalue;
if (split_gidx < static_cast<int>(gidx_begin)) {
fvalue = inputs.min_fvalue[fidx];
} else {
fvalue = inputs.feature_values[split_gidx];
}
GradientSumT left = missing_left ? bin + missing : bin;
GradientSumT right = inputs.parent_sum - left;
best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue,
fidx, GradientPair(left), GradientPair(right),
inputs.param);
update_fn(missing_left, scan_begin, gain, fidx, missing, bin, inputs,
best_split);
}
__syncthreads();
cub::CTA_SYNC();
}
}

Expand Down Expand Up @@ -186,11 +245,21 @@ __global__ void EvaluateSplitsKernel(
// One block for each feature. Features are sampled, so fidx != blockIdx.x
int fidx = inputs.feature_set[is_left ? blockIdx.x
: blockIdx.x - left.feature_set.size()];
if (common::IsCat(inputs.feature_types, fidx)) {
EvaluateFeature<BLOCK_THREADS, SumReduceT, BlockScanT, MaxReduceT,
TempStorage, GradientSumT,
OneHotBin<GradientSumT, TempStorage>,
UpdateOneHot<GradientSumT>>(fidx, inputs, evaluator, &best_split,
&temp_storage);
} else {
EvaluateFeature<BLOCK_THREADS, SumReduceT, BlockScanT, MaxReduceT,
TempStorage, GradientSumT,
NumericBin<GradientSumT, TempStorage, BlockScanT>,
UpdateNumeric<GradientSumT>>(fidx, inputs, evaluator, &best_split,
&temp_storage);
}

EvaluateFeature<BLOCK_THREADS, SumReduceT, BlockScanT, MaxReduceT>(
fidx, inputs, evaluator, &best_split, &temp_storage);

__syncthreads();
cub::CTA_SYNC();

if (threadIdx.x == 0) {
// Record best loss for each feature
Expand Down
1 change: 1 addition & 0 deletions src/tree/gpu_hist/evaluate_splits.cuh
Expand Up @@ -18,6 +18,7 @@ struct EvaluateSplitInputs {
GradientSumT parent_sum;
GPUTrainingParam param;
common::Span<const bst_feature_t> feature_set;
common::Span<FeatureType const> feature_types;
common::Span<const uint32_t> feature_segments;
common::Span<const float> feature_values;
common::Span<const float> min_fvalue;
Expand Down
4 changes: 4 additions & 0 deletions src/tree/updater_gpu_common.cuh
Expand Up @@ -59,6 +59,7 @@ struct DeviceSplitCandidate {
DefaultDirection dir {kLeftDir};
int findex {-1};
float fvalue {0};
bool is_cat { false };

GradientPair left_sum;
GradientPair right_sum;
Expand All @@ -79,13 +80,15 @@ struct DeviceSplitCandidate {
float fvalue_in, int findex_in,
GradientPair left_sum_in,
GradientPair right_sum_in,
bool cat,
const GPUTrainingParam& param) {
if (loss_chg_in > loss_chg &&
left_sum_in.GetHess() >= param.min_child_weight &&
right_sum_in.GetHess() >= param.min_child_weight) {
loss_chg = loss_chg_in;
dir = dir_in;
fvalue = fvalue_in;
is_cat = cat;
left_sum = left_sum_in;
right_sum = right_sum_in;
findex = findex_in;
Expand All @@ -98,6 +101,7 @@ struct DeviceSplitCandidate {
<< "dir: " << c.dir << ", "
<< "findex: " << c.findex << ", "
<< "fvalue: " << c.fvalue << ", "
<< "is_cat: " << c.is_cat << ", "
<< "left sum: " << c.left_sum << ", "
<< "right sum: " << c.right_sum << std::endl;
return os;
Expand Down