diff --git a/src/common/categorical.h b/src/common/categorical.h index fedada7bd700..4cbbbf72ba60 100644 --- a/src/common/categorical.h +++ b/src/common/categorical.h @@ -5,11 +5,12 @@ #ifndef XGBOOST_COMMON_CATEGORICAL_H_ #define XGBOOST_COMMON_CATEGORICAL_H_ +#include "bitfield.h" #include "xgboost/base.h" #include "xgboost/data.h" -#include "xgboost/span.h" #include "xgboost/parameter.h" -#include "bitfield.h" +#include "xgboost/span.h" +#include "xgboost/task.h" namespace xgboost { namespace common { @@ -47,6 +48,15 @@ inline void InvalidCategory() { "should be non-negative."; } +/*! + * \brief Whether should we use onehot encoding for categorical data. + */ +inline bool UseOneHot(uint32_t n_cats, uint32_t max_cat_to_onehot, ObjInfo task) { + bool use_one_hot = n_cats < max_cat_to_onehot || + (task.task != ObjInfo::kRegression && task.task != ObjInfo::kBinary); + return use_one_hot; +} + struct IsCatOp { XGBOOST_DEVICE bool operator()(FeatureType ft) { return ft == FeatureType::kCategorical; diff --git a/src/tree/hist/evaluate_splits.h b/src/tree/hist/evaluate_splits.h index 24b99ed4a8c7..e8e6f50d88b2 100644 --- a/src/tree/hist/evaluate_splits.h +++ b/src/tree/hist/evaluate_splits.h @@ -6,13 +6,16 @@ #include #include +#include #include #include #include +#include "xgboost/task.h" #include "../param.h" #include "../constraints.h" #include "../split_evaluator.h" +#include "../../common/categorical.h" #include "../../common/random.h" #include "../../common/hist_util.h" #include "../../data/gradient_index.h" @@ -36,13 +39,13 @@ template class HistEvaluator { int32_t n_threads_ {0}; FeatureInteractionConstraintHost interaction_constraints_; std::vector snode_; + ObjInfo task_; // if sum of statistics for non-missing values in the node // is equal to sum of statistics for all values: // then - there are no missing values // else - there are missing values - bool static SplitContainsMissingValues(const GradStats e, - const NodeEntry &snode) { + bool static SplitContainsMissingValues(const GradStats e, const NodeEntry &snode) { if (e.GetGrad() == snode.stats.GetGrad() && e.GetHess() == snode.stats.GetHess()) { return false; @@ -50,38 +53,40 @@ template class HistEvaluator { return true; } } + enum SplitType { kNum = 0, kOneHot = 1, kPart = 2 }; // Enumerate/Scan the split values of specific feature // Returns the sum of gradients corresponding to the data points that contains // a non-missing value for the particular feature fid. - template - GradStats EnumerateSplit( - common::HistogramCuts const &cut, const common::GHistRow &hist, - const NodeEntry &snode, SplitEntry *p_best, bst_feature_t fidx, - bst_node_t nidx, - TreeEvaluator::SplitEvaluator const &evaluator) const { + template + GradStats EnumerateSplit(common::HistogramCuts const &cut, common::Span sorted_idx, + const common::GHistRow &hist, bst_feature_t fidx, + bst_node_t nidx, + TreeEvaluator::SplitEvaluator const &evaluator, + SplitEntry *p_best) const { static_assert(d_step == +1 || d_step == -1, "Invalid step."); // aliases const std::vector &cut_ptr = cut.Ptrs(); const std::vector &cut_val = cut.Values(); + auto const &parent = snode_[nidx]; + int32_t n_bins{static_cast(cut_ptr.at(fidx + 1) - cut_ptr[fidx])}; + auto f_hist = hist.subspan(cut_ptr[fidx], n_bins); // statistics on both sides of split - GradStats c; - GradStats e; + GradStats left_sum; + GradStats right_sum; // best split so far SplitEntry best; // bin boundaries - CHECK_LE(cut_ptr[fidx], - static_cast(std::numeric_limits::max())); - CHECK_LE(cut_ptr[fidx + 1], - static_cast(std::numeric_limits::max())); - // imin: index (offset) of the minimum value for feature fid - // need this for backward enumeration + CHECK_LE(cut_ptr[fidx], static_cast(std::numeric_limits::max())); + CHECK_LE(cut_ptr[fidx + 1], static_cast(std::numeric_limits::max())); + // imin: index (offset) of the minimum value for feature fid need this for backward + // enumeration const auto imin = static_cast(cut_ptr[fidx]); - // ibegin, iend: smallest/largest cut points for feature fid - // use int to allow for value -1 + // ibegin, iend: smallest/largest cut points for feature fid use int to allow for + // value -1 int32_t ibegin, iend; if (d_step > 0) { ibegin = static_cast(cut_ptr[fidx]); @@ -91,49 +96,118 @@ template class HistEvaluator { iend = static_cast(cut_ptr[fidx]) - 1; } + auto calc_bin_value = [&](auto i) { + switch (split_type) { + case kNum: { + left_sum.Add(hist[i].GetGrad(), hist[i].GetHess()); + right_sum.SetSubstract(parent.stats, left_sum); + break; + } + case kOneHot: { + // not-chosen categories go to left + right_sum = GradStats{hist[i]}; + left_sum.SetSubstract(parent.stats, right_sum); + break; + } + case kPart: { + auto j = d_step == 1 ? (i - ibegin) : (ibegin - i); + right_sum.Add(f_hist[sorted_idx[j]].GetGrad(), f_hist[sorted_idx[j]].GetHess()); + left_sum.SetSubstract(parent.stats, right_sum); + break; + } + default: { + std::terminate(); + } + } + }; + + int32_t best_thresh{-1}; for (int32_t i = ibegin; i != iend; i += d_step) { // start working // try to find a split - e.Add(hist[i].GetGrad(), hist[i].GetHess()); - if (e.GetHess() >= param_.min_child_weight) { - c.SetSubstract(snode.stats, e); - if (c.GetHess() >= param_.min_child_weight) { - bst_float loss_chg; - bst_float split_pt; - if (d_step > 0) { - // forward enumeration: split at right bound of each bin - loss_chg = static_cast( - evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{e}, - GradStats{c}) - - snode.root_gain); - split_pt = cut_val[i]; - best.Update(loss_chg, fidx, split_pt, d_step == -1, e, c); - } else { - // backward enumeration: split at left bound of each bin - loss_chg = static_cast( - evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{c}, - GradStats{e}) - - snode.root_gain); - if (i == imin) { - // for leftmost bin, left bound is the smallest feature value - split_pt = cut.MinValues()[fidx]; - } else { - split_pt = cut_val[i - 1]; + calc_bin_value(i); + bool improved{false}; + if (left_sum.GetHess() >= param_.min_child_weight && + right_sum.GetHess() >= param_.min_child_weight) { + bst_float loss_chg; + bst_float split_pt; + if (d_step > 0) { + // forward enumeration: split at right bound of each bin + loss_chg = + static_cast(evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum}, + GradStats{right_sum}) - + parent.root_gain); + split_pt = cut_val[i]; + improved = best.Update(loss_chg, fidx, split_pt, d_step == -1, split_type != kNum, + left_sum, right_sum); + } else { + // backward enumeration: split at left bound of each bin + loss_chg = + static_cast(evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{right_sum}, + GradStats{left_sum}) - + parent.root_gain); + switch (split_type) { + case kNum: { + if (i == imin) { + split_pt = cut.MinValues()[fidx]; + } else { + split_pt = cut_val[i - 1]; + } + break; + } + case kOneHot: { + split_pt = cut_val[i]; + break; + } + case kPart: { + split_pt = cut_val[i]; + break; } - best.Update(loss_chg, fidx, split_pt, d_step == -1, c, e); } + improved = best.Update(loss_chg, fidx, split_pt, d_step == -1, split_type != kNum, + right_sum, left_sum); + } + if (improved) { + best_thresh = i; } } } + + if (split_type == kPart && best_thresh != -1) { + auto n = common::CatBitField::ComputeStorageSize(n_bins); + best.cat_bits.resize(n, 0); + common::CatBitField cat_bits{best.cat_bits}; + + if (d_step == 1) { + std::for_each(sorted_idx.begin(), sorted_idx.begin() + (best_thresh - ibegin + 1), + [&cat_bits](size_t c) { cat_bits.Set(c); }); + } else { + std::for_each(sorted_idx.rbegin(), sorted_idx.rbegin() + (ibegin - best_thresh), + [&cat_bits](size_t c) { cat_bits.Set(c); }); + } + } p_best->Update(best); - return e; + switch (split_type) { + case kNum: + // Normal, accumulated to left + return left_sum; + case kOneHot: + // Doesn't matter, not accumulating. + return {}; + case kPart: + // Accumulated to right due to chosen cats go to right. + return right_sum; + } + return left_sum; } public: void EvaluateSplits(const common::HistCollection &hist, - common::HistogramCuts const &cut, const RegTree &tree, - std::vector* p_entries) { + common::HistogramCuts const &cut, + common::Span feature_types, + const RegTree &tree, + std::vector *p_entries) { auto& entries = *p_entries; // All nodes are on the same level, so we can store the shared ptr. std::vector>> features( @@ -150,7 +224,7 @@ template class HistEvaluator { return features[nidx_in_set]->Size(); }, grain_size); - std::vector tloc_candidates(omp_get_max_threads() * entries.size()); + std::vector tloc_candidates(n_threads_ * entries.size()); for (size_t i = 0; i < entries.size(); ++i) { for (decltype(n_threads_) j = 0; j < n_threads_; ++j) { tloc_candidates[i * n_threads_ + j] = entries[i]; @@ -167,12 +241,37 @@ template class HistEvaluator { auto features_set = features[nidx_in_set]->ConstHostSpan(); for (auto fidx_in_set = r.begin(); fidx_in_set < r.end(); fidx_in_set++) { auto fidx = features_set[fidx_in_set]; - if (interaction_constraints_.Query(nidx, fidx)) { - auto grad_stats = EnumerateSplit<+1>(cut, histogram, snode_[nidx], - best, fidx, nidx, evaluator); + bool is_cat = common::IsCat(feature_types, fidx); + if (!interaction_constraints_.Query(nidx, fidx)) { + continue; + } + if (is_cat) { + auto n_bins = cut.Ptrs().at(fidx + 1) - cut.Ptrs()[fidx]; + if (common::UseOneHot(n_bins, param_.max_cat_to_onehot, task_)) { + EnumerateSplit<+1, kOneHot>(cut, {}, histogram, fidx, nidx, evaluator, best); + EnumerateSplit<-1, kOneHot>(cut, {}, histogram, fidx, nidx, evaluator, best); + } else { + auto const &cut_ptr = cut.Ptrs(); + std::vector sorted_idx(n_bins); + std::iota(sorted_idx.begin(), sorted_idx.end(), 0); + auto feat_hist = histogram.subspan(cut_ptr[fidx], n_bins); + std::stable_sort(sorted_idx.begin(), sorted_idx.end(), [&](size_t l, size_t r) { + auto ret = evaluator.CalcWeightCat(param_, feat_hist[l]) < + evaluator.CalcWeightCat(param_, feat_hist[r]); + static_assert(std::is_same::value, ""); + return ret; + }); + auto grad_stats = + EnumerateSplit<+1, kPart>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best); + if (SplitContainsMissingValues(grad_stats, snode_[nidx])) { + EnumerateSplit<-1, kPart>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best); + } + } + } else { + auto grad_stats = + EnumerateSplit<+1, kNum>(cut, {}, histogram, fidx, nidx, evaluator, best); if (SplitContainsMissingValues(grad_stats, snode_[nidx])) { - EnumerateSplit<-1>(cut, histogram, snode_[nidx], best, fidx, nidx, - evaluator); + EnumerateSplit<-1, kNum>(cut, {}, histogram, fidx, nidx, evaluator, best); } } } @@ -187,7 +286,7 @@ template class HistEvaluator { } } // Add splits to tree, handles all statistic - void ApplyTreeSplit(ExpandEntry candidate, RegTree *p_tree) { + void ApplyTreeSplit(ExpandEntry const& candidate, RegTree *p_tree) { auto evaluator = tree_evaluator_.GetEvaluator(); RegTree &tree = *p_tree; @@ -201,13 +300,31 @@ template class HistEvaluator { auto right_weight = evaluator.CalcWeight( candidate.nid, param_, GradStats{candidate.split.right_sum}); - tree.ExpandNode(candidate.nid, candidate.split.SplitIndex(), - candidate.split.split_value, candidate.split.DefaultLeft(), - base_weight, left_weight * param_.learning_rate, - right_weight * param_.learning_rate, - candidate.split.loss_chg, parent_sum.GetHess(), - candidate.split.left_sum.GetHess(), - candidate.split.right_sum.GetHess()); + if (candidate.split.is_cat) { + std::vector split_cats; + if (candidate.split.cat_bits.empty()) { + CHECK_LT(candidate.split.split_value, std::numeric_limits::max()) + << "Categorical feature value too large."; + auto cat = common::AsCat(candidate.split.split_value); + split_cats.resize(LBitField32::ComputeStorageSize(std::max(cat + 1, 1)), 0); + LBitField32 cat_bits; + cat_bits = LBitField32(split_cats); + cat_bits.Set(cat); + } else { + split_cats = candidate.split.cat_bits; + } + + tree.ExpandCategorical( + candidate.nid, candidate.split.SplitIndex(), split_cats, candidate.split.DefaultLeft(), + base_weight, left_weight, right_weight, candidate.split.loss_chg, parent_sum.GetHess(), + candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess()); + } else { + tree.ExpandNode(candidate.nid, candidate.split.SplitIndex(), candidate.split.split_value, + candidate.split.DefaultLeft(), base_weight, + left_weight * param_.learning_rate, right_weight * param_.learning_rate, + candidate.split.loss_chg, parent_sum.GetHess(), + candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess()); + } // Set up child constraints auto left_child = tree[candidate.nid].LeftChild(); @@ -249,14 +366,14 @@ template class HistEvaluator { public: // The column sampler must be constructed by caller since we need to preserve the rng // for the entire training session. - explicit HistEvaluator(TrainParam const ¶m, MetaInfo const &info, - int32_t n_threads, - std::shared_ptr sampler, + explicit HistEvaluator(TrainParam const ¶m, MetaInfo const &info, int32_t n_threads, + std::shared_ptr sampler, ObjInfo task, bool skip_0_index = false) - : param_{param}, column_sampler_{std::move(sampler)}, - tree_evaluator_{param, static_cast(info.num_col_), - GenericParameter::kCpuId}, - n_threads_{n_threads} { + : param_{param}, + column_sampler_{std::move(sampler)}, + tree_evaluator_{param, static_cast(info.num_col_), GenericParameter::kCpuId}, + n_threads_{n_threads}, + task_{task} { interaction_constraints_.Configure(param, info.num_col_); column_sampler_->Init(info.num_col_, info.feature_weigths.HostVector(), param_.colsample_bynode, param_.colsample_bylevel, diff --git a/src/tree/param.h b/src/tree/param.h index bebf5b8d6363..995d197d735c 100644 --- a/src/tree/param.h +++ b/src/tree/param.h @@ -1,5 +1,5 @@ /*! - * Copyright 2014-2019 by Contributors + * Copyright 2014-2021 by Contributors * \file param.h * \brief training parameters, statistics used to support tree construction. * \author Tianqi Chen @@ -7,6 +7,7 @@ #ifndef XGBOOST_TREE_PARAM_H_ #define XGBOOST_TREE_PARAM_H_ +#include #include #include #include @@ -15,6 +16,7 @@ #include "xgboost/parameter.h" #include "xgboost/data.h" +#include "../common/categorical.h" #include "../common/math.h" namespace xgboost { @@ -36,6 +38,8 @@ struct TrainParam : public XGBoostParameter { enum TreeGrowPolicy { kDepthWise = 0, kLossGuide = 1 }; int grow_policy; + uint32_t max_cat_to_onehot{1}; + //----- the rest parameters are less important ---- // minimum amount of hessian(weight) allowed in a child float min_child_weight; @@ -119,6 +123,10 @@ struct TrainParam : public XGBoostParameter { "Tree growing policy. 0: favor splitting at nodes closest to the node, " "i.e. grow depth-wise. 1: favor splitting at nodes with highest loss " "change. (cf. LightGBM)"); + DMLC_DECLARE_FIELD(max_cat_to_onehot) + .set_default(4) + .set_lower_bound(1) + .describe("Maximum number of categories to use one-hot encoding based split."); DMLC_DECLARE_FIELD(min_child_weight) .set_lower_bound(0.0f) .set_default(1.0f) @@ -384,6 +392,8 @@ struct SplitEntryContainer { /*! \brief split index */ bst_feature_t sindex{0}; bst_float split_value{0.0f}; + std::vector cat_bits; + bool is_cat{false}; GradientT left_sum; GradientT right_sum; @@ -433,6 +443,8 @@ struct SplitEntryContainer { this->loss_chg = e.loss_chg; this->sindex = e.sindex; this->split_value = e.split_value; + this->is_cat = e.is_cat; + this->cat_bits = e.cat_bits; this->left_sum = e.left_sum; this->right_sum = e.right_sum; return true; @@ -449,9 +461,8 @@ struct SplitEntryContainer { * \return whether the proposed split is better and can replace current split */ bool Update(bst_float new_loss_chg, unsigned split_index, - bst_float new_split_value, bool default_left, - const GradientT &left_sum, - const GradientT &right_sum) { + bst_float new_split_value, bool default_left, bool is_cat, + const GradientT &left_sum, const GradientT &right_sum) { if (this->NeedReplace(new_loss_chg, split_index)) { this->loss_chg = new_loss_chg; if (default_left) { @@ -459,6 +470,31 @@ struct SplitEntryContainer { } this->sindex = split_index; this->split_value = new_split_value; + this->is_cat = is_cat; + this->left_sum = left_sum; + this->right_sum = right_sum; + return true; + } else { + return false; + } + } + + /*! + * \brief Update with partition based categorical split. + * + * \return Whether the proposed split is better and can replace current split. + */ + bool Update(float new_loss_chg, bst_feature_t split_index, common::KCatBitField cats, + bool default_left, GradientT const &left_sum, GradientT const &right_sum) { + if (this->NeedReplace(new_loss_chg, split_index)) { + this->loss_chg = new_loss_chg; + if (default_left) { + split_index |= (1U << 31); + } + this->sindex = split_index; + cat_bits.resize(cats.Bits().size()); + std::copy(cats.Bits().begin(), cats.Bits().end(), cat_bits.begin()); + this->is_cat = true; this->left_sum = left_sum; this->right_sum = right_sum; return true; diff --git a/src/tree/split_evaluator.h b/src/tree/split_evaluator.h index 069718a27378..4fdf70145a95 100644 --- a/src/tree/split_evaluator.h +++ b/src/tree/split_evaluator.h @@ -92,7 +92,7 @@ class TreeEvaluator { XGBOOST_DEVICE float CalcWeight(bst_node_t nodeid, const ParamT ¶m, tree::GradStats const& stats) const { - float w = xgboost::tree::CalcWeight(param, stats); + float w = ::xgboost::tree::CalcWeight(param, stats); if (!has_constraint) { return w; } @@ -107,6 +107,12 @@ class TreeEvaluator { return w; } } + + template + XGBOOST_DEVICE double CalcWeightCat(ParamT const& param, GradientSumT const& stats) const { + return ::xgboost::tree::CalcWeight(param, stats); + } + XGBOOST_DEVICE float CalcGainGivenWeight(ParamT const &p, tree::GradStats const& stats, float w) const { if (stats.GetHess() <= 0) { diff --git a/src/tree/updater_colmaker.cc b/src/tree/updater_colmaker.cc index 3b0a74f3605f..52fb85821543 100644 --- a/src/tree/updater_colmaker.cc +++ b/src/tree/updater_colmaker.cc @@ -336,10 +336,10 @@ class ColMaker: public TreeUpdater { bst_float proposed_split = (fvalue + e.last_fvalue) * 0.5f; if ( proposed_split == fvalue ) { e.best.Update(loss_chg, fid, e.last_fvalue, - d_step == -1, c, e.stats); + d_step == -1, false, c, e.stats); } else { e.best.Update(loss_chg, fid, proposed_split, - d_step == -1, c, e.stats); + d_step == -1, false, c, e.stats); } } else { loss_chg = static_cast( @@ -348,10 +348,10 @@ class ColMaker: public TreeUpdater { bst_float proposed_split = (fvalue + e.last_fvalue) * 0.5f; if ( proposed_split == fvalue ) { e.best.Update(loss_chg, fid, e.last_fvalue, - d_step == -1, e.stats, c); + d_step == -1, false, e.stats, c); } else { e.best.Update(loss_chg, fid, proposed_split, - d_step == -1, e.stats, c); + d_step == -1, false, e.stats, c); } } } @@ -430,14 +430,14 @@ class ColMaker: public TreeUpdater { loss_chg = static_cast( evaluator.CalcSplitGain(param_, nid, fid, c, e.stats) - snode_[nid].root_gain); - e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1, c, - e.stats); + e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1, + false, c, e.stats); } else { loss_chg = static_cast( evaluator.CalcSplitGain(param_, nid, fid, e.stats, c) - snode_[nid].root_gain); e.best.Update(loss_chg, fid, e.last_fvalue + delta, d_step == -1, - e.stats, c); + false, e.stats, c); } } } diff --git a/src/tree/updater_histmaker.cc b/src/tree/updater_histmaker.cc index ac040f14ebda..7552e03034ed 100644 --- a/src/tree/updater_histmaker.cc +++ b/src/tree/updater_histmaker.cc @@ -173,7 +173,8 @@ class HistMaker: public BaseMaker { if (c.sum_hess >= param_.min_child_weight) { double loss_chg = CalcGain(param_, s.GetGrad(), s.GetHess()) + CalcGain(param_, c.GetGrad(), c.GetHess()) - root_gain; - if (best->Update(static_cast(loss_chg), fid, hist.cut[i], false, s, c)) { + if (best->Update(static_cast(loss_chg), fid, hist.cut[i], + false, false, s, c)) { *left_sum = s; } } @@ -187,7 +188,8 @@ class HistMaker: public BaseMaker { if (c.sum_hess >= param_.min_child_weight) { double loss_chg = CalcGain(param_, s.GetGrad(), s.GetHess()) + CalcGain(param_, c.GetGrad(), c.GetHess()) - root_gain; - if (best->Update(static_cast(loss_chg), fid, hist.cut[i-1], true, c, s)) { + if (best->Update(static_cast(loss_chg), fid, + hist.cut[i - 1], true, false, c, s)) { *left_sum = c; } } diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 00c35612f2a0..1207a57102f5 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -168,9 +168,11 @@ void QuantileHistMaker::Builder::InitRoot( std::vector entries{node}; builder_monitor_.Start("EvaluateSplits"); + auto ft = p_fmat->Info().feature_types.ConstHostSpan(); for (auto const &gmat : p_fmat->GetBatches( BatchParam{GenericParameter::kCpuId, param_.max_bin})) { - evaluator_->EvaluateSplits(histogram_builder_->Histogram(), gmat.cut, *p_tree, &entries); + evaluator_->EvaluateSplits(histogram_builder_->Histogram(), gmat.cut, ft, + *p_tree, &entries); break; } builder_monitor_.Stop("EvaluateSplits"); @@ -272,8 +274,9 @@ void QuantileHistMaker::Builder::ExpandTree( } builder_monitor_.Start("EvaluateSplits"); - evaluator_->EvaluateSplits(this->histogram_builder_->Histogram(), gmat.cut, - *p_tree, &nodes_to_evaluate); + auto ft = p_fmat->Info().feature_types.ConstHostSpan(); + evaluator_->EvaluateSplits(this->histogram_builder_->Histogram(), + gmat.cut, ft, *p_tree, &nodes_to_evaluate); builder_monitor_.Stop("EvaluateSplits"); for (size_t i = 0; i < nodes_for_apply_split.size(); ++i) { @@ -529,11 +532,11 @@ void QuantileHistMaker::Builder::InitData( // store a pointer to the tree p_last_tree_ = &tree; if (data_layout_ == DataLayout::kDenseDataOneBased) { - evaluator_.reset(new HistEvaluator{param_, info, this->nthread_, - column_sampler_, true}); + evaluator_.reset(new HistEvaluator{ + param_, info, this->nthread_, column_sampler_, task_, true}); } else { - evaluator_.reset(new HistEvaluator{param_, info, this->nthread_, - column_sampler_, false}); + evaluator_.reset(new HistEvaluator{ + param_, info, this->nthread_, column_sampler_, task_, false}); } if (data_layout_ == DataLayout::kDenseDataZeroBased diff --git a/tests/cpp/categorical_helpers.h b/tests/cpp/categorical_helpers.h new file mode 100644 index 000000000000..f4470a6c910e --- /dev/null +++ b/tests/cpp/categorical_helpers.h @@ -0,0 +1,44 @@ +/*! + * Copyright 2021 by XGBoost Contributors + * + * \brief Utilities for testing categorical data support. + */ +#include +#include + +#include "xgboost/span.h" +#include "helpers.h" +#include "../../src/common/categorical.h" + +namespace xgboost { +inline std::vector OneHotEncodeFeature(std::vector x, + size_t num_cat) { + std::vector ret(x.size() * num_cat, 0); + size_t n_rows = x.size(); + for (size_t r = 0; r < n_rows; ++r) { + bst_cat_t cat = common::AsCat(x[r]); + ret.at(num_cat * r + cat) = 1; + } + return ret; +} + +template +void ValidateCategoricalHistogram(size_t n_categories, + common::Span onehot, + common::Span cat) { + auto cat_sum = std::accumulate(cat.cbegin(), cat.cend(), GradientPairPrecise{}); + for (size_t c = 0; c < n_categories; ++c) { + auto zero = onehot[c * 2]; + auto one = onehot[c * 2 + 1]; + + auto chosen = cat[c]; + auto not_chosen = cat_sum - chosen; + + ASSERT_LE(RelError(zero.GetGrad(), not_chosen.GetGrad()), kRtEps); + ASSERT_LE(RelError(zero.GetHess(), not_chosen.GetHess()), kRtEps); + + ASSERT_LE(RelError(one.GetGrad(), chosen.GetGrad()), kRtEps); + ASSERT_LE(RelError(one.GetHess(), chosen.GetHess()), kRtEps); + } +} +} // namespace xgboost diff --git a/tests/cpp/common/test_quantile.cu b/tests/cpp/common/test_quantile.cu index 586ff762fc2b..c124ab5055e6 100644 --- a/tests/cpp/common/test_quantile.cu +++ b/tests/cpp/common/test_quantile.cu @@ -5,6 +5,13 @@ #include "../../../src/common/quantile.cuh" namespace xgboost { +namespace { +struct IsSorted { + XGBOOST_DEVICE bool operator()(common::SketchEntry const& a, common::SketchEntry const& b) const { + return a.value < b.value; + } +}; +} namespace common { TEST(GPUQuantile, Basic) { constexpr size_t kRows = 1000, kCols = 100, kBins = 256; @@ -52,9 +59,15 @@ void TestSketchUnique(float sparsity) { ASSERT_EQ(sketch.Data().size(), h_columns_ptr.back()); sketch.Unique(); - ASSERT_TRUE(thrust::is_sorted(thrust::device, sketch.Data().data(), - sketch.Data().data() + sketch.Data().size(), - detail::SketchUnique{})); + + std::vector h_data(sketch.Data().size()); + thrust::copy(dh::tcbegin(sketch.Data()), dh::tcend(sketch.Data()), h_data.begin()); + + for (size_t i = 1; i < h_columns_ptr.size(); ++i) { + auto begin = h_columns_ptr[i - 1]; + auto column = common::Span(h_data).subspan(begin, h_columns_ptr[i] - begin); + ASSERT_TRUE(std::is_sorted(column.begin(), column.end(), IsSorted{})); + } }); } @@ -84,8 +97,7 @@ void TestQuantileElemRank(int32_t device, Span in, if (with_error) { ASSERT_GE(in_column[idx].rmin + in_column[idx].rmin * kRtEps, prev_rmin); - ASSERT_GE(in_column[idx].rmax + in_column[idx].rmin * kRtEps, - prev_rmax); + ASSERT_GE(in_column[idx].rmax + in_column[idx].rmin * kRtEps, prev_rmax); ASSERT_GE(in_column[idx].rmax + in_column[idx].rmin * kRtEps, rmin_next); } else { @@ -169,7 +181,7 @@ TEST(GPUQuantile, MergeEmpty) { TEST(GPUQuantile, MergeBasic) { constexpr size_t kRows = 1000, kCols = 100; - RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const& info) { + RunWithSeedsAndBins(kRows, [=](int32_t seed, size_t n_bins, MetaInfo const &info) { HostDeviceVector ft; SketchContainer sketch_0(ft, n_bins, kCols, kRows, 0); HostDeviceVector storage_0; @@ -265,9 +277,16 @@ void TestMergeDuplicated(int32_t n_bins, size_t cols, size_t rows, float frac) { ASSERT_EQ(h_columns_ptr.back(), sketch_1.Data().size() + size_before_merge); sketch_0.Unique(); - ASSERT_TRUE(thrust::is_sorted(thrust::device, sketch_0.Data().data(), - sketch_0.Data().data() + sketch_0.Data().size(), - detail::SketchUnique{})); + columns_ptr = sketch_0.ColumnsPtr(); + dh::CopyDeviceSpanToVector(&h_columns_ptr, columns_ptr); + + std::vector h_data(sketch_0.Data().size()); + dh::CopyDeviceSpanToVector(&h_data, sketch_0.Data()); + for (size_t i = 1; i < h_columns_ptr.size(); ++i) { + auto begin = h_columns_ptr[i - 1]; + auto column = Span {h_data}.subspan(begin, h_columns_ptr[i] - begin); + ASSERT_TRUE(std::is_sorted(column.begin(), column.end(), IsSorted{})); + } } TEST(GPUQuantile, MergeDuplicated) { diff --git a/tests/cpp/common/test_quantile.h b/tests/cpp/common/test_quantile.h index 083766cfbb5a..8118248dc939 100644 --- a/tests/cpp/common/test_quantile.h +++ b/tests/cpp/common/test_quantile.h @@ -48,7 +48,9 @@ template void RunWithSeedsAndBins(size_t rows, Fn fn) { std::vector infos(2); auto& h_weights = infos.front().weights_.HostVector(); h_weights.resize(rows); - std::generate(h_weights.begin(), h_weights.end(), [&]() { return dist(&lcg); }); + + SimpleRealUniformDistribution weight_dist(0, 10); + std::generate(h_weights.begin(), h_weights.end(), [&]() { return weight_dist(&lcg); }); for (auto seed : seeds) { for (auto n_bin : bins) { diff --git a/tests/cpp/helpers.cc b/tests/cpp/helpers.cc index 559a999e3e12..db4454c9eb1e 100644 --- a/tests/cpp/helpers.cc +++ b/tests/cpp/helpers.cc @@ -172,12 +172,10 @@ SimpleLCG::StateType SimpleLCG::operator()() { state_ = (alpha_ * state_) % mod_; return state_; } -SimpleLCG::StateType SimpleLCG::Min() const { - return seed_ * alpha_; -} -SimpleLCG::StateType SimpleLCG::Max() const { - return max_value_; -} +SimpleLCG::StateType SimpleLCG::Min() const { return min(); } +SimpleLCG::StateType SimpleLCG::Max() const { return max(); } +// Make sure it's compile time constant. +static_assert(SimpleLCG::max() - SimpleLCG::min(), ""); void RandomDataGenerator::GenerateDense(HostDeviceVector *out) const { xgboost::SimpleRealUniformDistribution dist(lower_, upper_); @@ -291,6 +289,7 @@ void RandomDataGenerator::GenerateCSR( xgboost::SimpleRealUniformDistribution dist(lower_, upper_); float sparsity = sparsity_ * (upper_ - lower_) + lower_; + SimpleRealUniformDistribution cat(0.0, max_cat_); h_rptr.emplace_back(0); for (size_t i = 0; i < rows_; ++i) { @@ -298,7 +297,11 @@ void RandomDataGenerator::GenerateCSR( for (size_t j = 0; j < cols_; ++j) { auto g = dist(&lcg); if (g >= sparsity) { - g = dist(&lcg); + if (common::IsCat(ft_, j)) { + g = common::AsCat(cat(&lcg)); + } else { + g = dist(&lcg); + } h_value.emplace_back(g); rptr++; h_cols.emplace_back(j); @@ -347,11 +350,15 @@ RandomDataGenerator::GenerateDMatrix(bool with_label, bool float_label, } if (device_ >= 0) { out->Info().labels_.SetDevice(device_); + out->Info().feature_types.SetDevice(device_); for (auto const& page : out->GetBatches()) { page.data.SetDevice(device_); page.offset.SetDevice(device_); } } + if (!ft_.empty()) { + out->Info().feature_types.HostVector() = ft_; + } return out; } diff --git a/tests/cpp/helpers.h b/tests/cpp/helpers.h index 47150ab4e6a5..c424d65ced05 100644 --- a/tests/cpp/helpers.h +++ b/tests/cpp/helpers.h @@ -106,42 +106,39 @@ bool IsNear(std::vector::const_iterator _beg1, */ class SimpleLCG { private: - using StateType = int64_t; + using StateType = uint64_t; static StateType constexpr kDefaultInit = 3; - static StateType constexpr default_alpha_ = 61; - static StateType constexpr max_value_ = ((StateType)1 << 32) - 1; + static StateType constexpr kDefaultAlpha = 61; + static StateType constexpr kMaxValue = (static_cast(1) << 32) - 1; StateType state_; StateType const alpha_; StateType const mod_; - StateType seed_; + public: + using result_type = StateType; // NOLINT public: - SimpleLCG() : state_{kDefaultInit}, - alpha_{default_alpha_}, mod_{max_value_}, seed_{state_}{} + SimpleLCG() : state_{kDefaultInit}, alpha_{kDefaultAlpha}, mod_{kMaxValue} {} SimpleLCG(SimpleLCG const& that) = default; SimpleLCG(SimpleLCG&& that) = default; - void Seed(StateType seed) { - seed_ = seed; - } + void Seed(StateType seed) { state_ = seed % mod_; } /*! * \brief Initialize SimpleLCG. * * \param state Initial state, can also be considered as seed. If set to * zero, SimpleLCG will use internal default value. - * \param alpha multiplier - * \param mod modulo */ - explicit SimpleLCG(StateType state, - StateType alpha=default_alpha_, StateType mod=max_value_) - : state_{state == 0 ? kDefaultInit : state}, - alpha_{alpha}, mod_{mod} , seed_{state} {} + explicit SimpleLCG(StateType state) + : state_{state == 0 ? kDefaultInit : state}, alpha_{kDefaultAlpha}, mod_{kMaxValue} {} StateType operator()(); StateType Min() const; StateType Max() const; + + constexpr result_type static min() { return 0; }; // NOLINT + constexpr result_type static max() { return kMaxValue; } // NOLINT }; template @@ -217,10 +214,12 @@ class RandomDataGenerator { float upper_; int32_t device_; - int32_t seed_; + uint64_t seed_; SimpleLCG lcg_; size_t bins_; + std::vector ft_; + bst_cat_t max_cat_; Json ArrayInterfaceImpl(HostDeviceVector *storage, size_t rows, size_t cols) const; @@ -242,7 +241,7 @@ class RandomDataGenerator { device_ = d; return *this; } - RandomDataGenerator& Seed(int32_t s) { + RandomDataGenerator& Seed(uint64_t s) { seed_ = s; lcg_.Seed(seed_); return *this; @@ -251,6 +250,16 @@ class RandomDataGenerator { bins_ = b; return *this; } + RandomDataGenerator& Type(common::Span ft) { + CHECK_EQ(ft.size(), cols_); + ft_.resize(ft.size()); + std::copy(ft.cbegin(), ft.cend(), ft_.begin()); + return *this; + } + RandomDataGenerator& MaxCategory(bst_cat_t cat) { + max_cat_ = cat; + return *this; + } void GenerateDense(HostDeviceVector* out) const; diff --git a/tests/cpp/tree/gpu_hist/test_histogram.cu b/tests/cpp/tree/gpu_hist/test_histogram.cu index 5c659965759b..3b543a48d7cc 100644 --- a/tests/cpp/tree/gpu_hist/test_histogram.cu +++ b/tests/cpp/tree/gpu_hist/test_histogram.cu @@ -1,9 +1,11 @@ #include #include -#include "../../helpers.h" + #include "../../../../src/common/categorical.h" -#include "../../../../src/tree/gpu_hist/row_partitioner.cuh" #include "../../../../src/tree/gpu_hist/histogram.cuh" +#include "../../../../src/tree/gpu_hist/row_partitioner.cuh" +#include "../../categorical_helpers.h" +#include "../../helpers.h" namespace xgboost { namespace tree { @@ -99,16 +101,6 @@ TEST(Histogram, GPUDeterministic) { } } -std::vector OneHotEncodeFeature(std::vector x, size_t num_cat) { - std::vector ret(x.size() * num_cat, 0); - size_t n_rows = x.size(); - for (size_t r = 0; r < n_rows; ++r) { - bst_cat_t cat = common::AsCat(x[r]); - ret.at(num_cat * r + cat) = 1; - } - return ret; -} - // Test 1 vs rest categorical histogram is equivalent to one hot encoded data. void TestGPUHistogramCategorical(size_t num_categories) { size_t constexpr kRows = 340; @@ -123,7 +115,9 @@ void TestGPUHistogramCategorical(size_t num_categories) { auto gpair = GenerateRandomGradients(kRows, 0, 2); gpair.SetDevice(0); auto rounding = CreateRoundingFactor(gpair.DeviceSpan()); - // Generate hist with cat data. + /** + * Generate hist with cat data. + */ for (auto const &batch : cat_m->GetBatches(batch_param)) { auto* page = batch.Impl(); FeatureGroups single_group(page->Cuts()); @@ -133,7 +127,9 @@ void TestGPUHistogramCategorical(size_t num_categories) { rounding); } - // Generate hist with one hot encoded data. + /** + * Generate hist with one hot encoded data. + */ auto x_encoded = OneHotEncodeFeature(x, num_categories); auto encode_m = GetDMatrixFromData(x_encoded, kRows, num_categories); dh::device_vector encode_hist(2 * num_categories); @@ -152,20 +148,9 @@ void TestGPUHistogramCategorical(size_t num_categories) { std::vector h_encode_hist(encode_hist.size()); thrust::copy(encode_hist.begin(), encode_hist.end(), h_encode_hist.begin()); - - for (size_t c = 0; c < num_categories; ++c) { - auto zero = h_encode_hist[c * 2]; - auto one = h_encode_hist[c * 2 + 1]; - - auto chosen = h_cat_hist[c]; - auto not_chosen = cat_sum - chosen; - - ASSERT_LE(RelError(zero.GetGrad(), not_chosen.GetGrad()), kRtEps); - ASSERT_LE(RelError(zero.GetHess(), not_chosen.GetHess()), kRtEps); - - ASSERT_LE(RelError(one.GetGrad(), chosen.GetGrad()), kRtEps); - ASSERT_LE(RelError(one.GetHess(), chosen.GetHess()), kRtEps); - } + ValidateCategoricalHistogram(num_categories, + common::Span{h_encode_hist}, + common::Span{h_cat_hist}); } TEST(Histogram, GPUHistCategorical) { diff --git a/tests/cpp/tree/hist/test_evaluate_splits.cc b/tests/cpp/tree/hist/test_evaluate_splits.cc index cb0171269305..115dcb0297dd 100644 --- a/tests/cpp/tree/hist/test_evaluate_splits.cc +++ b/tests/cpp/tree/hist/test_evaluate_splits.cc @@ -7,7 +7,6 @@ namespace xgboost { namespace tree { - template void TestEvaluateSplits() { int static constexpr kRows = 8, kCols = 16; auto orig = omp_get_max_threads(); @@ -16,14 +15,12 @@ template void TestEvaluateSplits() { auto sampler = std::make_shared(); TrainParam param; - param.UpdateAllowUnknown(Args{{}}); - param.min_child_weight = 0; - param.reg_lambda = 0; + param.UpdateAllowUnknown(Args{{"min_child_weight", "0"}, {"reg_lambda", "0"}}); auto dmat = RandomDataGenerator(kRows, kCols, 0).Seed(3).GenerateDMatrix(); - auto evaluator = - HistEvaluator{param, dmat->Info(), n_threads, sampler}; + auto evaluator = HistEvaluator{ + param, dmat->Info(), n_threads, sampler, ObjInfo{ObjInfo::kRegression}}; common::HistCollection hist; std::vector row_gpairs = { {1.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {2.27f, 0.28f}, @@ -39,7 +36,7 @@ template void TestEvaluateSplits() { std::iota(row_indices.begin(), row_indices.end(), 0); row_set_collection.Init(); - auto hist_builder = GHistBuilder(n_threads, gmat.cut.Ptrs().back()); + auto hist_builder = GHistBuilder(omp_get_max_threads(), gmat.cut.Ptrs().back()); hist.Init(gmat.cut.Ptrs().back()); hist.AddHistRow(0); hist.AllocateAllData(); @@ -58,7 +55,7 @@ template void TestEvaluateSplits() { entries.front().depth = 0; evaluator.InitRoot(GradStats{total_gpair}); - evaluator.EvaluateSplits(hist, gmat.cut, tree, &entries); + evaluator.EvaluateSplits(hist, gmat.cut, {}, tree, &entries); auto best_loss_chg = evaluator.Evaluator().CalcSplitGain( @@ -96,8 +93,8 @@ TEST(HistEvaluator, Apply) { param.UpdateAllowUnknown(Args{{}}); auto dmat = RandomDataGenerator(kNRows, kNCols, 0).Seed(3).GenerateDMatrix(); auto sampler = std::make_shared(); - auto evaluator_ = - HistEvaluator{param, dmat->Info(), 4, sampler}; + auto evaluator_ = HistEvaluator{param, dmat->Info(), 4, sampler, + ObjInfo{ObjInfo::kRegression}}; CPUExpandEntry entry{0, 0, 10.0f}; entry.split.left_sum = GradStats{0.4, 0.6f}; @@ -108,5 +105,142 @@ TEST(HistEvaluator, Apply) { ASSERT_EQ(tree.Stat(tree[0].LeftChild()).sum_hess, 0.6f); ASSERT_EQ(tree.Stat(tree[0].RightChild()).sum_hess, 0.7f); } + +TEST(HistEvaluator, CategoricalPartition) { + int static constexpr kRows = 128, kCols = 1; + using GradientSumT = double; + std::vector ft(kCols, FeatureType::kCategorical); + + TrainParam param; + param.UpdateAllowUnknown(Args{{"min_child_weight", "0"}, {"reg_lambda", "0"}}); + + size_t n_cats{8}; + + auto dmat = + RandomDataGenerator(kRows, kCols, 0).Seed(3).Type(ft).MaxCategory(n_cats).GenerateDMatrix(); + + int32_t n_threads = 16; + auto sampler = std::make_shared(); + auto evaluator = HistEvaluator{ + param, dmat->Info(), n_threads, sampler, ObjInfo{ObjInfo::kRegression}}; + + for (auto const &gmat : dmat->GetBatches({GenericParameter::kCpuId, 32})) { + common::HistCollection hist; + + std::vector entries(1); + entries.front().nid = 0; + entries.front().depth = 0; + + hist.Init(gmat.cut.TotalBins()); + hist.AddHistRow(0); + hist.AllocateAllData(); + auto node_hist = hist[0]; + ASSERT_EQ(node_hist.size(), n_cats); + ASSERT_EQ(node_hist.size(), gmat.cut.Ptrs().back()); + + GradientPairPrecise total_gpair; + for (size_t i = 0; i < node_hist.size(); ++i) { + node_hist[i] = {static_cast(node_hist.size() - i), 1.0}; + total_gpair += node_hist[i]; + } + SimpleLCG lcg; + std::shuffle(node_hist.begin(), node_hist.end(), lcg); + + RegTree tree; + evaluator.InitRoot(GradStats{total_gpair}); + evaluator.EvaluateSplits(hist, gmat.cut, ft, tree, &entries); + ASSERT_TRUE(entries.front().split.is_cat); + + auto run_eval = [&](auto fn) { + for (size_t i = 1; i < gmat.cut.Ptrs().size(); ++i) { + GradStats left, right; + for (size_t j = gmat.cut.Ptrs()[i - 1]; j < gmat.cut.Ptrs()[i]; ++j) { + auto loss_chg = evaluator.Evaluator().CalcSplitGain(param, 0, i - 1, left, right) - + evaluator.Stats().front().root_gain; + fn(loss_chg); + left.Add(node_hist[j].GetGrad(), node_hist[j].GetHess()); + right.SetSubstract(GradStats{total_gpair}, left); + } + } + }; + // Assert that's the best split + auto best_loss_chg = entries.front().split.loss_chg; + run_eval([&](auto loss_chg) { + // Approximated test that gain returned by optimal partition is greater than + // numerical split. + ASSERT_GT(best_loss_chg, loss_chg); + }); + // node_hist is captured in lambda. + std::sort(node_hist.begin(), node_hist.end(), [&](auto l, auto r) { + return evaluator.Evaluator().CalcWeightCat(param, l) < + evaluator.Evaluator().CalcWeightCat(param, r); + }); + + double reimpl = 0; + run_eval([&](auto loss_chg) { reimpl = std::max(loss_chg, reimpl); }); + CHECK_EQ(reimpl, best_loss_chg); + } +} + +namespace { +auto CompareOneHotAndPartition(bool onehot) { + int static constexpr kRows = 128, kCols = 1; + using GradientSumT = double; + std::vector ft(kCols, FeatureType::kCategorical); + + TrainParam param; + if (onehot) { + // force use one-hot + param.UpdateAllowUnknown( + Args{{"min_child_weight", "0"}, {"reg_lambda", "0"}, {"max_cat_to_onehot", "100"}}); + } else { + param.UpdateAllowUnknown( + Args{{"min_child_weight", "0"}, {"reg_lambda", "0"}, {"max_cat_to_onehot", "1"}}); + } + + size_t n_cats{2}; + + auto dmat = + RandomDataGenerator(kRows, kCols, 0).Seed(3).Type(ft).MaxCategory(n_cats).GenerateDMatrix(); + + int32_t n_threads = 16; + auto sampler = std::make_shared(); + auto evaluator = HistEvaluator{ + param, dmat->Info(), n_threads, sampler, ObjInfo{ObjInfo::kRegression}}; + std::vector entries(1); + + for (auto const &gmat : dmat->GetBatches({GenericParameter::kCpuId, 32})) { + common::HistCollection hist; + + entries.front().nid = 0; + entries.front().depth = 0; + + hist.Init(gmat.cut.TotalBins()); + hist.AddHistRow(0); + hist.AllocateAllData(); + auto node_hist = hist[0]; + + CHECK_EQ(node_hist.size(), n_cats); + CHECK_EQ(node_hist.size(), gmat.cut.Ptrs().back()); + + GradientPairPrecise total_gpair; + for (size_t i = 0; i < node_hist.size(); ++i) { + node_hist[i] = {static_cast(node_hist.size() - i), 1.0}; + total_gpair += node_hist[i]; + } + RegTree tree; + evaluator.InitRoot(GradStats{total_gpair}); + evaluator.EvaluateSplits(hist, gmat.cut, ft, tree, &entries); + } + return entries.front(); +} +} // anonymous namespace + +TEST(HistEvaluator, Categorical) { + auto with_onehot = CompareOneHotAndPartition(true); + auto with_part = CompareOneHotAndPartition(false); + + ASSERT_EQ(with_onehot.split.loss_chg, with_part.split.loss_chg); +} } // namespace tree } // namespace xgboost diff --git a/tests/cpp/tree/test_param.cc b/tests/cpp/tree/test_param.cc index b4cc4005e3ad..d4194bb74c58 100644 --- a/tests/cpp/tree/test_param.cc +++ b/tests/cpp/tree/test_param.cc @@ -88,14 +88,14 @@ TEST(Param, SplitEntry) { xgboost::tree::SplitEntry se2; EXPECT_FALSE(se1.Update(se2)); - EXPECT_FALSE(se2.Update(-1, 100, 0, true, xgboost::tree::GradStats(), + EXPECT_FALSE(se2.Update(-1, 100, 0, true, false, xgboost::tree::GradStats(), xgboost::tree::GradStats())); - ASSERT_TRUE(se2.Update(1, 100, 0, true, xgboost::tree::GradStats(), + ASSERT_TRUE(se2.Update(1, 100, 0, true, false, xgboost::tree::GradStats(), xgboost::tree::GradStats())); ASSERT_TRUE(se1.Update(se2)); xgboost::tree::SplitEntry se3; - se3.Update(2, 101, 0, false, xgboost::tree::GradStats(), + se3.Update(2, 101, 0, false, false, xgboost::tree::GradStats(), xgboost::tree::GradStats()); xgboost::tree::SplitEntry::Reduce(se2, se3); EXPECT_EQ(se2.SplitIndex(), 101);