From 7065779afa6c1a94c736cc4d1a3166c75a206fae Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 21 Sep 2020 15:35:05 +0800 Subject: [PATCH] Improve JSON format for categorical features. (#6128) * Gather categories for all nodes. --- include/xgboost/tree_model.h | 2 + src/tree/tree_model.cc | 162 ++++++++++++++++++++++++----------- 2 files changed, 113 insertions(+), 51 deletions(-) diff --git a/include/xgboost/tree_model.h b/include/xgboost/tree_model.h index faec6b118a3a..78e53a6344ec 100644 --- a/include/xgboost/tree_model.h +++ b/include/xgboost/tree_model.h @@ -620,6 +620,8 @@ class RegTree : public Model { }; private: + void LoadCategoricalSplit(Json const& in); + void SaveCategoricalSplit(Json* p_out) const; // vector of nodes std::vector nodes_; // free node space, used during training process diff --git a/src/tree/tree_model.cc b/src/tree/tree_model.cc index 6e1ff15039b4..0447ed6925c4 100644 --- a/src/tree/tree_model.cc +++ b/src/tree/tree_model.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2015-2019 by Contributors + * Copyright 2015-2020 by Contributors * \file tree_model.cc * \brief model structure for tree */ @@ -740,7 +740,11 @@ void RegTree::Load(dmlc::Stream* fi) { } } CHECK_EQ(static_cast(deleted_nodes_.size()), param.num_deleted); + + split_types_.resize(param.num_nodes, FeatureType::kNumerical); + split_categories_segments_.resize(param.num_nodes); } + void RegTree::Save(dmlc::Stream* fo) const { CHECK_EQ(param.num_nodes, static_cast(nodes_.size())); CHECK_EQ(param.num_nodes, static_cast(stats_.size())); @@ -772,6 +776,92 @@ void RegTree::Save(dmlc::Stream* fo) const { } } +void RegTree::LoadCategoricalSplit(Json const& in) { + auto const& categories_segments = get(in["categories_segments"]); + auto const& categories_sizes = get(in["categories_sizes"]); + auto const& categories_nodes = get(in["categories_nodes"]); + auto const& categories = get(in["categories"]); + + size_t cnt = 0; + bst_node_t last_cat_node = -1; + if (!categories_nodes.empty()) { + last_cat_node = get(categories_nodes[cnt]); + } + for (size_t nidx = 0; nidx < param.num_nodes; ++nidx) { + if (nidx == last_cat_node) { + auto j_begin = get(categories_segments[cnt]); + auto j_end = get(categories_sizes[cnt]) + j_begin; + bst_cat_t max_cat{std::numeric_limits::min()}; + + for (auto j = j_begin; j < j_end; ++j) { + auto const &category = get(categories[j]); + auto cat = common::AsCat(category); + max_cat = std::max(max_cat, cat); + } + size_t size = max_cat == std::numeric_limits::min() + ? 0 + : common::KCatBitField::ComputeStorageSize(max_cat); + std::vector cat_bits_storage(size); + common::CatBitField cat_bits{common::Span(cat_bits_storage)}; + for (auto j = j_begin; j < j_end; ++j) { + cat_bits.Set(common::AsCat(get(categories[j]))); + } + + auto begin = split_categories_.size(); + split_categories_.resize(begin + cat_bits_storage.size()); + std::copy(cat_bits_storage.begin(), cat_bits_storage.end(), + split_categories_.begin() + begin); + split_categories_segments_[nidx].beg = begin; + split_categories_segments_[nidx].size = cat_bits_storage.size(); + + ++cnt; + if (cnt == categories_nodes.size()) { + last_cat_node = -1; + } else { + last_cat_node = get(categories_nodes[++cnt]); + } + } else { + split_categories_segments_[nidx].beg = categories.size(); + split_categories_segments_[nidx].size = 0; + } + } +} + +void RegTree::SaveCategoricalSplit(Json* p_out) const { + auto& out = *p_out; + CHECK_EQ(this->split_types_.size(), param.num_nodes); + CHECK_EQ(this->GetSplitCategoriesPtr().size(), param.num_nodes); + + std::vector categories_segments; + std::vector categories_sizes; + std::vector categories; + std::vector categories_nodes; + + for (size_t i = 0; i < nodes_.size(); ++i) { + if (this->split_types_[i] == FeatureType::kCategorical) { + categories_nodes.emplace_back(i); + auto begin = categories.size(); + categories_segments.emplace_back(static_cast(begin)); + auto segment = split_categories_segments_[i]; + auto node_categories = + this->GetSplitCategories().subspan(segment.beg, segment.size); + common::KCatBitField const cat_bits(node_categories); + for (size_t i = 0; i < cat_bits.Size(); ++i) { + if (cat_bits.Check(i)) { + categories.emplace_back(static_cast(i)); + } + } + size_t size = categories.size() - begin; + categories_sizes.emplace_back(static_cast(size)); + } + } + + out["categories_segments"] = categories_segments; + out["categories_sizes"] = categories_sizes; + out["categories_nodes"] = categories_nodes; + out["categories"] = categories; +} + void RegTree::LoadModel(Json const& in) { FromJson(in["tree_param"], ¶m); auto n_nodes = param.num_nodes; @@ -799,13 +889,9 @@ void RegTree::LoadModel(Json const& in) { bool has_cat = get(in).find("split_type") != get(in).cend(); std::vector split_type; - std::vector categories; if (has_cat) { split_type = get(in["split_type"]); - categories = get(in["categories"]); } - - stats_.clear(); nodes_.clear(); @@ -833,36 +919,23 @@ void RegTree::LoadModel(Json const& in) { if (has_cat) { split_types_[i] = static_cast(get(split_type[i])); - auto const& j_categories = get(categories[i]); - bst_cat_t max_cat { std::numeric_limits::min() }; - for (auto const& j_cat : j_categories) { - auto cat = common::AsCat(get(j_cat)); - max_cat = std::max(max_cat, cat); - } - size_t size = max_cat == std::numeric_limits::min() - ? 0 - : common::KCatBitField::ComputeStorageSize(max_cat); - std::vector cat_bits_storage(size); - common::CatBitField cat_bits{common::Span(cat_bits_storage)}; - for (auto const& j_cat : j_categories) { - cat_bits.Set(common::AsCat(get(j_cat))); - } - auto begin = split_categories_.size(); - split_categories_.resize(begin + cat_bits_storage.size()); - std::copy(cat_bits_storage.begin(), cat_bits_storage.end(), - split_categories_.begin() + begin); - split_categories_segments_[i].beg = begin; - split_categories_segments_[i].size = cat_bits_storage.size(); } } + if (has_cat) { + this->LoadCategoricalSplit(in); + } else { + this->split_categories_segments_.resize(this->param.num_nodes); + std::fill(split_types_.begin(), split_types_.end(), FeatureType::kNumerical); + } + deleted_nodes_.clear(); for (bst_node_t i = 1; i < param.num_nodes; ++i) { if (nodes_[i].IsDeleted()) { deleted_nodes_.push_back(i); } } - + // easier access to [] operator auto& self = *this; for (auto nid = 1; nid < n_nodes; ++nid) { auto parent = self[nid].Parent(); @@ -870,9 +943,16 @@ void RegTree::LoadModel(Json const& in) { self[nid].SetParent(self[nid].Parent(), self[parent].LeftChild() == nid); } CHECK_EQ(static_cast(deleted_nodes_.size()), param.num_deleted); + CHECK_EQ(this->split_categories_segments_.size(), param.num_nodes); } void RegTree::SaveModel(Json* p_out) const { + /* Here we are treating leaf node and internal node equally. Some information like + * child node id doesn't make sense for leaf node but we will have to save them to + * avoid creating a huge map. One difficulty is XGBoost has deleted node created by + * pruner, and this pruner can be used inside another updater so leaf are not necessary + * at the end of node array. + */ auto& out = *p_out; CHECK_EQ(param.num_nodes, static_cast(nodes_.size())); CHECK_EQ(param.num_nodes, static_cast(stats_.size())); @@ -894,8 +974,7 @@ void RegTree::SaveModel(Json* p_out) const { std::vector conds(n_nodes); std::vector default_left(n_nodes); std::vector split_type(n_nodes); - - std::vector categories(n_nodes); + CHECK_EQ(this->split_types_.size(), param.num_nodes); for (bst_node_t i = 0; i < n_nodes; ++i) { auto const& s = stats_[i]; @@ -911,25 +990,12 @@ void RegTree::SaveModel(Json* p_out) const { conds[i] = n.SplitCond(); default_left[i] = n.DefaultLeft(); - std::vector categories_temp; - // This condition is only for being compatibale with older version of XGBoost model - // that doesn't have categorical data support. - if (this->GetSplitTypes().size() == static_cast(n_nodes)) { - CHECK_EQ(this->GetSplitCategoriesPtr().size(), param.num_nodes); - split_type[i] = static_cast(this->NodeSplitType(i)); - auto beg = this->GetSplitCategoriesPtr().at(i).beg; - auto size = this->GetSplitCategoriesPtr().at(i).size; - auto node_categories = this->GetSplitCategories().subspan(beg, size); - common::KCatBitField const cat_bits(node_categories); - for (size_t i = 0; i < cat_bits.Size(); ++i) { - if (cat_bits.Check(i)) { - categories_temp.emplace_back(static_cast(i)); - } - } - } - categories[i] = Array(categories_temp); + split_type[i] = static_cast(this->NodeSplitType(i)); } + this->SaveCategoricalSplit(&out); + + out["split_type"] = std::move(split_type); out["loss_changes"] = std::move(loss_changes); out["sum_hessian"] = std::move(sum_hessian); out["base_weights"] = std::move(base_weights); @@ -940,12 +1006,6 @@ void RegTree::SaveModel(Json* p_out) const { out["split_indices"] = std::move(indices); out["split_conditions"] = std::move(conds); out["default_left"] = std::move(default_left); - - out["categories"] = categories; - - if (this->GetSplitTypes().size() == static_cast(n_nodes)) { - out["split_type"] = std::move(split_type); - } } void RegTree::FillNodeMeanValues() {