Skip to content

Commit

Permalink
More categorical tests and disable shap sparse test. (#6219)
Browse files Browse the repository at this point in the history
* Fix tree load with 32 category.
  • Loading branch information
trivialfis committed Oct 10, 2020
1 parent c991eb6 commit b5b2435
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 28 deletions.
12 changes: 9 additions & 3 deletions src/common/json.cc
Expand Up @@ -186,7 +186,9 @@ Json& JsonObject::operator[](int ind) {
}

bool JsonObject::operator==(Value const& rhs) const {
if (!IsA<JsonObject>(&rhs)) { return false; }
if (!IsA<JsonObject>(&rhs)) {
return false;
}
return object_ == Cast<JsonObject const>(&rhs)->GetObject();
}

Expand Down Expand Up @@ -275,10 +277,14 @@ Json& JsonNumber::operator[](int ind) {

bool JsonNumber::operator==(Value const& rhs) const {
if (!IsA<JsonNumber>(&rhs)) { return false; }
auto r_num = Cast<JsonNumber const>(&rhs)->GetNumber();
if (std::isinf(number_)) {
return std::isinf(Cast<JsonNumber const>(&rhs)->GetNumber());
return std::isinf(r_num);
}
if (std::isnan(number_)) {
return std::isnan(r_num);
}
return std::abs(number_ - Cast<JsonNumber const>(&rhs)->GetNumber()) < kRtEps;
return number_ - r_num == 0;
}

Value & JsonNumber::operator=(Value const &rhs) {
Expand Down
9 changes: 5 additions & 4 deletions src/tree/tree_model.cc
Expand Up @@ -792,16 +792,17 @@ void RegTree::LoadCategoricalSplit(Json const& in) {
auto j_begin = get<Integer const>(categories_segments[cnt]);
auto j_end = get<Integer const>(categories_sizes[cnt]) + j_begin;
bst_cat_t max_cat{std::numeric_limits<bst_cat_t>::min()};
CHECK_NE(j_end - j_begin, 0) << nidx;

for (auto j = j_begin; j < j_end; ++j) {
auto const &category = get<Integer const>(categories[j]);
auto cat = common::AsCat(category);
max_cat = std::max(max_cat, cat);
}
size_t size = max_cat == std::numeric_limits<bst_cat_t>::min()
? 0
: common::KCatBitField::ComputeStorageSize(max_cat);
size = size == 0 ? 1 : size;
// Have at least 1 category in split.
CHECK_NE(std::numeric_limits<bst_cat_t>::min(), max_cat);
size_t n_cats = max_cat + 1; // cat 0
size_t size = common::KCatBitField::ComputeStorageSize(n_cats);
std::vector<uint32_t> cat_bits_storage(size, 0);
common::CatBitField cat_bits{common::Span<uint32_t>(cat_bits_storage)};
for (auto j = j_begin; j < j_end; ++j) {
Expand Down
72 changes: 72 additions & 0 deletions tests/cpp/tree/test_tree_model.cc
Expand Up @@ -5,6 +5,7 @@
#include "xgboost/json_io.h"
#include "xgboost/tree_model.h"
#include "../../../src/common/bitfield.h"
#include "../../../src/common/categorical.h"

namespace xgboost {
#if DMLC_IO_NO_ENDIAN_SWAP // skip on big-endian machines
Expand Down Expand Up @@ -150,6 +151,77 @@ TEST(Tree, ExpandCategoricalFeature) {
}
}

void GrowTree(RegTree* p_tree) {
SimpleLCG lcg;
size_t n_expands = 10;
constexpr size_t kCols = 256;
SimpleRealUniformDistribution<double> coin(0.0, 1.0);
SimpleRealUniformDistribution<double> feat(0.0, kCols);
SimpleRealUniformDistribution<double> split_cat(0.0, 128.0);
SimpleRealUniformDistribution<double> split_value(0.0, kCols);

std::stack<bst_node_t> stack;
stack.push(RegTree::kRoot);
auto& tree = *p_tree;

for (size_t i = 0; i < n_expands; ++i) {
auto is_cat = coin(&lcg) <= 0.5;
bst_node_t node = stack.top();
stack.pop();

bst_feature_t f = feat(&lcg);
if (is_cat) {
bst_cat_t cat = common::AsCat(split_cat(&lcg));
std::vector<uint32_t> split_cats(
LBitField32::ComputeStorageSize(cat + 1));
LBitField32 bitset{split_cats};
bitset.Set(cat);
tree.ExpandCategorical(node, f, split_cats, true, 1.0, 2.0, 3.0, 11.0, 2.0,
/*left_sum=*/3.0, /*right_sum=*/4.0);
} else {
auto split = split_value(&lcg);
tree.ExpandNode(node, f, split, true, 1.0, 2.0, 3.0, 11.0, 2.0,
/*left_sum=*/3.0, /*right_sum=*/4.0);
}

stack.push(tree[node].LeftChild());
stack.push(tree[node].RightChild());
}
}

void CheckReload(RegTree const &tree) {
Json out{Object()};
tree.SaveModel(&out);

RegTree loaded_tree;
loaded_tree.LoadModel(out);
Json saved{Object()};
loaded_tree.SaveModel(&saved);

auto same = out == saved;
ASSERT_TRUE(same);
}

TEST(Tree, CategoricalIO) {
{
RegTree tree;
bst_cat_t cat = 32;
std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(cat + 1));
LBitField32 bitset{split_cats};
bitset.Set(cat);
tree.ExpandCategorical(0, 0, split_cats, true, 1.0, 2.0, 3.0, 11.0, 2.0,
/*left_sum=*/3.0, /*right_sum=*/4.0);

CheckReload(tree);
}

{
RegTree tree;
GrowTree(&tree);
CheckReload(tree);
}
}

namespace {
RegTree ConstructTree() {
RegTree tree;
Expand Down
7 changes: 6 additions & 1 deletion tests/python-gpu/test_gpu_prediction.py
Expand Up @@ -212,6 +212,10 @@ def test_shap(self, num_rounds, dataset, param):
tm.dataset_strategy, shap_parameter_strategy)
@settings(deadline=None, max_examples=20)
def test_shap_interactions(self, num_rounds, dataset, param):
if dataset.name == 'sparse':
issue = 'https://github.com/dmlc/xgboost/issues/6074'
pytest.xfail(reason=f'GPU shap with sparse is flaky: {issue}')

param.update({"predictor": "gpu_predictor", "gpu_id": 0})
param = dataset.set_params(param)
dmat = dataset.get_dmat()
Expand All @@ -220,5 +224,6 @@ def test_shap_interactions(self, num_rounds, dataset, param):
shap = bst.predict(test_dmat, pred_interactions=True)
margin = bst.predict(test_dmat, output_margin=True)
assume(len(dataset.y) > 0)
assert np.allclose(np.sum(shap, axis=(len(shap.shape) - 1, len(shap.shape) - 2)), margin,
assert np.allclose(np.sum(shap, axis=(len(shap.shape) - 1, len(shap.shape) - 2)),
margin,
1e-3, 1e-3)
49 changes: 29 additions & 20 deletions tests/python-gpu/test_gpu_updaters.py
Expand Up @@ -41,7 +41,24 @@ def test_gpu_hist(self, param, num_rounds, dataset):
note(result)
assert tm.non_increasing(result['train'][dataset.metric])

def run_categorical_basic(self, cat, onehot, label, rounds):
def run_categorical_basic(self, rows, cols, rounds, cats):
import pandas as pd
rng = np.random.RandomState(1994)

pd_dict = {}
for i in range(cols):
c = rng.randint(low=0, high=cats+1, size=rows)
pd_dict[str(i)] = pd.Series(c, dtype=np.int64)

df = pd.DataFrame(pd_dict)
label = df.iloc[:, 0]
for i in range(0, cols-1):
label += df.iloc[:, i]
label += 1
df = df.astype('category')
onehot = pd.get_dummies(df)
cat = df

by_etl_results = {}
by_builtin_results = {}

Expand All @@ -64,28 +81,20 @@ def run_categorical_basic(self, cat, onehot, label, rounds):
rtol=1e-3)
assert tm.non_increasing(by_builtin_results['Train']['rmse'])

@given(strategies.integers(10, 400), strategies.integers(5, 10),
strategies.integers(1, 5), strategies.integers(4, 8))
@given(strategies.integers(10, 400), strategies.integers(3, 8),
strategies.integers(1, 5), strategies.integers(4, 7))
@settings(deadline=None)
@pytest.mark.skipif(**tm.no_pandas())
def test_categorical(self, rows, cols, rounds, cats):
import pandas as pd
rng = np.random.RandomState(1994)

pd_dict = {}
for i in range(cols):
c = rng.randint(low=0, high=cats+1, size=rows)
pd_dict[str(i)] = pd.Series(c, dtype=np.int64)

df = pd.DataFrame(pd_dict)
label = df.iloc[:, 0]
for i in range(0, cols-1):
label += df.iloc[:, i]
label += 1
df = df.astype('category')
x = pd.get_dummies(df)

self.run_categorical_basic(df, x, label, rounds)
self.run_categorical_basic(rows, cols, rounds, cats)

def test_categorical_32_cat(self):
'''32 hits the bound of integer bitset, so special test'''
rows = 1000
cols = 10
cats = 32
rounds = 4
self.run_categorical_basic(rows, cols, rounds, cats)

@pytest.mark.skipif(**tm.no_cupy())
@given(parameter_strategy, strategies.integers(1, 20),
Expand Down

0 comments on commit b5b2435

Please sign in to comment.