Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor split valuation kernel #8073

Merged
merged 12 commits into from Jul 21, 2022
2 changes: 1 addition & 1 deletion src/common/device_helpers.cuh
Expand Up @@ -1949,7 +1949,7 @@ class LDGIterator {
const T *ptr_;

public:
explicit LDGIterator(const T *ptr) : ptr_(ptr) {}
XGBOOST_DEVICE explicit LDGIterator(const T *ptr) : ptr_(ptr) {}
__device__ T operator[](std::size_t idx) const {
DeviceWordT tmp[kNumWords];
static_assert(sizeof(tmp) == sizeof(T), "Expect sizes to be equal.");
Expand Down
386 changes: 190 additions & 196 deletions src/tree/gpu_hist/evaluate_splits.cu

Large diffs are not rendered by default.

1 change: 0 additions & 1 deletion src/tree/gpu_hist/evaluate_splits.cuh
Expand Up @@ -51,7 +51,6 @@ struct CatAccessor {
}
};

template <typename GradientSumT>
class GPUHistEvaluator {
using CatST = common::CatBitField::value_type; // categorical storage type
// use pinned memory to stage the categories, used for sort based splits.
Expand Down
8 changes: 2 additions & 6 deletions src/tree/gpu_hist/evaluator.cu
Expand Up @@ -14,8 +14,7 @@

namespace xgboost {
namespace tree {
template <typename GradientSumT>
void GPUHistEvaluator<GradientSumT>::Reset(common::HistogramCuts const &cuts,
void GPUHistEvaluator::Reset(common::HistogramCuts const &cuts,
common::Span<FeatureType const> ft,
bst_feature_t n_features, TrainParam const &param,
int32_t device) {
Expand Down Expand Up @@ -68,8 +67,7 @@ void GPUHistEvaluator<GradientSumT>::Reset(common::HistogramCuts const &cuts,
}
}

template <typename GradientSumT>
common::Span<bst_feature_t const> GPUHistEvaluator<GradientSumT>::SortHistogram(
common::Span<bst_feature_t const> GPUHistEvaluator::SortHistogram(
common::Span<const EvaluateSplitInputs> d_inputs, EvaluateSplitSharedInputs shared_inputs,
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator) {
dh::XGBCachingDeviceAllocator<char> alloc;
Expand Down Expand Up @@ -128,7 +126,5 @@ common::Span<bst_feature_t const> GPUHistEvaluator<GradientSumT>::SortHistogram(
return dh::ToSpan(cat_sorted_idx_);
}

template class GPUHistEvaluator<GradientPair>;
template class GPUHistEvaluator<GradientPairPrecise>;
} // namespace tree
} // namespace xgboost
2 changes: 1 addition & 1 deletion src/tree/param.h
Expand Up @@ -255,7 +255,7 @@ XGBOOST_DEVICE inline T CalcWeight(const TrainingParams &p, T sum_grad,
// calculate the cost of loss function
template <typename TrainingParams, typename T>
XGBOOST_DEVICE inline T CalcGain(const TrainingParams &p, T sum_grad, T sum_hess) {
if (sum_hess < p.min_child_weight) {
if (sum_hess < p.min_child_weight || sum_hess <= 0.0) {
return T(0.0);
}
if (p.max_delta_step == 0.0f) {
Expand Down
37 changes: 24 additions & 13 deletions src/tree/split_evaluator.h
Expand Up @@ -71,11 +71,10 @@ class TreeEvaluator {
const float* upper;
bool has_constraint;

XGBOOST_DEVICE float CalcSplitGain(const ParamT &param, bst_node_t nidx,
bst_feature_t fidx,
tree::GradStats const& left,
tree::GradStats const& right) const {
int constraint = constraints[fidx];
template <typename GradientSumT>
XGBOOST_DEVICE float CalcSplitGain(const ParamT& param, bst_node_t nidx, bst_feature_t fidx,
GradientSumT const& left, GradientSumT const& right) const {
int constraint = has_constraint ? constraints[fidx] : 0;
const float negative_infinity = -std::numeric_limits<float>::infinity();
float wleft = this->CalcWeight(nidx, param, left);
float wright = this->CalcWeight(nidx, param, right);
Expand All @@ -92,8 +91,9 @@ class TreeEvaluator {
}
}

template <typename GradientSumT>
XGBOOST_DEVICE float CalcWeight(bst_node_t nodeid, const ParamT &param,
tree::GradStats const& stats) const {
GradientSumT const& stats) const {
float w = ::xgboost::tree::CalcWeight(param, stats);
if (!has_constraint) {
return w;
Expand All @@ -118,21 +118,32 @@ class TreeEvaluator {
return ::xgboost::tree::CalcWeight(param, stats);
}

XGBOOST_DEVICE float
CalcGainGivenWeight(ParamT const &p, tree::GradStats const& stats, float w) const {
// Fast floating point division instruction on device
XGBOOST_DEVICE float Divide(float a, float b) const {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we extract this as an independent function?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not expecting to use it anywhere else at this moment, so I think it should stay unless you have something specific in mind A kernel needs to be heavily bottlenecked by arithmetic before this makes a difference, and I can't think of other places in xgboost.

#ifdef __CUDA_ARCH__
return __fdividef(a, b);
#else
return a / b;
#endif
}

template <typename GradientSumT>
XGBOOST_DEVICE float CalcGainGivenWeight(ParamT const& p, GradientSumT const& stats,
float w) const {
if (stats.GetHess() <= 0) {
return .0f;
}
// Avoiding tree::CalcGainGivenWeight can significantly reduce avg floating point error.
if (p.max_delta_step == 0.0f && has_constraint == false) {
return common::Sqr(ThresholdL1(stats.sum_grad, p.reg_alpha)) /
(stats.sum_hess + p.reg_lambda);
return Divide(common::Sqr(ThresholdL1(stats.GetGrad(), p.reg_alpha)),
(stats.GetHess() + p.reg_lambda));
}
return tree::CalcGainGivenWeight<ParamT, float>(p, stats.sum_grad,
stats.sum_hess, w);
return tree::CalcGainGivenWeight<ParamT, float>(p, stats.GetGrad(),
stats.GetHess(), w);
}
template <typename GradientSumT>
XGBOOST_DEVICE float CalcGain(bst_node_t nid, ParamT const &p,
tree::GradStats const& stats) const {
GradientSumT const& stats) const {
return this->CalcGainGivenWeight(p, stats, this->CalcWeight(nid, p, stats));
}
};
Expand Down
2 changes: 1 addition & 1 deletion src/tree/updater_gpu_hist.cu
Expand Up @@ -171,7 +171,7 @@ class DeviceHistogramStorage {
template <typename GradientSumT>
struct GPUHistMakerDevice {
private:
GPUHistEvaluator<GradientSumT> evaluator_;
GPUHistEvaluator evaluator_;
Context const* ctx_;

public:
Expand Down
15 changes: 8 additions & 7 deletions tests/cpp/tree/gpu_hist/test_evaluate_splits.cu
Expand Up @@ -62,7 +62,7 @@ void TestEvaluateSingleSplit(bool is_categorical) {
cuts.min_vals_.ConstDeviceSpan(),
};

GPUHistEvaluator<GradientPairPrecise> evaluator{
GPUHistEvaluator evaluator{
tparam, static_cast<bst_feature_t>(feature_set.size()), 0};
evaluator.Reset(cuts, dh::ToSpan(feature_types), feature_set.size(), tparam, 0);
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
Expand Down Expand Up @@ -109,7 +109,7 @@ TEST(GpuHist, EvaluateSingleSplitMissing) {
dh::ToSpan(feature_min_values),
};

GPUHistEvaluator<GradientPairPrecise> evaluator(tparam, feature_set.size(), 0);
GPUHistEvaluator evaluator(tparam, feature_set.size(), 0);
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;

EXPECT_EQ(result.findex, 0);
Expand All @@ -121,7 +121,7 @@ TEST(GpuHist, EvaluateSingleSplitMissing) {

TEST(GpuHist, EvaluateSingleSplitEmpty) {
TrainParam tparam = ZeroParam();
GPUHistEvaluator<GradientPairPrecise> evaluator(tparam, 1, 0);
GPUHistEvaluator evaluator(tparam, 1, 0);
DeviceSplitCandidate result =
evaluator.EvaluateSingleSplit(EvaluateSplitInputs{}, EvaluateSplitSharedInputs{}).split;
EXPECT_EQ(result.findex, -1);
Expand Down Expand Up @@ -159,7 +159,7 @@ TEST(GpuHist, EvaluateSingleSplitFeatureSampling) {
dh::ToSpan(feature_min_values),
};

GPUHistEvaluator<GradientPairPrecise> evaluator(tparam, feature_min_values.size(), 0);
GPUHistEvaluator evaluator(tparam, feature_min_values.size(), 0);
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input, shared_inputs).split;

EXPECT_EQ(result.findex, 1);
Expand Down Expand Up @@ -199,7 +199,7 @@ TEST(GpuHist, EvaluateSingleSplitBreakTies) {
dh::ToSpan(feature_min_values),
};

GPUHistEvaluator<GradientPairPrecise> evaluator(tparam, feature_min_values.size(), 0);
GPUHistEvaluator evaluator(tparam, feature_min_values.size(), 0);
DeviceSplitCandidate result = evaluator.EvaluateSingleSplit(input,shared_inputs).split;

EXPECT_EQ(result.findex, 0);
Expand Down Expand Up @@ -246,7 +246,7 @@ TEST(GpuHist, EvaluateSplits) {
dh::ToSpan(feature_min_values),
};

GPUHistEvaluator<GradientPairPrecise> evaluator{
GPUHistEvaluator evaluator{
tparam, static_cast<bst_feature_t>(feature_min_values.size()), 0};
dh::device_vector<EvaluateSplitInputs> inputs = std::vector<EvaluateSplitInputs>{input_left,input_right};
evaluator.LaunchEvaluateSplits(input_left.feature_set.size(),dh::ToSpan(inputs),shared_inputs, evaluator.GetEvaluator(),
Expand All @@ -263,7 +263,7 @@ TEST(GpuHist, EvaluateSplits) {

TEST_F(TestPartitionBasedSplit, GpuHist) {
dh::device_vector<FeatureType> ft{std::vector<FeatureType>{FeatureType::kCategorical}};
GPUHistEvaluator<GradientPairPrecise> evaluator{param_,
GPUHistEvaluator evaluator{param_,
static_cast<bst_feature_t>(info_.num_col_), 0};

cuts_.cut_ptrs_.SetDevice(0);
Expand All @@ -287,5 +287,6 @@ TEST_F(TestPartitionBasedSplit, GpuHist) {
auto split = evaluator.EvaluateSingleSplit(input, shared_inputs).split;
ASSERT_NEAR(split.loss_chg, best_score_, 1e-16);
}

} // namespace tree
} // namespace xgboost
2 changes: 2 additions & 0 deletions tests/cpp/tree/test_evaluate_splits.h
Expand Up @@ -43,6 +43,8 @@ class TestPartitionBasedSplit : public ::testing::Test {
auto &h_vals = cuts_.cut_values_.HostVector();
h_vals.resize(n_bins_);
std::iota(h_vals.begin(), h_vals.end(), 0.0);

cuts_.min_vals_.Resize(1);

hist_.Init(cuts_.TotalBins());
hist_.AddHistRow(0);
Expand Down