Skip to content

Commit

Permalink
Handle missing categorical value in CPU evaluator. (#7948)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed May 27, 2022
1 parent 2070afe commit bde4f25
Show file tree
Hide file tree
Showing 7 changed files with 181 additions and 154 deletions.
194 changes: 102 additions & 92 deletions src/tree/hist/evaluate_splits.h
Expand Up @@ -119,13 +119,90 @@ class HistEvaluator {
p_best->Update(best);
}

/**
* \brief Enumerate with partition-based splits.
*
* The implementation is different from LightGBM. Firstly we don't have a
* pseudo-cateogry for missing value, instead of we make 2 complete scans over the
* histogram. Secondly, both scan directions generate splits in the same
* order. Following table depicts the scan process, square bracket means the gradient in
* missing values is resided on that partition:
*
* | Forward | Backward |
* |----------+----------|
* | [BCDE] A | E [ABCD] |
* | [CDE] AB | DE [ABC] |
* | [DE] ABC | CDE [AB] |
* | [E] ABCD | BCDE [A] |
*/
template <int d_step>
void EnumeratePart(common::HistogramCuts const &cut, common::Span<size_t const> sorted_idx,
common::GHistRow const &hist, bst_feature_t fidx, bst_node_t nidx,
TreeEvaluator::SplitEvaluator<TrainParam> const &evaluator,
SplitEntry *p_best) {
static_assert(d_step == +1 || d_step == -1, "Invalid step.");

auto const &cut_ptr = cut.Ptrs();
auto const &parent = snode_[nidx];
bst_bin_t n_bins{static_cast<bst_bin_t>(cut_ptr[fidx + 1] - cut_ptr[fidx])};

// statistics on both sides of split
GradStats left_sum;
GradStats right_sum;
// best split so far
SplitEntry best;

auto f_hist = hist.subspan(cut_ptr[fidx], n_bins);
bst_bin_t ibegin, iend;
bst_bin_t f_begin = cut_ptr[fidx];
if (d_step > 0) {
ibegin = f_begin;
iend = ibegin + n_bins - 1;
} else {
ibegin = static_cast<bst_bin_t>(cut_ptr[fidx + 1]) - 1;
iend = f_begin;
}

bst_bin_t best_thresh{-1};
for (bst_bin_t i = ibegin; i != iend; i += d_step) {
auto j = i - f_begin; // index local to current feature
if (d_step == 1) {
right_sum.Add(f_hist[sorted_idx[j]].GetGrad(), f_hist[sorted_idx[j]].GetHess());
left_sum.SetSubstract(parent.stats, right_sum); // missing on left
} else {
left_sum.Add(f_hist[sorted_idx[j]].GetGrad(), f_hist[sorted_idx[j]].GetHess());
right_sum.SetSubstract(parent.stats, left_sum); // missing on right
}
if (IsValid(left_sum, right_sum)) {
auto loss_chg =
evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum}, GradStats{right_sum}) -
parent.root_gain;
// We don't have a numeric split point, nan hare is a dummy split.
if (best.Update(loss_chg, fidx, std::numeric_limits<float>::quiet_NaN(), d_step == 1, true,
left_sum, right_sum)) {
best_thresh = i;
}
}
}

if (best_thresh != -1) {
auto n = common::CatBitField::ComputeStorageSize(n_bins + 1);
best.cat_bits = decltype(best.cat_bits)(n, 0);
common::CatBitField cat_bits{best.cat_bits};
bst_bin_t partition = d_step == 1 ? (best_thresh - ibegin + 1) : best_thresh - iend;
std::for_each(sorted_idx.begin(), sorted_idx.begin() + partition,
[&](size_t c) { cat_bits.Set(c); });
}

p_best->Update(best);
}

