Skip to content

Commit

Permalink
Improve JSON format for categorical features. (#6128)
Browse files Browse the repository at this point in the history
* Gather categories for all nodes.
  • Loading branch information
trivialfis committed Sep 21, 2020
1 parent 210c131 commit 7065779
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 51 deletions.
2 changes: 2 additions & 0 deletions include/xgboost/tree_model.h
Expand Up @@ -620,6 +620,8 @@ class RegTree : public Model {
};

private:
void LoadCategoricalSplit(Json const& in);
void SaveCategoricalSplit(Json* p_out) const;
// vector of nodes
std::vector<Node> nodes_;
// free node space, used during training process
Expand Down
162 changes: 111 additions & 51 deletions src/tree/tree_model.cc
@@ -1,5 +1,5 @@
/*!
* Copyright 2015-2019 by Contributors
* Copyright 2015-2020 by Contributors
* \file tree_model.cc
* \brief model structure for tree
*/
Expand Down Expand Up @@ -740,7 +740,11 @@ void RegTree::Load(dmlc::Stream* fi) {
}
}
CHECK_EQ(static_cast<int>(deleted_nodes_.size()), param.num_deleted);

split_types_.resize(param.num_nodes, FeatureType::kNumerical);
split_categories_segments_.resize(param.num_nodes);
}

void RegTree::Save(dmlc::Stream* fo) const {
CHECK_EQ(param.num_nodes, static_cast<int>(nodes_.size()));
CHECK_EQ(param.num_nodes, static_cast<int>(stats_.size()));
Expand Down Expand Up @@ -772,6 +776,92 @@ void RegTree::Save(dmlc::Stream* fo) const {
}
}

void RegTree::LoadCategoricalSplit(Json const& in) {
auto const& categories_segments = get<Array const>(in["categories_segments"]);
auto const& categories_sizes = get<Array const>(in["categories_sizes"]);
auto const& categories_nodes = get<Array const>(in["categories_nodes"]);
auto const& categories = get<Array const>(in["categories"]);

size_t cnt = 0;
bst_node_t last_cat_node = -1;
if (!categories_nodes.empty()) {
last_cat_node = get<Integer const>(categories_nodes[cnt]);
}
for (size_t nidx = 0; nidx < param.num_nodes; ++nidx) {
if (nidx == last_cat_node) {
auto j_begin = get<Integer const>(categories_segments[cnt]);
auto j_end = get<Integer const>(categories_sizes[cnt]) + j_begin;
bst_cat_t max_cat{std::numeric_limits<bst_cat_t>::min()};

for (auto j = j_begin; j < j_end; ++j) {
auto const &category = get<Integer const>(categories[j]);
auto cat = common::AsCat(category);
max_cat = std::max(max_cat, cat);
}
size_t size = max_cat == std::numeric_limits<bst_cat_t>::min()
? 0
: common::KCatBitField::ComputeStorageSize(max_cat);
std::vector<uint32_t> cat_bits_storage(size);
common::CatBitField cat_bits{common::Span<uint32_t>(cat_bits_storage)};
for (auto j = j_begin; j < j_end; ++j) {
cat_bits.Set(common::AsCat(get<Integer const>(categories[j])));
}

auto begin = split_categories_.size();
split_categories_.resize(begin + cat_bits_storage.size());
std::copy(cat_bits_storage.begin(), cat_bits_storage.end(),
split_categories_.begin() + begin);
split_categories_segments_[nidx].beg = begin;
split_categories_segments_[nidx].size = cat_bits_storage.size();

++cnt;
if (cnt == categories_nodes.size()) {
last_cat_node = -1;
} else {
last_cat_node = get<Integer const>(categories_nodes[++cnt]);
}
} else {
split_categories_segments_[nidx].beg = categories.size();
split_categories_segments_[nidx].size = 0;
}
}
}

void RegTree::SaveCategoricalSplit(Json* p_out) const {
auto& out = *p_out;
CHECK_EQ(this->split_types_.size(), param.num_nodes);
CHECK_EQ(this->GetSplitCategoriesPtr().size(), param.num_nodes);

std::vector<Json> categories_segments;
std::vector<Json> categories_sizes;
std::vector<Json> categories;
std::vector<Json> categories_nodes;

for (size_t i = 0; i < nodes_.size(); ++i) {
if (this->split_types_[i] == FeatureType::kCategorical) {
categories_nodes.emplace_back(i);
auto begin = categories.size();
categories_segments.emplace_back(static_cast<Integer::Int>(begin));
auto segment = split_categories_segments_[i];
auto node_categories =
this->GetSplitCategories().subspan(segment.beg, segment.size);
common::KCatBitField const cat_bits(node_categories);
for (size_t i = 0; i < cat_bits.Size(); ++i) {
if (cat_bits.Check(i)) {
categories.emplace_back(static_cast<Integer::Int>(i));
}
}
size_t size = categories.size() - begin;
categories_sizes.emplace_back(static_cast<Integer::Int>(size));
}
}

out["categories_segments"] = categories_segments;
out["categories_sizes"] = categories_sizes;
out["categories_nodes"] = categories_nodes;
out["categories"] = categories;
}

