diff --git a/include/xgboost/tree_model.h b/include/xgboost/tree_model.h index b2d2ad3383de..9153b84565a8 100644 --- a/include/xgboost/tree_model.h +++ b/include/xgboost/tree_model.h @@ -440,10 +440,10 @@ class RegTree : public Model { * \param right_sum The sum hess of right leaf. */ void ExpandCategorical(bst_node_t nid, unsigned split_index, - common::Span split_cat, bool default_left, + common::Span split_cat, bool default_left, bst_float base_weight, bst_float left_leaf_weight, - bst_float right_leaf_weight, bst_float loss_change, - float sum_hess, float left_sum, float right_sum); + bst_float right_leaf_weight, bst_float loss_change, float sum_hess, + float left_sum, float right_sum); bool HasCategoricalSplit() const { return !split_categories_.empty(); diff --git a/src/tree/hist/evaluate_splits.h b/src/tree/hist/evaluate_splits.h index fa6cc718b7f7..7fbd27d56ada 100644 --- a/src/tree/hist/evaluate_splits.h +++ b/src/tree/hist/evaluate_splits.h @@ -57,9 +57,9 @@ class HistEvaluator { } /** - * \brief Use learned direction with one-hot split. Other implementations (LGB, sklearn) - * create a pseudo-category for missing value but here we just do a complete scan - * to avoid making specialized histogram bin. + * \brief Use learned direction with one-hot split. Other implementations (LGB) create a + * pseudo-category for missing value but here we just do a complete scan to avoid + * making specialized histogram bin. */ void EnumerateOneHot(common::HistogramCuts const &cut, const common::GHistRow &hist, bst_feature_t fidx, bst_node_t nidx, @@ -76,6 +76,7 @@ class HistEvaluator { GradStats right_sum; // best split so far SplitEntry best; + best.is_cat = false; // marker for whether it's updated or not. auto f_hist = hist.subspan(cut_ptr[fidx], n_bins); auto feature_sum = GradStats{ @@ -98,8 +99,8 @@ class HistEvaluator { } // missing on right (treat missing as chosen category) - left_sum.SetSubstract(left_sum, missing); right_sum.Add(missing); + left_sum.SetSubstract(parent.stats, right_sum); if (IsValid(left_sum, right_sum)) { auto missing_right_chg = static_cast( evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum}, GradStats{right_sum}) - @@ -108,6 +109,13 @@ class HistEvaluator { } } + if (best.is_cat) { + auto n = common::CatBitField::ComputeStorageSize(n_bins + 1); + best.cat_bits.resize(n, 0); + common::CatBitField cat_bits{best.cat_bits}; + cat_bits.Set(best.split_value); + } + p_best->Update(best); } @@ -345,25 +353,11 @@ class HistEvaluator { evaluator.CalcWeight(candidate.nid, param_, GradStats{candidate.split.right_sum}); if (candidate.split.is_cat) { - std::vector split_cats; - if (candidate.split.cat_bits.empty()) { - if (common::InvalidCat(candidate.split.split_value)) { - common::InvalidCategory(); - } - auto cat = common::AsCat(candidate.split.split_value); - split_cats.resize(LBitField32::ComputeStorageSize(std::max(cat + 1, 1)), 0); - LBitField32 cat_bits; - cat_bits = LBitField32(split_cats); - cat_bits.Set(cat); - } else { - split_cats = candidate.split.cat_bits; - common::CatBitField cat_bits{split_cats}; - } tree.ExpandCategorical( - candidate.nid, candidate.split.SplitIndex(), split_cats, candidate.split.DefaultLeft(), - base_weight, left_weight * param_.learning_rate, right_weight * param_.learning_rate, - candidate.split.loss_chg, parent_sum.GetHess(), candidate.split.left_sum.GetHess(), - candidate.split.right_sum.GetHess()); + candidate.nid, candidate.split.SplitIndex(), candidate.split.cat_bits, + candidate.split.DefaultLeft(), base_weight, left_weight * param_.learning_rate, + right_weight * param_.learning_rate, candidate.split.loss_chg, parent_sum.GetHess(), + candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess()); } else { tree.ExpandNode(candidate.nid, candidate.split.SplitIndex(), candidate.split.split_value, candidate.split.DefaultLeft(), base_weight, diff --git a/src/tree/split_evaluator.h b/src/tree/split_evaluator.h index ba3533e84f43..a44522f88197 100644 --- a/src/tree/split_evaluator.h +++ b/src/tree/split_evaluator.h @@ -160,7 +160,7 @@ class TreeEvaluator { return; } - auto max_nidx = std::max(leftid, rightid); + size_t max_nidx = std::max(leftid, rightid); if (lower_bounds_.Size() <= max_nidx) { lower_bounds_.Resize(max_nidx * 2 + 1, -std::numeric_limits::max()); } diff --git a/src/tree/tree_model.cc b/src/tree/tree_model.cc index ec4beadddc59..d498c54ed9ef 100644 --- a/src/tree/tree_model.cc +++ b/src/tree/tree_model.cc @@ -808,11 +808,9 @@ void RegTree::ExpandNode(bst_node_t nid, unsigned split_index, bst_float split_v } void RegTree::ExpandCategorical(bst_node_t nid, unsigned split_index, - common::Span split_cat, bool default_left, - bst_float base_weight, - bst_float left_leaf_weight, - bst_float right_leaf_weight, - bst_float loss_change, float sum_hess, + common::Span split_cat, bool default_left, + bst_float base_weight, bst_float left_leaf_weight, + bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum) { this->ExpandNode(nid, split_index, std::numeric_limits::quiet_NaN(), default_left, base_weight, diff --git a/tests/python/test_updaters.py b/tests/python/test_updaters.py index b4cc1eb310bb..4e73bab317dd 100644 --- a/tests/python/test_updaters.py +++ b/tests/python/test_updaters.py @@ -10,10 +10,10 @@ 'nthread': strategies.integers(1, 4), 'max_depth': strategies.integers(1, 11), 'min_child_weight': strategies.floats(0.5, 2.0), - 'alpha': strategies.floats(0.0, 2.0), + 'alpha': strategies.floats(1e-5, 2.0), 'lambda': strategies.floats(1e-5, 2.0), 'eta': strategies.floats(0.01, 0.5), - 'gamma': strategies.floats(0.0, 2.0), + 'gamma': strategies.floats(1e-5, 2.0), 'seed': strategies.integers(0, 10), # We cannot enable subsampling as the training loss can increase # 'subsample': strategies.floats(0.5, 1.0),