Skip to content

Commit

Permalink
Cleanup.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jun 15, 2021
1 parent 07dc121 commit 63f4bc3
Showing 1 changed file with 25 additions and 50 deletions.
75 changes: 25 additions & 50 deletions src/gbm/gbtree.h
Original file line number Diff line number Diff line change
Expand Up @@ -305,73 +305,48 @@ class GBTree : public GradientBooster {
// Because feature with no importance doesn't appear in the return value so
// we need to set up another pair of vectors to store the values during
// computation.
std::vector<size_t> split_counts(this->model_.learner_model_param->num_feature, 0);
std::vector<float> gain_map(this->model_.learner_model_param->num_feature, 0);
features->clear();
scores->clear();

if (importance_type == "weight") {
auto add_score = [&](auto fn) {
for (auto const &p_tree : model_.trees) {
p_tree->WalkTree([&](bst_node_t nidx) {
auto const &node = (*p_tree)[nidx];
if (!node.IsLeaf()) {
gain_map[node.SplitIndex()]++;
split_counts[node.SplitIndex()]++;
fn(p_tree, nidx, node.SplitIndex());
}
return true;
});
}
for (size_t i = 0; i < gain_map.size(); ++i) {
if (gain_map[i] != 0.0f) {
features->push_back(i);
scores->push_back(gain_map[i]);
}
}
return;
}
};

bool average_over_splits = true;
if (importance_type == "total_gain" || importance_type == "total_cover") {
average_over_splits = false;
if (importance_type == "weight") {
add_score([&](auto const &p_tree, bst_node_t, bst_feature_t split) {
gain_map[split] = split_counts[split];
});
}

std::vector<size_t> split_counts(this->model_.learner_model_param->num_feature, 0);
if (importance_type == "gain" || importance_type == "total_gain") {
for (auto const &p_tree : model_.trees) {
p_tree->WalkTree([&](bst_node_t nidx) {
auto const &node = (*p_tree)[nidx];
if (!node.IsLeaf()) {
split_counts[node.SplitIndex()]++;
gain_map[node.SplitIndex()] += p_tree->Stat(nidx).loss_chg;
}
return true;
});
}
add_score([&](auto const &p_tree, bst_node_t nidx, bst_feature_t split) {
gain_map[split] += p_tree->Stat(nidx).loss_chg;
});
}
if (importance_type == "cover" || importance_type == "total_cover") {
for (auto const &p_tree : model_.trees) {
p_tree->WalkTree([&](bst_node_t nidx) {
auto const &node = (*p_tree)[nidx];
if (!node.IsLeaf()) {
split_counts[node.SplitIndex()]++;
gain_map[node.SplitIndex()] += p_tree->Stat(nidx).sum_hess;
}
return true;
});
add_score([&](auto const &p_tree, bst_node_t nidx, bst_feature_t split) {
gain_map[split] += p_tree->Stat(nidx).sum_hess;
});
}
if (importance_type == "gain" || importance_type == "cover") {
for (size_t i = 0; i < gain_map.size(); ++i) {
gain_map[i] /= std::max(1.0f, static_cast<float>(split_counts[i]));
}
}

if (average_over_splits) {
for (size_t i = 0; i < split_counts.size(); ++i) {
if (split_counts[i] != 0) {
features->push_back(i);
scores->push_back(gain_map[i] / static_cast<double>(split_counts[i]));
}
}
} else {
for (size_t i = 0; i < split_counts.size(); ++i) {
if (split_counts[i] != 0) {
features->push_back(i);
scores->push_back(gain_map[i]);
}
features->clear();
scores->clear();
for (size_t i = 0; i < split_counts.size(); ++i) {
if (split_counts[i] != 0) {
features->push_back(i);
scores->push_back(gain_map[i]);
}
}
}
Expand Down

0 comments on commit 63f4bc3

Please sign in to comment.