Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve JSON format for categorical features. #6128

Merged
merged 5 commits into from Sep 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/xgboost/tree_model.h
Expand Up @@ -620,6 +620,8 @@ class RegTree : public Model {
};

private:
void LoadCategoricalSplit(Json const& in);
void SaveCategoricalSplit(Json* p_out) const;
// vector of nodes
std::vector<Node> nodes_;
// free node space, used during training process
Expand Down
162 changes: 111 additions & 51 deletions src/tree/tree_model.cc
@@ -1,5 +1,5 @@
/*!
* Copyright 2015-2019 by Contributors
* Copyright 2015-2020 by Contributors
* \file tree_model.cc
* \brief model structure for tree
*/
Expand Down Expand Up @@ -740,7 +740,11 @@ void RegTree::Load(dmlc::Stream* fi) {
}
}
CHECK_EQ(static_cast<int>(deleted_nodes_.size()), param.num_deleted);

split_types_.resize(param.num_nodes, FeatureType::kNumerical);
split_categories_segments_.resize(param.num_nodes);
}

void RegTree::Save(dmlc::Stream* fo) const {
CHECK_EQ(param.num_nodes, static_cast<int>(nodes_.size()));
CHECK_EQ(param.num_nodes, static_cast<int>(stats_.size()));
Expand Down Expand Up @@ -772,6 +776,92 @@ void RegTree::Save(dmlc::Stream* fo) const {
}
}

void RegTree::LoadCategoricalSplit(Json const& in) {
auto const& categories_segments = get<Array const>(in["categories_segments"]);
auto const& categories_sizes = get<Array const>(in["categories_sizes"]);
auto const& categories_nodes = get<Array const>(in["categories_nodes"]);
auto const& categories = get<Array const>(in["categories"]);

size_t cnt = 0;
bst_node_t last_cat_node = -1;
if (!categories_nodes.empty()) {
last_cat_node = get<Integer const>(categories_nodes[cnt]);
}
for (size_t nidx = 0; nidx < param.num_nodes; ++nidx) {
if (nidx == last_cat_node) {
auto j_begin = get<Integer const>(categories_segments[cnt]);
auto j_end = get<Integer const>(categories_sizes[cnt]) + j_begin;
bst_cat_t max_cat{std::numeric_limits<bst_cat_t>::min()};

for (auto j = j_begin; j < j_end; ++j) {
auto const &category = get<Integer const>(categories[j]);
auto cat = common::AsCat(category);
max_cat = std::max(max_cat, cat);
}
size_t size = max_cat == std::numeric_limits<bst_cat_t>::min()
? 0
: common::KCatBitField::ComputeStorageSize(max_cat);
std::vector<uint32_t> cat_bits_storage(size);
common::CatBitField cat_bits{common::Span<uint32_t>(cat_bits_storage)};
for (auto j = j_begin; j < j_end; ++j) {
cat_bits.Set(common::AsCat(get<Integer const>(categories[j])));
}

auto begin = split_categories_.size();
split_categories_.resize(begin + cat_bits_storage.size());
std::copy(cat_bits_storage.begin(), cat_bits_storage.end(),
split_categories_.begin() + begin);
split_categories_segments_[nidx].beg = begin;
split_categories_segments_[nidx].size = cat_bits_storage.size();

++cnt;
if (cnt == categories_nodes.size()) {
last_cat_node = -1;
} else {
last_cat_node = get<Integer const>(categories_nodes[++cnt]);
}
} else {
split_categories_segments_[nidx].beg = categories.size();
split_categories_segments_[nidx].size = 0;
}
}
}

void RegTree::SaveCategoricalSplit(Json* p_out) const {
auto& out = *p_out;
CHECK_EQ(this->split_types_.size(), param.num_nodes);
CHECK_EQ(this->GetSplitCategoriesPtr().size(), param.num_nodes);

std::vector<Json> categories_segments;
std::vector<Json> categories_sizes;
std::vector<Json> categories;
std::vector<Json> categories_nodes;

for (size_t i = 0; i < nodes_.size(); ++i) {
if (this->split_types_[i] == FeatureType::kCategorical) {
categories_nodes.emplace_back(i);
auto begin = categories.size();
categories_segments.emplace_back(static_cast<Integer::Int>(begin));
auto segment = split_categories_segments_[i];
auto node_categories =
this->GetSplitCategories().subspan(segment.beg, segment.size);
common::KCatBitField const cat_bits(node_categories);
for (size_t i = 0; i < cat_bits.Size(); ++i) {
if (cat_bits.Check(i)) {
categories.emplace_back(static_cast<Integer::Int>(i));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@trivialfis What's the performance implication of saving individual categories? Is it better than saving bitmaps directly?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it's not. If saving bitmap is preferred I can make the change.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, if saving individual categories produces acceptable performance, let's keep it. It's easier to parse by a human.

}
}
size_t size = categories.size() - begin;
categories_sizes.emplace_back(static_cast<Integer::Int>(size));
}
}

out["categories_segments"] = categories_segments;
out["categories_sizes"] = categories_sizes;
out["categories_nodes"] = categories_nodes;
out["categories"] = categories;
}

