Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unify the cat split storage for CPU. #7937

Merged
merged 3 commits into from May 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 3 additions & 3 deletions include/xgboost/tree_model.h
Expand Up @@ -440,10 +440,10 @@ class RegTree : public Model {
* \param right_sum The sum hess of right leaf.
*/
void ExpandCategorical(bst_node_t nid, unsigned split_index,
common::Span<uint32_t> split_cat, bool default_left,
common::Span<const uint32_t> split_cat, bool default_left,
bst_float base_weight, bst_float left_leaf_weight,
bst_float right_leaf_weight, bst_float loss_change,
float sum_hess, float left_sum, float right_sum);
bst_float right_leaf_weight, bst_float loss_change, float sum_hess,
float left_sum, float right_sum);

bool HasCategoricalSplit() const {
return !split_categories_.empty();
Expand Down
38 changes: 16 additions & 22 deletions src/tree/hist/evaluate_splits.h
Expand Up @@ -57,9 +57,9 @@ class HistEvaluator {
}

/**
* \brief Use learned direction with one-hot split. Other implementations (LGB, sklearn)
* create a pseudo-category for missing value but here we just do a complete scan
* to avoid making specialized histogram bin.
* \brief Use learned direction with one-hot split. Other implementations (LGB) create a
* pseudo-category for missing value but here we just do a complete scan to avoid
* making specialized histogram bin.
*/
void EnumerateOneHot(common::HistogramCuts const &cut, const common::GHistRow &hist,
bst_feature_t fidx, bst_node_t nidx,
Expand All @@ -76,6 +76,7 @@ class HistEvaluator {
GradStats right_sum;
// best split so far
SplitEntry best;
best.is_cat = false; // marker for whether it's updated or not.

auto f_hist = hist.subspan(cut_ptr[fidx], n_bins);
auto feature_sum = GradStats{
Expand All @@ -98,8 +99,8 @@ class HistEvaluator {
}

// missing on right (treat missing as chosen category)
left_sum.SetSubstract(left_sum, missing);
right_sum.Add(missing);
left_sum.SetSubstract(parent.stats, right_sum);
if (IsValid(left_sum, right_sum)) {
auto missing_right_chg = static_cast<float>(
evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum}, GradStats{right_sum}) -
Expand All @@ -108,6 +109,13 @@ class HistEvaluator {
}
}

if (best.is_cat) {
auto n = common::CatBitField::ComputeStorageSize(n_bins + 1);
best.cat_bits.resize(n, 0);
common::CatBitField cat_bits{best.cat_bits};
cat_bits.Set(best.split_value);
}

p_best->Update(best);
}

Expand Down Expand Up @@ -345,25 +353,11 @@ class HistEvaluator {
evaluator.CalcWeight(candidate.nid, param_, GradStats{candidate.split.right_sum});

if (candidate.split.is_cat) {
std::vector<uint32_t> split_cats;
if (candidate.split.cat_bits.empty()) {
if (common::InvalidCat(candidate.split.split_value)) {
common::InvalidCategory();
}
auto cat = common::AsCat(candidate.split.split_value);
split_cats.resize(LBitField32::ComputeStorageSize(std::max(cat + 1, 1)), 0);
LBitField32 cat_bits;
cat_bits = LBitField32(split_cats);
cat_bits.Set(cat);
} else {
split_cats = candidate.split.cat_bits;
common::CatBitField cat_bits{split_cats};
}
tree.ExpandCategorical(
candidate.nid, candidate.split.SplitIndex(), split_cats, candidate.split.DefaultLeft(),
base_weight, left_weight * param_.learning_rate, right_weight * param_.learning_rate,
candidate.split.loss_chg, parent_sum.GetHess(), candidate.split.left_sum.GetHess(),
candidate.split.right_sum.GetHess());
candidate.nid, candidate.split.SplitIndex(), candidate.split.cat_bits,
candidate.split.DefaultLeft(), base_weight, left_weight * param_.learning_rate,
right_weight * param_.learning_rate, candidate.split.loss_chg, parent_sum.GetHess(),
candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess());
} else {
tree.ExpandNode(candidate.nid, candidate.split.SplitIndex(), candidate.split.split_value,
candidate.split.DefaultLeft(), base_weight,
Expand Down
2 changes: 1 addition & 1 deletion src/tree/split_evaluator.h
Expand Up @@ -160,7 +160,7 @@ class TreeEvaluator {
return;
}

auto max_nidx = std::max(leftid, rightid);
size_t max_nidx = std::max(leftid, rightid);
if (lower_bounds_.Size() <= max_nidx) {
lower_bounds_.Resize(max_nidx * 2 + 1, -std::numeric_limits<float>::max());
}
Expand Down
8 changes: 3 additions & 5 deletions src/tree/tree_model.cc
Expand Up @@ -808,11 +808,9 @@ void RegTree::ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_v
}

void RegTree::ExpandCategorical(bst_node_t nid, unsigned split_index,
common::Span<uint32_t> split_cat, bool default_left,
bst_float base_weight,
bst_float left_leaf_weight,
bst_float right_leaf_weight,
bst_float loss_change, float sum_hess,
common::Span<const uint32_t> split_cat, bool default_left,
bst_float base_weight, bst_float left_leaf_weight,
bst_float right_leaf_weight, bst_float loss_change, float sum_hess,
float left_sum, float right_sum) {
this->ExpandNode(nid, split_index, std::numeric_limits<float>::quiet_NaN(),
default_left, base_weight,
Expand Down
4 changes: 2 additions & 2 deletions tests/python/test_updaters.py
Expand Up @@ -10,10 +10,10 @@
'nthread': strategies.integers(1, 4),
'max_depth': strategies.integers(1, 11),
'min_child_weight': strategies.floats(0.5, 2.0),
'alpha': strategies.floats(0.0, 2.0),
'alpha': strategies.floats(1e-5, 2.0),
'lambda': strategies.floats(1e-5, 2.0),
'eta': strategies.floats(0.01, 0.5),
'gamma': strategies.floats(0.0, 2.0),
'gamma': strategies.floats(1e-5, 2.0),
'seed': strategies.integers(0, 10),
# We cannot enable subsampling as the training loss can increase
# 'subsample': strategies.floats(0.5, 1.0),
Expand Down