// Enumerate/Scan the split values of specific feature
// Returns the sum of gradients corresponding to the data points that contains
// a non-missing value for the particular feature fid.
template <int d_step, SplitType split_type>
template <int d_step>
GradStats EnumerateSplit(common::HistogramCuts const &cut, common::Span<size_t const> sorted_idx,
const common::GHistRow &hist, bst_feature_t fidx,
bst_node_t nidx,
const common::GHistRow &hist, bst_feature_t fidx, bst_node_t nidx,
TreeEvaluator::SplitEvaluator<TrainParam> const &evaluator,
SplitEntry *p_best) const {
static_assert(d_step == +1 || d_step == -1, "Invalid step.");
Expand All @@ -134,8 +211,6 @@ class HistEvaluator {
const std::vector<uint32_t> &cut_ptr = cut.Ptrs();
const std::vector<bst_float> &cut_val = cut.Values();
auto const &parent = snode_[nidx];
int32_t n_bins{static_cast<int32_t>(cut_ptr.at(fidx + 1) - cut_ptr[fidx])};
auto f_hist = hist.subspan(cut_ptr[fidx], n_bins);

// statistics on both sides of split
GradStats left_sum;
Expand All @@ -144,50 +219,28 @@ class HistEvaluator {
SplitEntry best;

// bin boundaries
CHECK_LE(cut_ptr[fidx], static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
CHECK_LE(cut_ptr[fidx + 1], static_cast<uint32_t>(std::numeric_limits<int32_t>::max()));
CHECK_LE(cut_ptr[fidx], static_cast<uint32_t>(std::numeric_limits<bst_bin_t>::max()));
CHECK_LE(cut_ptr[fidx + 1], static_cast<uint32_t>(std::numeric_limits<bst_bin_t>::max()));
// imin: index (offset) of the minimum value for feature fid need this for backward
// enumeration
const auto imin = static_cast<int32_t>(cut_ptr[fidx]);
const auto imin = static_cast<bst_bin_t>(cut_ptr[fidx]);
// ibegin, iend: smallest/largest cut points for feature fid use int to allow for
// value -1
int32_t ibegin, iend;
bst_bin_t ibegin, iend;
if (d_step > 0) {
ibegin = static_cast<int32_t>(cut_ptr[fidx]);
iend = static_cast<int32_t>(cut_ptr.at(fidx + 1));
ibegin = static_cast<bst_bin_t>(cut_ptr[fidx]);
iend = static_cast<bst_bin_t>(cut_ptr.at(fidx + 1));
} else {
ibegin = static_cast<int32_t>(cut_ptr[fidx + 1]) - 1;
iend = static_cast<int32_t>(cut_ptr[fidx]) - 1;
ibegin = static_cast<bst_bin_t>(cut_ptr[fidx + 1]) - 1;
iend = static_cast<bst_bin_t>(cut_ptr[fidx]) - 1;
}

auto calc_bin_value = [&](auto i) {
switch (split_type) {
case kNum: {
left_sum.Add(hist[i].GetGrad(), hist[i].GetHess());
right_sum.SetSubstract(parent.stats, left_sum);
break;
}
case kOneHot: {
std::terminate(); // unreachable
break;
}
case kPart: {
auto j = d_step == 1 ? (i - ibegin) : (ibegin - i);
right_sum.Add(f_hist[sorted_idx[j]].GetGrad(), f_hist[sorted_idx[j]].GetHess());
left_sum.SetSubstract(parent.stats, right_sum);
break;
}
}
};

int32_t best_thresh{-1};
for (int32_t i = ibegin; i != iend; i += d_step) {
for (bst_bin_t i = ibegin; i != iend; i += d_step) {
// start working
// try to find a split
calc_bin_value(i);
bool improved{false};
if (left_sum.GetHess() >= param_.min_child_weight &&
right_sum.GetHess() >= param_.min_child_weight) {
left_sum.Add(hist[i].GetGrad(), hist[i].GetHess());
right_sum.SetSubstract(parent.stats, left_sum);
if (IsValid(left_sum, right_sum)) {
bst_float loss_chg;
bst_float split_pt;
if (d_step > 0) {
Expand All @@ -197,66 +250,24 @@ class HistEvaluator {
GradStats{right_sum}) -
parent.root_gain);
split_pt = cut_val[i]; // not used for partition based
improved = best.Update(loss_chg, fidx, split_pt, d_step == -1, split_type != kNum,
left_sum, right_sum);
best.Update(loss_chg, fidx, split_pt, d_step == -1, false, left_sum, right_sum);
} else {
// backward enumeration: split at left bound of each bin
loss_chg =
static_cast<float>(evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{right_sum},
GradStats{left_sum}) -
parent.root_gain);
switch (split_type) {
case kNum: {
if (i == imin) {
split_pt = cut.MinValues()[fidx];
} else {
split_pt = cut_val[i - 1];
}
break;
}
case kOneHot: {
std::terminate(); // unreachable
break;
}
case kPart: {
split_pt = cut_val[i];
break;
}
if (i == imin) {
split_pt = cut.MinValues()[fidx];
} else {
split_pt = cut_val[i - 1];
}
improved = best.Update(loss_chg, fidx, split_pt, d_step == -1, split_type != kNum,
right_sum, left_sum);
}
if (improved) {
best_thresh = i;
best.Update(loss_chg, fidx, split_pt, d_step == -1, false, right_sum, left_sum);
}
}
}

if (split_type == kPart && best_thresh != -1) {
auto n = common::CatBitField::ComputeStorageSize(n_bins);
best.cat_bits.resize(n, 0);
common::CatBitField cat_bits{best.cat_bits};

if (d_step == 1) {
std::for_each(sorted_idx.begin(), sorted_idx.begin() + (best_thresh - ibegin + 1),
[&](size_t c) { cat_bits.Set(cut_val[c + ibegin]); });
} else {
std::for_each(sorted_idx.rbegin(), sorted_idx.rbegin() + (ibegin - best_thresh),
[&](size_t c) { cat_bits.Set(cut_val[c + cut_ptr[fidx]]); });
}
}
p_best->Update(best);

switch (split_type) {
case kNum:
// Normal, accumulated to left
return left_sum;
case kOneHot:
return {};
case kPart:
// Accumulated to right due to chosen cats go to right.
return right_sum;
}
return left_sum;
}

Expand Down Expand Up @@ -316,14 +327,13 @@ class HistEvaluator {
evaluator.CalcWeightCat(param_, feat_hist[r]);
return ret;
});
EnumerateSplit<+1, kPart>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
EnumerateSplit<-1, kPart>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
EnumeratePart<+1>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
EnumeratePart<-1>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best);
}
} else {
auto grad_stats =
EnumerateSplit<+1, kNum>(cut, {}, histogram, fidx, nidx, evaluator, best);
auto grad_stats = EnumerateSplit<+1>(cut, {}, histogram, fidx, nidx, evaluator, best);
if (SplitContainsMissingValues(grad_stats, snode_[nidx])) {
EnumerateSplit<-1, kNum>(cut, {}, histogram, fidx, nidx, evaluator, best);
EnumerateSplit<-1>(cut, {}, histogram, fidx, nidx, evaluator, best);
}
}
}
Expand Down
5 changes: 2 additions & 3 deletions src/tree/hist/expand_entry.h
Expand Up @@ -50,12 +50,11 @@ struct CPUExpandEntry {
}