void RegTree::LoadModel(Json const& in) {
FromJson(in["tree_param"], &param);
auto n_nodes = param.num_nodes;
Expand Down Expand Up @@ -799,13 +889,9 @@ void RegTree::LoadModel(Json const& in) {

bool has_cat = get<Object const>(in).find("split_type") != get<Object const>(in).cend();
std::vector<Json> split_type;
std::vector<Json> categories;
if (has_cat) {
split_type = get<Array const>(in["split_type"]);
categories = get<Array const>(in["categories"]);
}


stats_.clear();
nodes_.clear();

Expand Down Expand Up @@ -833,46 +919,40 @@ void RegTree::LoadModel(Json const& in) {
if (has_cat) {
split_types_[i] =
static_cast<FeatureType>(get<Integer const>(split_type[i]));
auto const& j_categories = get<Array const>(categories[i]);
bst_cat_t max_cat { std::numeric_limits<bst_cat_t>::min() };
for (auto const& j_cat : j_categories) {
auto cat = common::AsCat(get<Integer const>(j_cat));
max_cat = std::max(max_cat, cat);
}
size_t size = max_cat == std::numeric_limits<bst_cat_t>::min()
? 0
: common::KCatBitField::ComputeStorageSize(max_cat);
std::vector<uint32_t> cat_bits_storage(size);
common::CatBitField cat_bits{common::Span<uint32_t>(cat_bits_storage)};
for (auto const& j_cat : j_categories) {
cat_bits.Set(common::AsCat(get<Integer const>(j_cat)));
}
auto begin = split_categories_.size();
split_categories_.resize(begin + cat_bits_storage.size());
std::copy(cat_bits_storage.begin(), cat_bits_storage.end(),
split_categories_.begin() + begin);
split_categories_segments_[i].beg = begin;
split_categories_segments_[i].size = cat_bits_storage.size();
}
}

if (has_cat) {
this->LoadCategoricalSplit(in);
} else {
this->split_categories_segments_.resize(this->param.num_nodes);
std::fill(split_types_.begin(), split_types_.end(), FeatureType::kNumerical);
}

deleted_nodes_.clear();
for (bst_node_t i = 1; i < param.num_nodes; ++i) {
if (nodes_[i].IsDeleted()) {
deleted_nodes_.push_back(i);
}
}

// easier access to [] operator
auto& self = *this;
for (auto nid = 1; nid < n_nodes; ++nid) {
auto parent = self[nid].Parent();
CHECK_NE(parent, RegTree::kInvalidNodeId);
self[nid].SetParent(self[nid].Parent(), self[parent].LeftChild() == nid);
}
CHECK_EQ(static_cast<bst_node_t>(deleted_nodes_.size()), param.num_deleted);
CHECK_EQ(this->split_categories_segments_.size(), param.num_nodes);
}

void RegTree::SaveModel(Json* p_out) const {
/* Here we are treating leaf node and internal node equally. Some information like
* child node id doesn't make sense for leaf node but we will have to save them to
* avoid creating a huge map. One difficulty is XGBoost has deleted node created by
* pruner, and this pruner can be used inside another updater so leaf are not necessary
* at the end of node array.
*/
auto& out = *p_out;
CHECK_EQ(param.num_nodes, static_cast<int>(nodes_.size()));
CHECK_EQ(param.num_nodes, static_cast<int>(stats_.size()));
Expand All @@ -894,8 +974,7 @@ void RegTree::SaveModel(Json* p_out) const {
std::vector<Json> conds(n_nodes);
std::vector<Json> default_left(n_nodes);
std::vector<Json> split_type(n_nodes);

std::vector<Json> categories(n_nodes);
CHECK_EQ(this->split_types_.size(), param.num_nodes);

for (bst_node_t i = 0; i < n_nodes; ++i) {
auto const& s = stats_[i];
Expand All @@ -911,25 +990,12 @@ void RegTree::SaveModel(Json* p_out) const {
conds[i] = n.SplitCond();
default_left[i] = n.DefaultLeft();

std::vector<Json> categories_temp;
// This condition is only for being compatibale with older version of XGBoost model
// that doesn't have categorical data support.
if (this->GetSplitTypes().size() == static_cast<size_t>(n_nodes)) {
CHECK_EQ(this->GetSplitCategoriesPtr().size(), param.num_nodes);
split_type[i] = static_cast<I>(this->NodeSplitType(i));
auto beg = this->GetSplitCategoriesPtr().at(i).beg;
auto size = this->GetSplitCategoriesPtr().at(i).size;
auto node_categories = this->GetSplitCategories().subspan(beg, size);
common::KCatBitField const cat_bits(node_categories);
for (size_t i = 0; i < cat_bits.Size(); ++i) {
if (cat_bits.Check(i)) {
categories_temp.emplace_back(static_cast<Integer::Int>(i));
}
}
}
categories[i] = Array(categories_temp);
split_type[i] = static_cast<I>(this->NodeSplitType(i));
}

this->SaveCategoricalSplit(&out);

out["split_type"] = std::move(split_type);
out["loss_changes"] = std::move(loss_changes);
out["sum_hessian"] = std::move(sum_hessian);
out["base_weights"] = std::move(base_weights);
Expand All @@ -940,12 +1006,6 @@ void RegTree::SaveModel(Json* p_out) const {
out["split_indices"] = std::move(indices);
out["split_conditions"] = std::move(conds);
out["default_left"] = std::move(default_left);

out["categories"] = categories;

if (this->GetSplitTypes().size() == static_cast<size_t>(n_nodes)) {
out["split_type"] = std::move(split_type);
}
}

void RegTree::FillNodeMeanValues() {
Expand Down

0 comments on commit 7065779

Please sign in to comment.