Skip to content

Commit

Permalink
Reviewer's comment.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Aug 19, 2020
1 parent 693f334 commit cc26f9d
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 10 deletions.
2 changes: 0 additions & 2 deletions include/xgboost/span.h
Expand Up @@ -528,7 +528,6 @@ class Span {

XGBOOST_DEVICE reference operator[](index_type _idx) const {
SPAN_LT(_idx, size());
SPAN_CHECK(_idx < size());
return data()[_idx];
}

Expand Down Expand Up @@ -588,7 +587,6 @@ class Span {
detail::ExtentValue<Extent, Offset, Count>::value> {
SPAN_CHECK((Count == dynamic_extent) ?
(Offset <= size()) : (Offset + Count <= size()));

return {data() + Offset, Count == dynamic_extent ? size() - Offset : Count};
}

Expand Down
15 changes: 7 additions & 8 deletions src/tree/tree_model.cc
Expand Up @@ -867,7 +867,6 @@ void RegTree::SaveModel(Json* p_out) const {

std::vector<Json> categories(n_nodes);

auto& self = *this;
for (bst_node_t i = 0; i < n_nodes; ++i) {
auto const& s = stats_[i];
loss_changes[i] = s.loss_chg;
Expand All @@ -886,12 +885,12 @@ void RegTree::SaveModel(Json* p_out) const {
std::vector<Json> categories_temp;
// This condition is only for being compatibale with older version of XGBoost model
// that doesn't have categorical data support.
if (self.GetSplitTypes().size() == static_cast<size_t>(n_nodes)) {
CHECK_EQ(self.split_categories_segments_.size(), param.num_nodes);
split_type[i] = static_cast<I>(self.NodeSplitType(i));
auto beg = self.split_categories_segments_.at(i).beg;
auto size = self.split_categories_segments_.at(i).size;
auto node_categories = self.GetSplitCategories().subspan(beg, size);
if (this->GetSplitTypes().size() == static_cast<size_t>(n_nodes)) {
CHECK_EQ(this->GetSplitCategoriesPtr().size(), param.num_nodes);
split_type[i] = static_cast<I>(this->NodeSplitType(i));
auto beg = this->GetSplitCategoriesPtr().at(i).beg;
auto size = this->GetSplitCategoriesPtr().at(i).size;
auto node_categories = this->GetSplitCategories().subspan(beg, size);
common::KCatBitField const cat_bits(node_categories);
for (size_t i = 0; i < cat_bits.Size(); ++i) {
if (cat_bits.Check(i)) {
Expand All @@ -916,7 +915,7 @@ void RegTree::SaveModel(Json* p_out) const {

out["categories"] = categories;

if (self.GetSplitTypes().size() == static_cast<size_t>(n_nodes)) {
if (this->GetSplitTypes().size() == static_cast<size_t>(n_nodes)) {
out["split_type"] = std::move(split_type);
}
}
Expand Down

0 comments on commit cc26f9d

Please sign in to comment.