friend std::ostream& operator<<(std::ostream& os, const CPUExpandEntry& e) {
os << "ExpandEntry: \n";
os << "ExpandEntry:\n";
os << "nidx: " << e.nid << "\n";
os << "depth: " << e.depth << "\n";
os << "loss: " << e.split.loss_chg << "\n";
os << "left_sum: " << e.split.left_sum << "\n";
os << "right_sum: " << e.split.right_sum << "\n";
os << "split:\n" << e.split << std::endl;
return os;
}
};
Expand Down
38 changes: 8 additions & 30 deletions src/tree/param.h
Expand Up @@ -367,12 +367,14 @@ struct SplitEntryContainer {

SplitEntryContainer() = default;

friend std::ostream& operator<<(std::ostream& os, SplitEntryContainer const& s) {
os << "loss_chg: " << s.loss_chg << ", "
<< "split index: " << s.SplitIndex() << ", "
<< "split value: " << s.split_value << ", "
<< "left_sum: " << s.left_sum << ", "
<< "right_sum: " << s.right_sum;
friend std::ostream &operator<<(std::ostream &os, SplitEntryContainer const &s) {
os << "loss_chg: " << s.loss_chg << "\n"
<< "dft_left: " << s.DefaultLeft() << "\n"
<< "split_index: " << s.SplitIndex() << "\n"
<< "split_value: " << s.split_value << "\n"
<< "is_cat: " << s.is_cat << "\n"
<< "left_sum: " << s.left_sum << "\n"
<< "right_sum: " << s.right_sum << std::endl;
return os;
}
/*!\return feature index to split on */
Expand Down Expand Up @@ -446,30 +448,6 @@ struct SplitEntryContainer {
}
}

/*!
* \brief Update with partition based categorical split.
*
* \return Whether the proposed split is better and can replace current split.
*/
bool Update(float new_loss_chg, bst_feature_t split_index, common::KCatBitField cats,
bool default_left, GradientT const &left_sum, GradientT const &right_sum) {
if (this->NeedReplace(new_loss_chg, split_index)) {
this->loss_chg = new_loss_chg;
if (default_left) {
split_index |= (1U << 31);
}
this->sindex = split_index;
cat_bits.resize(cats.Bits().size());
std::copy(cats.Bits().begin(), cats.Bits().end(), cat_bits.begin());
this->is_cat = true;
this->left_sum = left_sum;
this->right_sum = right_sum;
return true;
} else {
return false;
}
}

/*! \brief same as update, used by AllReduce*/
inline static void Reduce(SplitEntryContainer &dst, // NOLINT(*)
const SplitEntryContainer &src) { // NOLINT(*)
Expand Down
4 changes: 2 additions & 2 deletions tests/cpp/tree/hist/test_evaluate_splits.cc
Expand Up @@ -147,9 +147,9 @@ auto CompareOneHotAndPartition(bool onehot) {
auto dmat =
RandomDataGenerator(kRows, kCols, 0).Seed(3).Type(ft).MaxCategory(n_cats).GenerateDMatrix();

int32_t n_threads = 16;
auto sampler = std::make_shared<common::ColumnSampler>();
auto evaluator = HistEvaluator<CPUExpandEntry>{param, dmat->Info(), n_threads, sampler};
auto evaluator =
HistEvaluator<CPUExpandEntry>{param, dmat->Info(), common::OmpGetNumThreads(0), sampler};
std::vector<CPUExpandEntry> entries(1);

for (auto const &gmat : dmat->GetBatches<GHistIndexMatrix>({32, param.sparse_threshold})) {
Expand Down
5 changes: 4 additions & 1 deletion tests/cpp/tree/test_evaluate_splits.h
Expand Up @@ -2,11 +2,14 @@
* Copyright 2022 by XGBoost Contributors
*/
#include <gtest/gtest.h>
#include <xgboost/data.h>

#include <algorithm> // next_permutation
#include <numeric> // iota

#include "../../../src/tree/hist/evaluate_splits.h"
#include "../../../src/common/hist_util.h" // HistogramCuts,HistCollection
#include "../../../src/tree/param.h" // TrainParam
#include "../../../src/tree/split_evaluator.h"
#include "../helpers.h"

namespace xgboost {
Expand Down
10 changes: 10 additions & 0 deletions tests/python-gpu/test_gpu_updaters.py
Expand Up @@ -77,6 +77,16 @@ def test_sparse(self, dataset):
def test_categorical(self, rows, cols, rounds, cats):
self.cputest.run_categorical_basic(rows, cols, rounds, cats, "gpu_hist")

@given(
strategies.integers(10, 400),
strategies.integers(3, 8),
strategies.integers(4, 7)
)
@settings(deadline=None, print_blob=True)
@pytest.mark.skipif(**tm.no_pandas())
def test_categorical_missing(self, rows, cols, cats):
self.cputest.run_categorical_missing(rows, cols, cats, "gpu_hist")

def test_max_cat(self) -> None:
self.cputest.run_max_cat("gpu_hist")

Expand Down

0 comments on commit bde4f25

Please sign in to comment.