Skip to content

Commit

Permalink
Unify the cat split storage for CPU. (#7937)
Browse files Browse the repository at this point in the history
* Unify the cat split storage for CPU.

* Cleanup.

* Workaround.
  • Loading branch information
trivialfis committed May 26, 2022
1 parent 755d9d4 commit 18cbeba
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 33 deletions.
6 changes: 3 additions & 3 deletions include/xgboost/tree_model.h
Expand Up @@ -440,10 +440,10 @@ class RegTree : public Model {
* \param right_sum The sum hess of right leaf.
*/
void ExpandCategorical(bst_node_t nid, unsigned split_index,
common::Span<uint32_t> split_cat, bool default_left,
common::Span<const uint32_t> split_cat, bool default_left,
bst_float base_weight, bst_float left_leaf_weight,
bst_float right_leaf_weight, bst_float loss_change,
float sum_hess, float left_sum, float right_sum);
bst_float right_leaf_weight, bst_float loss_change, float sum_hess,
float left_sum, float right_sum);

bool HasCategoricalSplit() const {
return !split_categories_.empty();
Expand Down
38 changes: 16 additions & 22 deletions src/tree/hist/evaluate_splits.h
Expand Up @@ -57,9 +57,9 @@ class HistEvaluator {
}

/**
* \brief Use learned direction with one-hot split. Other implementations (LGB, sklearn)
* create a pseudo-category for missing value but here we just do a complete scan
* to avoid making specialized histogram bin.
* \brief Use learned direction with one-hot split. Other implementations (LGB) create a
* pseudo-category for missing value but here we just do a complete scan to avoid
* making specialized histogram bin.
*/
void EnumerateOneHot(common::HistogramCuts const &cut, const common::GHistRow &hist,
bst_feature_t fidx, bst_node_t nidx,
Expand All @@ -76,6 +76,7 @@ class HistEvaluator {
GradStats right_sum;
// best split so far
SplitEntry best;
best.is_cat = false; // marker for whether it's updated or not.

auto f_hist = hist.subspan(cut_ptr[fidx], n_bins);
auto feature_sum = GradStats{
Expand All @@ -98,8 +99,8 @@ class HistEvaluator {
}

// missing on right (treat missing as chosen category)
left_sum.SetSubstract(left_sum, missing);
right_sum.Add(missing);
left_sum.SetSubstract(parent.stats, right_sum);
if (IsValid(left_sum, right_sum)) {
auto missing_right_chg = static_cast<float>(
evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum}, GradStats{right_sum}) -
Expand All @@ -108,6 +109,13 @@ class HistEvaluator {
}
}

if (best.is_cat) {
auto n = common::CatBitField::ComputeStorageSize(n_bins + 1);
best.cat_bits.resize(n, 0);
common::CatBitField cat_bits{best.cat_bits};
cat_bits.Set(best.split_value);
}

p_best->Update(best);
}

Expand Down Expand Up @@ -345,25 +353,11 @@ class HistEvaluator {
evaluator.CalcWeight(candidate.nid, param_, GradStats{candidate.split.right_sum});

if (candidate.split.is_cat) {
std::vector<uint32_t> split_cats;
if (candidate.split.cat_bits.empty()) {
if (common::InvalidCat(candidate.split.split_value)) {
common::InvalidCategory();
}
auto cat = common::AsCat(candidate.split.split_value);
split_cats.resize(LBitField32::ComputeStorageSize(std::max(cat + 1, 1)), 0);
LBitField32 cat_bits;
cat_bits = LBitField32(split_cats);
cat_bits.Set(cat);
} else {
split_cats = candidate.split.cat_bits;
common::CatBitField cat_bits{split_cats};
}
tree.ExpandCategorical(
candidate.nid, candidate.split.SplitIndex(), split_cats, candidate.split.DefaultLeft(),
base_weight, left_weight * param_.learning_rate, right_weight * param_.learning_rate,
candidate.split.loss_chg, parent_sum.GetHess(), candidate.split.left_sum.GetHess(),
candidate.split.right_sum.GetHess());
candidate.nid, candidate.split.SplitIndex(), candidate.split.cat_bits,
candidate.split.DefaultLeft(), base_weight, left_weight * param_.learning_rate,
right_weight * param_.learning_rate, candidate.split.loss_chg, parent_sum.GetHess(),
candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess());
} else {
tree.ExpandNode(candidate.nid, candidate.split.SplitIndex(), candidate.split.split_value,
candidate.split.DefaultLeft(), base_weight,
Expand Down
2 changes: 1 addition & 1 deletion src/tree/split_evaluator.h
Expand Up @@ -160,7 +160,7 @@ class TreeEvaluator {
return;
}

auto max_nidx = std::max(leftid, rightid);
size_t max_nidx = std::max(leftid, rightid);
if (lower_bounds_.Size() <= max_nidx) {
lower_bounds_.Resize(max_nidx * 2 + 1, -std::numeric_limits<float>::max());
}
Expand Down
8 changes: 3 additions & 5 deletions src/tree/tree_model.cc
Expand Up @@ -808,11 +808,9 @@ void RegTree::ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_v
}

void RegTree::ExpandCategorical(bst_node_t nid, unsigned split_index,
common::Span<uint32_t> split_cat, bool default_left,
bst_float base_weight,
bst_float left_leaf_weight,
bst_float right_leaf_weight,
bst_float loss_change, float sum_hess,
common::Span<const uint32_t> split_cat, bool default_left,
bst_float base_weight, bst_float left_leaf_weight,
bst_float right_leaf_weight, bst_float loss_change, float sum_hess,
float left_sum, float right_sum) {
this->ExpandNode(nid, split_index, std::numeric_limits<float>::quiet_NaN(),
default_left, base_weight,
Expand Down
4 changes: 2 additions & 2 deletions tests/python/test_updaters.py
Expand Up @@ -10,10 +10,10 @@
'nthread': strategies.integers(1, 4),
'max_depth': strategies.integers(1, 11),
'min_child_weight': strategies.floats(0.5, 2.0),
'alpha': strategies.floats(0.0, 2.0),
'alpha': strategies.floats(1e-5, 2.0),
'lambda': strategies.floats(1e-5, 2.0),
'eta': strategies.floats(0.01, 0.5),
'gamma': strategies.floats(0.0, 2.0),
'gamma': strategies.floats(1e-5, 2.0),
'seed': strategies.integers(0, 10),
# We cannot enable subsampling as the training loss can increase
# 'subsample': strategies.floats(0.5, 1.0),
Expand Down

0 comments on commit 18cbeba

Please sign in to comment.