Skip to content

Commit

Permalink
Handle missing values in one hot splits. (#7934)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed May 24, 2022
1 parent 18a38f7 commit 606be9e
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 14 deletions.
72 changes: 63 additions & 9 deletions src/tree/hist/evaluate_splits.h
Expand Up @@ -45,14 +45,72 @@ class HistEvaluator {
// then - there are no missing values
// else - there are missing values
bool static SplitContainsMissingValues(const GradStats e, const NodeEntry &snode) {
if (e.GetGrad() == snode.stats.GetGrad() &&
e.GetHess() == snode.stats.GetHess()) {
if (e.GetGrad() == snode.stats.GetGrad() && e.GetHess() == snode.stats.GetHess()) {
return false;
} else {
return true;
}
}

bool IsValid(GradStats const &left, GradStats const &right) const {
return left.GetHess() >= param_.min_child_weight && right.GetHess() >= param_.min_child_weight;
}

/**
* \brief Use learned direction with one-hot split. Other implementations (LGB, sklearn)
* create a pseudo-category for missing value but here we just do a complete scan
* to avoid making specialized histogram bin.
*/
void EnumerateOneHot(common::HistogramCuts const &cut, const common::GHistRow &hist,
bst_feature_t fidx, bst_node_t nidx,
TreeEvaluator::SplitEvaluator<TrainParam> const &evaluator,
SplitEntry *p_best) const {
const std::vector<uint32_t> &cut_ptr = cut.Ptrs();
const std::vector<bst_float> &cut_val = cut.Values();

bst_bin_t ibegin = static_cast<bst_bin_t>(cut_ptr[fidx]);
bst_bin_t iend = static_cast<bst_bin_t>(cut_ptr[fidx + 1]);
bst_bin_t n_bins = iend - ibegin;

GradStats left_sum;
GradStats right_sum;
// best split so far
SplitEntry best;

auto f_hist = hist.subspan(cut_ptr[fidx], n_bins);
auto feature_sum = GradStats{
std::accumulate(f_hist.data(), f_hist.data() + f_hist.size(), GradientPairPrecise{})};
GradStats missing;
auto const &parent = snode_[nidx];
missing.SetSubstract(parent.stats, feature_sum);

for (bst_bin_t i = ibegin; i != iend; i += 1) {
auto split_pt = cut_val[i];

// missing on left (treat missing as other categories)
right_sum = GradStats{hist[i]};
left_sum.SetSubstract(parent.stats, right_sum);
if (IsValid(left_sum, right_sum)) {
auto missing_left_chg = static_cast<float>(
evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum}, GradStats{right_sum}) -
parent.root_gain);
best.Update(missing_left_chg, fidx, split_pt, true, true, left_sum, right_sum);
}

// missing on right (treat missing as chosen category)
left_sum.SetSubstract(left_sum, missing);
right_sum.Add(missing);
if (IsValid(left_sum, right_sum)) {
auto missing_right_chg = static_cast<float>(
evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum}, GradStats{right_sum}) -
parent.root_gain);
best.Update(missing_right_chg, fidx, split_pt, false, true, left_sum, right_sum);
}
}

p_best->Update(best);
}