void RegTree::LoadModel(Json const& in) {
FromJson(in["tree_param"], &param);
auto n_nodes = param.num_nodes;
Expand Down Expand Up @@ -799,13 +889,9 @@ void RegTree::LoadModel(Json const& in) {

bool has_cat = get<Object const>(in).find("split_type") != get<Object const>(in).cend();
std::vector<Json> split_type;
std::vector<Json> categories;
if (has_cat) {
split_type = get<Array const>(in["split_type"]);
categories = get<Array const>(in["categories"]);
}


stats_.clear();
nodes_.clear();

Expand Down Expand Up @@ -833,46 +919,40 @@ void RegTree::LoadModel(Json const& in) {
if (has_cat) {
split_types_[i] =
static_cast<FeatureType>(get<Integer const>(split_type[i]));
auto const& j_categories = get<Array const>(categories[i]);
bst_cat_t max_cat { std::numeric_limits<bst_cat_t>::min() };
for (auto const& j_cat : j_categories) {
auto cat = common::AsCat(get<Integer const>(j_cat));
max_cat = std::max(max_cat, cat);
}
size_t size = max_cat == std::numeric_limits<bst_cat_t>::min()
? 0
: common::KCatBitField::ComputeStorageSize(max_cat);
std::vector<uint32_t> cat_bits_storage(size);
common::CatBitField cat_bits{common::Span<uint32_t>(cat_bits_storage)};
for (auto const& j_cat : j_categories) {
cat_bits.Set(common::AsCat(get<Integer const>(j_cat)));
}
auto begin = split_categories_.size();
split_categories_.resize(begin + cat_bits_storage.size());
std::copy(cat_bits_storage.begin(), cat_bits_storage.end(),
split_categories_.begin() + begin);
split_categories_segments_[i].beg = begin;
split_categories_segments_[i].size = cat_bits_storage.size();
}
}

if (has_cat) {
this->LoadCategoricalSplit(in);
} else {
this->split_categories_segments_.resize(this->param.num_nodes);
std::fill(split_types_.begin(), split_types_.end(), FeatureType::kNumerical);
}

deleted_nodes_.clear();
for (bst_node_t i = 1; i < param.num_nodes; ++i) {
if (nodes_[i].IsDeleted()) {
deleted_nodes_.push_back(i);
}
}

// easier access to [] operator
auto& self = *this;
for (auto nid = 1; nid < n_nodes; ++nid) {
auto parent = self[nid].Parent();
CHECK_NE(parent, RegTree::kInvalidNodeId);
self[nid].SetParent(self[nid].Parent(), self[parent].LeftChild() == nid);
}
CHECK_EQ(static_cast<bst_node_t>(deleted_nodes_.size()), param.num_deleted);
CHECK_EQ(this->split_categories_segments_.size(), param.num_nodes);
}

void RegTree::SaveModel(Json* p_out) const {
/* Here we are treating leaf node and internal node equally. Some information like
* child node id doesn't make sense for leaf node but we will have to save them to
* avoid creating a huge map. One difficulty is XGBoost has deleted node created by
* pruner, and this pruner can be used inside another updater so leaf are not necessary
* at the end of node array.
*/
auto& out = *p_out;
CHECK_EQ(param.num_nodes, static_cast<int>(nodes_.size()));
CHECK_EQ(param.num_nodes, static_cast<int>(stats_.size()));
Expand All @@ -894,8 +974,7 @@ void RegTree::SaveModel(Json* p_out) const {
std::vector<Json> conds(n_nodes);
std::vector<Json> default_left(n_nodes);
std::vector<Json> split_type(n_nodes);

std::vector<Json> categories(n_nodes);
CHECK_EQ(this->split_types_.size(), param.num_nodes);

for (bst_node_t i = 0; i < n_nodes; ++i) {
auto const& s = stats_[i];
Expand All @@ -911,25 +990,12 @@ void RegTree::SaveModel(Json* p_out) const {
conds[i] = n.SplitCond();
default_left[i] = n.DefaultLeft();

std::vector<Json> categories_temp;
// This condition is only for being compatibale with older version of XGBoost model
// that doesn't have categorical data support.
if (this->GetSplitTypes().size() == static_cast<size_t>(n_nodes)) {
CHECK_EQ(this->GetSplitCategoriesPtr().size(), param.num_nodes);
split_type[i] = static_cast<I>(this->NodeSplitType(i));
auto beg = this->GetSplitCategoriesPtr().at(i).beg;
auto size = this->GetSplitCategoriesPtr().at(i).size;
auto node_categories = this->GetSplitCategories().subspan(beg, size);
common::KCatBitField const cat_bits(node_categories);
for (size_t i = 0; i < cat_bits.Size(); ++i) {
if (cat_bits.Check(i)) {
categories_temp.emplace_back(static_cast<Integer::Int>(i));
}
}
}
categories[i] = Array(categories_temp);
split_type[i] = static_cast<I>(this->NodeSplitType(i));
}

this->SaveCategoricalSplit(&out);

out["split_type"] = std::move(split_type);
out["loss_changes"] = std::move(loss_changes);
out["sum_hessian"] = std::move(sum_hessian);
out["base_weights"] = std::move(base_weights);
Expand All @@ -940,12 +1006,6 @@ void RegTree::SaveModel(Json* p_out) const {
out["split_indices"] = std::move(indices);
out["split_conditions"] = std::move(conds);
out["default_left"] = std::move(default_left);

out["categories"] = categories;

if (this->GetSplitTypes().size() == static_cast<size_t>(n_nodes)) {
out["split_type"] = std::move(split_type);
}
}

void RegTree::FillNodeMeanValues() {
Expand Down