Skip to content

Commit

Permalink
dispatched call for update position
Browse files Browse the repository at this point in the history
  • Loading branch information
ShvetsKS committed May 14, 2022
1 parent 912105c commit d155edc
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 53 deletions.
3 changes: 0 additions & 3 deletions src/tree/updater_approx.cc
Expand Up @@ -122,8 +122,6 @@ class GloablApproxBuilder {
CHECK_EQ(out_preds.Size(), data->Info().num_row_);
CHECK(p_last_tree_);

size_t n_nodes = p_last_tree_->GetNodes().size();

auto evaluator = evaluator_.Evaluator();
auto const &tree = *p_last_tree_;
auto const &snode = evaluator_.Stats();
Expand Down Expand Up @@ -275,7 +273,6 @@ class GloablApproxBuilder {
evaluator_.ApplyTreeSplit(candidate, p_tree);
applied[candidate.nid] = candidate;
applied_vec.push_back(candidate);
is_applied = true;
CHECK_EQ(applied[candidate.nid].nid, candidate.nid);
num_leaves++;
int left_child_nidx = tree[candidate.nid].LeftChild();
Expand Down
71 changes: 21 additions & 50 deletions src/tree/updater_quantile_hist.cc
Expand Up @@ -245,15 +245,13 @@ void QuantileHistMaker::Builder::SplitSiblings(
const CPUExpandEntry right_node = CPUExpandEntry(cright, p_tree->GetDepth(cright), 0.0);
nodes_to_evaluate->push_back(left_node);
nodes_to_evaluate->push_back(right_node);
bool is_loss_guide = static_cast<TrainParam::TreeGrowPolicy>(param_.grow_policy) ==
TrainParam::kDepthWise ? false : true;
if (entry.split.left_sum.GetHess() <= entry.split.right_sum.GetHess()) {
nodes_for_explicit_hist_build_.push_back(left_node);
nodes_for_subtraction_trick_.push_back(right_node);
} else {
nodes_for_explicit_hist_build_.push_back(right_node);
nodes_for_subtraction_trick_.push_back(left_node);
}
if (entry.split.left_sum.GetHess() <= entry.split.right_sum.GetHess()) {
nodes_for_explicit_hist_build_.push_back(left_node);
nodes_for_subtraction_trick_.push_back(right_node);
} else {
nodes_for_explicit_hist_build_.push_back(right_node);
nodes_for_subtraction_trick_.push_back(left_node);
}
}
monitor_->Stop("SplitSiblings");
}
Expand Down Expand Up @@ -308,47 +306,20 @@ void QuantileHistMaker::Builder::ExpandTree(
size_t page_id{0};
for (auto const &page : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
CommonRowPartitioner &partitioner = this->partitioner_.at(page_id);
if (is_loss_guide) {
if (page.cut.HasCategorical()) {
partitioner.UpdatePosition<any_missing, BinIdxType, true, true>(this->ctx_, page,
nodes_for_apply_split, p_tree,
depth,
&smalest_nodes_mask,
is_loss_guide,
&split_conditions_,
&split_ind_, param_.max_depth,
&child_node_ids_, is_left_small, true);
} else {
partitioner.UpdatePosition<any_missing, BinIdxType, true, false>(this->ctx_, page,
nodes_for_apply_split, p_tree,
depth,
&smalest_nodes_mask,
is_loss_guide,
&split_conditions_,
&split_ind_, param_.max_depth,
&child_node_ids_, is_left_small, true);
}
} else {
if (page.cut.HasCategorical()) {
partitioner.UpdatePosition<any_missing, BinIdxType, false, true>(this->ctx_, page,
nodes_for_apply_split, p_tree,
depth,
&smalest_nodes_mask,
is_loss_guide,
&split_conditions_,
&split_ind_, param_.max_depth,
&child_node_ids_, is_left_small, true);
} else {
partitioner.UpdatePosition<any_missing, BinIdxType, false, false>(this->ctx_, page,
nodes_for_apply_split, p_tree,
depth,
&smalest_nodes_mask,
is_loss_guide,
&split_conditions_,
&split_ind_, param_.max_depth,
&child_node_ids_, is_left_small, true);
}
}
partitioner.UpdatePositionDispatched({any_missing,
static_cast<common::BinTypeSize>(sizeof(BinIdxType)),
is_loss_guide, page.cut.HasCategorical()},
this->ctx_,
page,
nodes_for_apply_split,
p_tree,
depth,
&smalest_nodes_mask,
is_loss_guide,
&split_conditions_,
&split_ind_, param_.max_depth,
&child_node_ids_, is_left_small,
true);
++page_id;
}

Expand Down

0 comments on commit d155edc

Please sign in to comment.