// Enumerate/Scan the split values of specific feature
// Returns the sum of gradients corresponding to the data points that contains
// a non-missing value for the particular feature fid.
Expand Down Expand Up @@ -102,9 +160,7 @@ class HistEvaluator {
break;
}
case kOneHot: {
// not-chosen categories go to left
right_sum = GradStats{hist[i]};
left_sum.SetSubstract(parent.stats, right_sum);
std::terminate(); // unreachable
break;
}
case kPart: {
Expand Down Expand Up @@ -151,7 +207,7 @@ class HistEvaluator {
break;
}
case kOneHot: {
split_pt = cut_val[i];
std::terminate(); // unreachable
break;
}
case kPart: {
Expand Down Expand Up @@ -188,7 +244,6 @@ class HistEvaluator {
// Normal, accumulated to left
return left_sum;
case kOneHot:
// Doesn't matter, not accumulating.
return {};
case kPart:
// Accumulated to right due to chosen cats go to right.
Expand Down Expand Up @@ -242,8 +297,7 @@ class HistEvaluator {
if (is_cat) {
auto n_bins = cut_ptrs.at(fidx + 1) - cut_ptrs[fidx];
if (common::UseOneHot(n_bins, param_.max_cat_to_onehot)) {
EnumerateSplit<+1, kOneHot>(cut, {}, histogram, fidx, nidx, evaluator, best);
EnumerateSplit<-1, kOneHot>(cut, {}, histogram, fidx, nidx, evaluator, best);
EnumerateOneHot(cut, histogram, fidx, nidx, evaluator, best);
} else {
std::vector<size_t> sorted_idx(n_bins);
std::iota(sorted_idx.begin(), sorted_idx.end(), 0);
Expand Down
32 changes: 28 additions & 4 deletions tests/python/test_updaters.py
Expand Up @@ -214,17 +214,19 @@ def test_max_cat(self, tree_method) -> None:
self.run_max_cat(tree_method)

def run_categorical_basic(self, rows, cols, rounds, cats, tree_method):
USE_ONEHOT = np.iinfo(np.int32).max
USE_PART = 1

onehot, label = tm.make_categorical(rows, cols, cats, True)
cat, _ = tm.make_categorical(rows, cols, cats, False)

by_etl_results = {}
by_builtin_results = {}

predictor = "gpu_predictor" if tree_method == "gpu_hist" else None
parameters = {"tree_method": tree_method, "predictor": predictor}
# Use one-hot exclusively
parameters = {
"tree_method": tree_method, "predictor": predictor, "max_cat_to_onehot": 9999
}
parameters["max_cat_to_onehot"] = USE_ONEHOT

m = xgb.DMatrix(onehot, label, enable_categorical=False)
xgb.train(
Expand Down Expand Up @@ -257,7 +259,8 @@ def run_categorical_basic(self, rows, cols, rounds, cats, tree_method):
assert tm.non_increasing(by_builtin_results["Train"]["rmse"])

by_grouping: xgb.callback.TrainingCallback.EvalsLog = {}
parameters["max_cat_to_onehot"] = 1
# switch to partition-based splits
parameters["max_cat_to_onehot"] = USE_PART
parameters["reg_lambda"] = 0
m = xgb.DMatrix(cat, label, enable_categorical=True)
xgb.train(
Expand All @@ -284,6 +287,27 @@ def run_categorical_basic(self, rows, cols, rounds, cats, tree_method):
)
assert tm.non_increasing(by_grouping["Train"]["rmse"]), by_grouping

# test with missing values
cat, label = tm.make_categorical(
n_samples=256, n_features=4, n_categories=8, onehot=False, sparsity=0.5
)
Xy = xgb.DMatrix(cat, label, enable_categorical=True)
evals_result = {}
# Test with onehot splits
parameters["max_cat_to_onehot"] = USE_ONEHOT
booster = xgb.train(
parameters,
Xy,
num_boost_round=16,
evals=[(Xy, "Train")],
evals_result=evals_result
)
assert tm.non_increasing(evals_result["Train"]["rmse"])
y_predt = booster.predict(Xy)

rmse = tm.root_mean_square(label, y_predt)
np.testing.assert_allclose(rmse, evals_result["Train"]["rmse"][-1])

@given(strategies.integers(10, 400), strategies.integers(3, 8),
strategies.integers(1, 2), strategies.integers(4, 7))
@settings(deadline=None, print_blob=True)
Expand Down
15 changes: 14 additions & 1 deletion tests/python/testing.py
Expand Up @@ -302,7 +302,7 @@ def get_mq2008(dpath):

@memory.cache
def make_categorical(
n_samples: int, n_features: int, n_categories: int, onehot: bool
n_samples: int, n_features: int, n_categories: int, onehot: bool, sparsity=0.0,
):
import pandas as pd

Expand All @@ -325,6 +325,13 @@ def make_categorical(
for col in df.columns:
df[col] = df[col].cat.set_categories(categories)

if sparsity > 0.0:
for i in range(n_features):
index = rng.randint(low=0, high=n_samples-1, size=int(n_samples * sparsity))
df.iloc[index, i] = np.NaN
assert df.iloc[:, i].isnull().values.any()
assert n_categories == np.unique(df.dtypes[i].categories).size

if onehot:
return pd.get_dummies(df), label
return df, label
Expand Down Expand Up @@ -538,6 +545,12 @@ def eval_error_metric_skl(y_true: np.ndarray, y_score: np.ndarray) -> float:
return np.sum(r)


def root_mean_square(y_true: np.ndarray, y_score: np.ndarray) -> float:
err = y_score - y_true
rmse = np.sqrt(np.dot(err, err) / y_score.size)
return rmse


def softmax(x):
e = np.exp(x)
return e / np.sum(e)
Expand Down

0 comments on commit 606be9e

Please sign in to comment.