Skip to content

Commit

Permalink
Fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Aug 19, 2020
1 parent c40f197 commit 96433a6
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/tree/tree_model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -842,8 +842,10 @@ void RegTree::LoadModel(Json const& in) {
auto cat = common::AsCat(get<Integer const>(j_cat));
max_cat = std::max(max_cat, cat);
}
std::vector<uint32_t> cat_bits_storage(
common::KCatBitField::ComputeStorageSize(max_cat));
size_t size = max_cat == std::numeric_limits<bst_cat_t>::min()
? 0
: common::KCatBitField::ComputeStorageSize(max_cat);
std::vector<uint32_t> cat_bits_storage(size);
common::CatBitField cat_bits{common::Span<uint32_t>(cat_bits_storage)};
for (auto const& j_cat : j_categories) {
cat_bits.Set(common::AsCat(get<Integer const>(j_cat)));
Expand Down Expand Up @@ -915,6 +917,7 @@ void RegTree::SaveModel(Json* p_out) const {
conds[i] = n.SplitCond();
default_left[i] = n.DefaultLeft();

std::vector<Json> categories_temp;
// This condition is only for being compatibale with older version of XGBoost model
// that doesn't have categorical data support.
if (self.GetSplitTypes().size() == static_cast<size_t>(n_nodes)) {
Expand All @@ -924,14 +927,13 @@ void RegTree::SaveModel(Json* p_out) const {
auto size = self.split_categories_segments_.at(i).size;
auto node_categories = self.GetSplitCategories().subspan(beg, size);
common::KCatBitField const cat_bits(node_categories);
std::vector<Json> categories_temp;
for (size_t i = 0; i < cat_bits.Size(); ++i) {
if (cat_bits.Check(i)) {
categories_temp.emplace_back(static_cast<Integer::Int>(i));
}
}
categories[i] = Array(categories_temp);
}
categories[i] = Array(categories_temp);
}

out["loss_changes"] = std::move(loss_changes);
Expand Down

0 comments on commit 96433a6

Please sign in to comment.