Skip to content

Commit

Permalink
Always use partition based categorical splits. (#7857)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed May 3, 2022
1 parent 90cce38 commit 317d7be
Show file tree
Hide file tree
Showing 13 changed files with 79 additions and 104 deletions.
27 changes: 12 additions & 15 deletions doc/tutorials/categorical.rst
Expand Up @@ -72,23 +72,20 @@ Optimal Partitioning
.. versionadded:: 1.6

Optimal partitioning is a technique for partitioning the categorical predictors for each
node split, the proof of optimality for numerical objectives like ``RMSE`` was first
introduced by `[1] <#references>`__. The algorithm is used in decision trees for handling
regression and binary classification tasks `[2] <#references>`__, later LightGBM `[3]
<#references>`__ brought it to the context of gradient boosting trees and now is also
adopted in XGBoost as an optional feature for handling categorical splits. More
specifically, the proof by Fisher `[1] <#references>`__ states that, when trying to
partition a set of discrete values into groups based on the distances between a measure of
these values, one only needs to look at sorted partitions instead of enumerating all
possible permutations. In the context of decision trees, the discrete values are
categories, and the measure is the output leaf value. Intuitively, we want to group the
categories that output similar leaf values. During split finding, we first sort the
gradient histogram to prepare the contiguous partitions then enumerate the splits
node split, the proof of optimality for numerical output was first introduced by `[1]
<#references>`__. The algorithm is used in decision trees `[2] <#references>`__, later
LightGBM `[3] <#references>`__ brought it to the context of gradient boosting trees and
now is also adopted in XGBoost as an optional feature for handling categorical
splits. More specifically, the proof by Fisher `[1] <#references>`__ states that, when
trying to partition a set of discrete values into groups based on the distances between a
measure of these values, one only needs to look at sorted partitions instead of
enumerating all possible permutations. In the context of decision trees, the discrete
values are categories, and the measure is the output leaf value. Intuitively, we want to
group the categories that output similar leaf values. During split finding, we first sort
the gradient histogram to prepare the contiguous partitions then enumerate the splits
according to these sorted values. One of the related parameters for XGBoost is
``max_cat_to_one_hot``, which controls whether one-hot encoding or partitioning should be
used for each feature, see :doc:`/parameter` for details. When objective is not
regression or binary classification, XGBoost will fallback to using onehot encoding
instead.
used for each feature, see :doc:`/parameter` for details.


**********************
Expand Down
3 changes: 0 additions & 3 deletions include/xgboost/task.h
Expand Up @@ -38,9 +38,6 @@ struct ObjInfo {
ObjInfo(Task t) : task{t} {} // NOLINT
ObjInfo(Task t, bool khess, bool zhess) : task{t}, const_hess{khess}, zero_hess(zhess) {}

XGBOOST_DEVICE bool UseOneHot() const {
return (task != ObjInfo::kRegression && task != ObjInfo::kBinary);
}
/**
* \brief Use adaptive tree if the objective doesn't have valid hessian value.
*/
Expand Down
5 changes: 2 additions & 3 deletions src/common/categorical.h
Expand Up @@ -12,7 +12,6 @@
#include "xgboost/data.h"
#include "xgboost/parameter.h"
#include "xgboost/span.h"
#include "xgboost/task.h"

namespace xgboost {
namespace common {
Expand Down Expand Up @@ -82,8 +81,8 @@ inline void InvalidCategory() {
/*!
* \brief Whether should we use onehot encoding for categorical data.
*/
XGBOOST_DEVICE inline bool UseOneHot(uint32_t n_cats, uint32_t max_cat_to_onehot, ObjInfo task) {
bool use_one_hot = n_cats < max_cat_to_onehot || task.UseOneHot();
XGBOOST_DEVICE inline bool UseOneHot(uint32_t n_cats, uint32_t max_cat_to_onehot) {
bool use_one_hot = n_cats < max_cat_to_onehot;
return use_one_hot;
}

Expand Down
30 changes: 14 additions & 16 deletions src/tree/gpu_hist/evaluate_splits.cu
Expand Up @@ -199,13 +199,11 @@ __device__ void EvaluateFeature(
}

template <int BLOCK_THREADS, typename GradientSumT>
__global__ void EvaluateSplitsKernel(
EvaluateSplitInputs<GradientSumT> left,
EvaluateSplitInputs<GradientSumT> right,
ObjInfo task,
common::Span<bst_feature_t> sorted_idx,
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
common::Span<DeviceSplitCandidate> out_candidates) {
__global__ void EvaluateSplitsKernel(EvaluateSplitInputs<GradientSumT> left,
EvaluateSplitInputs<GradientSumT> right,
common::Span<bst_feature_t> sorted_idx,
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
common::Span<DeviceSplitCandidate> out_candidates) {
// KeyValuePair here used as threadIdx.x -> gain_value
using ArgMaxT = cub::KeyValuePair<int, float>;
using BlockScanT =
Expand Down Expand Up @@ -241,7 +239,7 @@ __global__ void EvaluateSplitsKernel(

if (common::IsCat(inputs.feature_types, fidx)) {
auto n_bins_in_feat = inputs.feature_segments[fidx + 1] - inputs.feature_segments[fidx];
if (common::UseOneHot(n_bins_in_feat, inputs.param.max_cat_to_onehot, task)) {
if (common::UseOneHot(n_bins_in_feat, inputs.param.max_cat_to_onehot)) {
EvaluateFeature<BLOCK_THREADS, SumReduceT, BlockScanT, MaxReduceT, TempStorage, GradientSumT,
kOneHot>(fidx, inputs, evaluator, sorted_idx, 0, &best_split, &temp_storage);
} else {
Expand Down Expand Up @@ -310,7 +308,7 @@ __device__ void SortBasedSplit(EvaluateSplitInputs<GradientSumT> const &input,

template <typename GradientSumT>
void GPUHistEvaluator<GradientSumT>::EvaluateSplits(
EvaluateSplitInputs<GradientSumT> left, EvaluateSplitInputs<GradientSumT> right, ObjInfo task,
EvaluateSplitInputs<GradientSumT> left, EvaluateSplitInputs<GradientSumT> right,
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
common::Span<DeviceSplitCandidate> out_splits) {
if (!split_cats_.empty()) {
Expand All @@ -323,7 +321,7 @@ void GPUHistEvaluator<GradientSumT>::EvaluateSplits(
// One block for each feature
uint32_t constexpr kBlockThreads = 256;
dh::LaunchKernel {static_cast<uint32_t>(combined_num_features), kBlockThreads, 0}(
EvaluateSplitsKernel<kBlockThreads, GradientSumT>, left, right, task, this->SortedIdx(left),
EvaluateSplitsKernel<kBlockThreads, GradientSumT>, left, right, this->SortedIdx(left),
evaluator, dh::ToSpan(feature_best_splits));

// Reduce to get best candidate for left and right child over all features
Expand Down Expand Up @@ -365,15 +363,15 @@ void GPUHistEvaluator<GradientSumT>::CopyToHost(EvaluateSplitInputs<GradientSumT
}

template <typename GradientSumT>
void GPUHistEvaluator<GradientSumT>::EvaluateSplits(GPUExpandEntry candidate, ObjInfo task,
void GPUHistEvaluator<GradientSumT>::EvaluateSplits(GPUExpandEntry candidate,
EvaluateSplitInputs<GradientSumT> left,
EvaluateSplitInputs<GradientSumT> right,
common::Span<GPUExpandEntry> out_entries) {
auto evaluator = this->tree_evaluator_.template GetEvaluator<GPUTrainingParam>();

dh::TemporaryArray<DeviceSplitCandidate> splits_out_storage(2);
auto out_splits = dh::ToSpan(splits_out_storage);
this->EvaluateSplits(left, right, task, evaluator, out_splits);
this->EvaluateSplits(left, right, evaluator, out_splits);

auto d_sorted_idx = this->SortedIdx(left);
auto d_entries = out_entries;
Expand All @@ -385,7 +383,7 @@ void GPUHistEvaluator<GradientSumT>::EvaluateSplits(GPUExpandEntry candidate, Ob
auto fidx = out_splits[i].findex;

if (split.is_cat &&
!common::UseOneHot(input.FeatureBins(fidx), input.param.max_cat_to_onehot, task)) {
!common::UseOneHot(input.FeatureBins(fidx), input.param.max_cat_to_onehot)) {
bool is_left = i == 0;
auto out = is_left ? cats_out.first(cats_out.size() / 2) : cats_out.last(cats_out.size() / 2);
SortBasedSplit(input, d_sorted_idx, fidx, is_left, out, &out_splits[i]);
Expand All @@ -405,11 +403,11 @@ void GPUHistEvaluator<GradientSumT>::EvaluateSplits(GPUExpandEntry candidate, Ob

template <typename GradientSumT>
GPUExpandEntry GPUHistEvaluator<GradientSumT>::EvaluateSingleSplit(
EvaluateSplitInputs<GradientSumT> input, float weight, ObjInfo task) {
EvaluateSplitInputs<GradientSumT> input, float weight) {
dh::TemporaryArray<DeviceSplitCandidate> splits_out(1);
auto out_split = dh::ToSpan(splits_out);
auto evaluator = tree_evaluator_.GetEvaluator<GPUTrainingParam>();
this->EvaluateSplits(input, {}, task, evaluator, out_split);
this->EvaluateSplits(input, {}, evaluator, out_split);

auto cats_out = this->DeviceCatStorage(input.nidx);
auto d_sorted_idx = this->SortedIdx(input);
Expand All @@ -421,7 +419,7 @@ GPUExpandEntry GPUHistEvaluator<GradientSumT>::EvaluateSingleSplit(
auto fidx = out_split[i].findex;

if (split.is_cat &&
!common::UseOneHot(input.FeatureBins(fidx), input.param.max_cat_to_onehot, task)) {
!common::UseOneHot(input.FeatureBins(fidx), input.param.max_cat_to_onehot)) {
SortBasedSplit(input, d_sorted_idx, fidx, true, cats_out, &out_split[i]);
}

Expand Down
9 changes: 4 additions & 5 deletions src/tree/gpu_hist/evaluate_splits.cuh
Expand Up @@ -114,7 +114,7 @@ class GPUHistEvaluator {
/**
* \brief Reset the evaluator, should be called before any use.
*/
void Reset(common::HistogramCuts const &cuts, common::Span<FeatureType const> ft, ObjInfo task,
void Reset(common::HistogramCuts const &cuts, common::Span<FeatureType const> ft,
bst_feature_t n_features, TrainParam const &param, int32_t device);

/**
Expand Down Expand Up @@ -150,21 +150,20 @@ class GPUHistEvaluator {

// impl of evaluate splits, contains CUDA kernels so it's public
void EvaluateSplits(EvaluateSplitInputs<GradientSumT> left,
EvaluateSplitInputs<GradientSumT> right, ObjInfo task,
EvaluateSplitInputs<GradientSumT> right,
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator,
common::Span<DeviceSplitCandidate> out_splits);
/**
* \brief Evaluate splits for left and right nodes.
*/
void EvaluateSplits(GPUExpandEntry candidate, ObjInfo task,
void EvaluateSplits(GPUExpandEntry candidate,
EvaluateSplitInputs<GradientSumT> left,
EvaluateSplitInputs<GradientSumT> right,
common::Span<GPUExpandEntry> out_splits);
/**
* \brief Evaluate splits for root node.
*/
GPUExpandEntry EvaluateSingleSplit(EvaluateSplitInputs<GradientSumT> input, float weight,
ObjInfo task);
GPUExpandEntry EvaluateSingleSplit(EvaluateSplitInputs<GradientSumT> input, float weight);
};
} // namespace tree
} // namespace xgboost
Expand Down
6 changes: 3 additions & 3 deletions src/tree/gpu_hist/evaluator.cu
Expand Up @@ -16,12 +16,12 @@ namespace xgboost {
namespace tree {
template <typename GradientSumT>
void GPUHistEvaluator<GradientSumT>::Reset(common::HistogramCuts const &cuts,
common::Span<FeatureType const> ft, ObjInfo task,
common::Span<FeatureType const> ft,
bst_feature_t n_features, TrainParam const &param,
int32_t device) {
param_ = param;
tree_evaluator_ = TreeEvaluator{param, n_features, device};
if (cuts.HasCategorical() && !task.UseOneHot()) {
if (cuts.HasCategorical()) {
dh::XGBCachingDeviceAllocator<char> alloc;
auto ptrs = cuts.cut_ptrs_.ConstDeviceSpan();
auto beg = thrust::make_counting_iterator<size_t>(1ul);
Expand All @@ -34,7 +34,7 @@ void GPUHistEvaluator<GradientSumT>::Reset(common::HistogramCuts const &cuts,
auto idx = i - 1;
if (common::IsCat(ft, idx)) {
auto n_bins = ptrs[i] - ptrs[idx];
bool use_sort = !common::UseOneHot(n_bins, to_onehot, task);
bool use_sort = !common::UseOneHot(n_bins, to_onehot);
return use_sort;
}
return false;
Expand Down
10 changes: 3 additions & 7 deletions src/tree/hist/evaluate_splits.h
Expand Up @@ -11,7 +11,6 @@
#include <utility>
#include <vector>

#include "xgboost/task.h"
#include "../param.h"
#include "../constraints.h"
#include "../split_evaluator.h"
Expand Down Expand Up @@ -39,7 +38,6 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
int32_t n_threads_ {0};
FeatureInteractionConstraintHost interaction_constraints_;
std::vector<NodeEntry> snode_;
ObjInfo task_;

// if sum of statistics for non-missing values in the node
// is equal to sum of statistics for all values:
Expand Down Expand Up @@ -244,7 +242,7 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
}
if (is_cat) {
auto n_bins = cut_ptrs.at(fidx + 1) - cut_ptrs[fidx];
if (common::UseOneHot(n_bins, param_.max_cat_to_onehot, task_)) {
if (common::UseOneHot(n_bins, param_.max_cat_to_onehot)) {
EnumerateSplit<+1, kOneHot>(cut, {}, histogram, fidx, nidx, evaluator, best);
EnumerateSplit<-1, kOneHot>(cut, {}, histogram, fidx, nidx, evaluator, best);
} else {
Expand Down Expand Up @@ -345,7 +343,6 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {

auto Evaluator() const { return tree_evaluator_.GetEvaluator(); }
auto const& Stats() const { return snode_; }
auto Task() const { return task_; }

float InitRoot(GradStats const& root_sum) {
snode_.resize(1);
Expand All @@ -363,12 +360,11 @@ template <typename GradientSumT, typename ExpandEntry> class HistEvaluator {
// The column sampler must be constructed by caller since we need to preserve the rng
// for the entire training session.
explicit HistEvaluator(TrainParam const &param, MetaInfo const &info, int32_t n_threads,
std::shared_ptr<common::ColumnSampler> sampler, ObjInfo task)
std::shared_ptr<common::ColumnSampler> sampler)
: param_{param},
column_sampler_{std::move(sampler)},
tree_evaluator_{param, static_cast<bst_feature_t>(info.num_col_), GenericParameter::kCpuId},
n_threads_{n_threads},
task_{task} {
n_threads_{n_threads} {
interaction_constraints_.Configure(param, info.num_col_);
column_sampler_->Init(info.num_col_, info.feature_weights.HostVector(), param_.colsample_bynode,
param_.colsample_bylevel, param_.colsample_bytree);
Expand Down
17 changes: 8 additions & 9 deletions src/tree/updater_approx.cc
Expand Up @@ -29,10 +29,8 @@ DMLC_REGISTRY_FILE_TAG(updater_approx);

namespace {
// Return the BatchParam used by DMatrix.
template <typename GradientSumT>
auto BatchSpec(TrainParam const &p, common::Span<float> hess,
HistEvaluator<GradientSumT, CPUExpandEntry> const &evaluator) {
return BatchParam{p.max_bin, hess, !evaluator.Task().const_hess};
auto BatchSpec(TrainParam const &p, common::Span<float> hess, ObjInfo const task) {
return BatchParam{p.max_bin, hess, !task.const_hess};
}

auto BatchSpec(TrainParam const &p, common::Span<float> hess) {
Expand All @@ -47,7 +45,8 @@ class GloablApproxBuilder {
std::shared_ptr<common::ColumnSampler> col_sampler_;
HistEvaluator<GradientSumT, CPUExpandEntry> evaluator_;
HistogramBuilder<GradientSumT, CPUExpandEntry> histogram_builder_;
GenericParameter const *ctx_;
Context const *ctx_;
ObjInfo const task_;

std::vector<ApproxRowPartitioner> partitioner_;
// Pointer to last updated tree, used for update prediction cache.
Expand All @@ -65,8 +64,7 @@ class GloablApproxBuilder {
int32_t n_total_bins = 0;
partitioner_.clear();
// Generating the GHistIndexMatrix is quite slow, is there a way to speed it up?
for (auto const &page :
p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(param_, hess, evaluator_))) {
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(BatchSpec(param_, hess, task_))) {
if (n_total_bins == 0) {
n_total_bins = page.cut.TotalBins();
feature_values_ = page.cut;
Expand Down Expand Up @@ -158,7 +156,7 @@ class GloablApproxBuilder {
void LeafPartition(RegTree const &tree, common::Span<float> hess,
std::vector<bst_node_t> *p_out_position) {
monitor_->Start(__func__);
if (!evaluator_.Task().UpdateTreeLeaf()) {
if (!task_.UpdateTreeLeaf()) {
return;
}
for (auto const &part : partitioner_) {
Expand All @@ -173,8 +171,9 @@ class GloablApproxBuilder {
common::Monitor *monitor)
: param_{std::move(param)},
col_sampler_{std::move(column_sampler)},
evaluator_{param_, info, ctx->Threads(), col_sampler_, task},
evaluator_{param_, info, ctx->Threads(), col_sampler_},
ctx_{ctx},
task_{task},
monitor_{monitor} {}

void UpdateTree(DMatrix *p_fmat, std::vector<GradientPair> const &gpair, common::Span<float> hess,
Expand Down

0 comments on commit 317d7be

Please sign in to comment.