Skip to content

Commit

Permalink
Use single precision in gain calculation, use pointers instead of spa…
Browse files Browse the repository at this point in the history
…n. (#8051)
  • Loading branch information
RAMitchell committed Jul 12, 2022
1 parent a5bc8e2 commit 0bdaca2
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 21 deletions.
2 changes: 1 addition & 1 deletion src/tree/gpu_hist/evaluate_splits.cuh
Expand Up @@ -68,7 +68,7 @@ class GPUHistEvaluator {
// storage for sorted index of feature histogram, used for sort based splits.
dh::device_vector<bst_feature_t> cat_sorted_idx_;
// cached input for sorting the histogram, used for sort based splits.
using SortPair = thrust::tuple<uint32_t, double>;
using SortPair = thrust::tuple<uint32_t, float>;
dh::device_vector<SortPair> sort_input_;
// cache for feature index
dh::device_vector<bst_feature_t> feature_idx_;
Expand Down
2 changes: 1 addition & 1 deletion src/tree/gpu_hist/evaluator.cu
Expand Up @@ -89,7 +89,7 @@ common::Span<bst_feature_t const> GPUHistEvaluator<GradientSumT>::SortHistogram(
input.gradient_histogram[j]);
return thrust::make_tuple(i, lw);
}
return thrust::make_tuple(i, 0.0);
return thrust::make_tuple(i, 0.0f);
});
// Sort an array segmented according to
// - nodes
Expand Down
36 changes: 17 additions & 19 deletions src/tree/split_evaluator.h
Expand Up @@ -66,21 +66,21 @@ class TreeEvaluator {

template <typename ParamT>
struct SplitEvaluator {
common::Span<int const> constraints;
common::Span<float const> lower;
common::Span<float const> upper;
const int* constraints;
const float* lower;
const float* upper;
bool has_constraint;

XGBOOST_DEVICE double CalcSplitGain(const ParamT &param, bst_node_t nidx,
XGBOOST_DEVICE float CalcSplitGain(const ParamT &param, bst_node_t nidx,
bst_feature_t fidx,
tree::GradStats const& left,
tree::GradStats const& right) const {
int constraint = constraints[fidx];
const double negative_infinity = -std::numeric_limits<double>::infinity();
double wleft = this->CalcWeight(nidx, param, left);
double wright = this->CalcWeight(nidx, param, right);
const float negative_infinity = -std::numeric_limits<float>::infinity();
float wleft = this->CalcWeight(nidx, param, left);
float wright = this->CalcWeight(nidx, param, right);

double gain = this->CalcGainGivenWeight(param, left, wleft) +
float gain = this->CalcGainGivenWeight(param, left, wleft) +
this->CalcGainGivenWeight(param, right, wright);

if (constraint == 0) {
Expand All @@ -101,17 +101,17 @@ class TreeEvaluator {

if (nodeid == kRootParentId) {
return w;
} else if (w < lower(nodeid)) {
} else if (w < lower[nodeid]) {
return lower[nodeid];
} else if (w > upper(nodeid)) {
} else if (w > upper[nodeid]) {
return upper[nodeid];
} else {
return w;
}
}

template <typename GradientSumT>
XGBOOST_DEVICE double CalcWeightCat(ParamT const& param, GradientSumT const& stats) const {
XGBOOST_DEVICE float CalcWeightCat(ParamT const& param, GradientSumT const& stats) const {
// FIXME(jiamingy): This is a temporary solution until we have categorical feature
// specific regularization parameters. During sorting we should try to avoid any
// regularization.
Expand Down Expand Up @@ -141,15 +141,13 @@ class TreeEvaluator {
/* Get a view to the evaluator that can be passed down to device. */
template <typename ParamT = TrainParam> auto GetEvaluator() const {
if (device_ != GenericParameter::kCpuId) {
auto constraints = monotone_.ConstDeviceSpan();
return SplitEvaluator<ParamT>{
constraints, lower_bounds_.ConstDeviceSpan(),
upper_bounds_.ConstDeviceSpan(), has_constraint_};
auto constraints = monotone_.ConstDevicePointer();
return SplitEvaluator<ParamT>{constraints, lower_bounds_.ConstDevicePointer(),
upper_bounds_.ConstDevicePointer(), has_constraint_};
} else {
auto constraints = monotone_.ConstHostSpan();
return SplitEvaluator<ParamT>{constraints, lower_bounds_.ConstHostSpan(),
upper_bounds_.ConstHostSpan(),
has_constraint_};
auto constraints = monotone_.ConstHostPointer();
return SplitEvaluator<ParamT>{constraints, lower_bounds_.ConstHostPointer(),
upper_bounds_.ConstHostPointer(), has_constraint_};
}
}

Expand Down

0 comments on commit 0bdaca2

Please sign in to comment.