Skip to content

Commit

Permalink
applied new driver logic
Browse files Browse the repository at this point in the history
  • Loading branch information
ShvetsKS committed Jun 25, 2022
1 parent 8cd8cd5 commit 7780172
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 37 deletions.
3 changes: 0 additions & 3 deletions src/tree/driver.h
Expand Up @@ -59,9 +59,7 @@ class Driver {
// Can a child of this entry still be expanded?
// can be used to avoid extra work
bool IsChildValid(ExpandEntryT const& parent_entry) {
std::cout << "param_.max_depth:" << param_.max_depth << " parent_entry.depth:" << parent_entry.depth << std::endl;
if (param_.max_depth > 0 && parent_entry.depth + 1 >= param_.max_depth) return false;
std::cout << "param_.max_leaves:" << param_.max_leaves << " parent_entry.num_leaves_:" << num_leaves_ << std::endl;
if (param_.max_leaves > 0 && num_leaves_ >= param_.max_leaves) return false;
return true;
}
Expand All @@ -70,7 +68,6 @@ class Driver {
// This set has no dependencies between entries so they may be expanded in
// parallel or asynchronously
std::vector<ExpandEntryT> Pop() {
std::cout << "queue_.size():" << queue_.size() << std::endl;
if (queue_.empty()) return {};
// Return a single entry for loss guided mode
if (param_.grow_policy == TrainParam::kLossGuide) {
Expand Down
6 changes: 1 addition & 5 deletions src/tree/hist/expand_entry.h
Expand Up @@ -24,20 +24,16 @@ struct CPUExpandEntry {
}

bool IsValid(const TrainParam& param, int num_leaves) const {
std::cout << "split.loss_chg:" << split.loss_chg << " eps:" << kRtEps << std::endl;
if (split.loss_chg <= kRtEps) { std::cout << "NOT VALID!" << std::endl; return false;}
std::cout << "split.left_sum.GetHess():" << split.left_sum.GetHess() << " split.right_sum.GetHess():" << split.right_sum.GetHess() << std::endl;
if (split.loss_chg <= kRtEps) { return false;}
if (split.left_sum.GetHess() == 0 || split.right_sum.GetHess() == 0) {
return false;
}
std::cout << "split.loss_chg:" << split.loss_chg << " param.min_split_loss:" << param.min_split_loss << std::endl;
if (split.loss_chg < param.min_split_loss) {
return false;
}
if (param.max_depth > 0 && depth == param.max_depth) {
return false;
}
std::cout << "param.max_leaves:" << param.max_leaves << " num_leaves:" << num_leaves << std::endl;
if (param.max_leaves > 0 && num_leaves == param.max_leaves) {
return false;
}
Expand Down
29 changes: 1 addition & 28 deletions src/tree/updater_quantile_hist.cc
Expand Up @@ -155,7 +155,7 @@ void QuantileHistMaker::Builder::InitRoot(

void QuantileHistMaker::Builder::AddSplitsToTree(
const std::vector<CPUExpandEntry>& expand,
Driver<CPUExpandEntry>* driver,
Driver<CPUExpandEntry>* driver,
RegTree *p_tree,
int *num_leaves,
std::vector<CPUExpandEntry>* nodes_for_apply_split,
Expand All @@ -165,14 +165,9 @@ void QuantileHistMaker::Builder::AddSplitsToTree(
const bool is_loss_guided = static_cast<TrainParam::TreeGrowPolicy>(param_.grow_policy)
!= TrainParam::kDepthWise;
std::vector<uint16_t> complete_node_ids;
std::cout << "expand.size():" << expand.size() << std::endl;
for (auto const& entry : expand) {
if (driver->IsChildValid(entry)) {

// if (entry.IsValid(param_, *num_leaves)) {
nodes_for_apply_split->push_back(entry);
evaluator_->ApplyTreeSplit(entry, p_tree);
// (*num_leaves)++;
complete_node_ids.push_back((*p_tree)[entry.nid].LeftChild());
complete_node_ids.push_back((*p_tree)[entry.nid].RightChild());
*is_left_small = entry.split.left_sum.GetHess() <= entry.split.right_sum.GetHess();
Expand All @@ -183,9 +178,6 @@ void QuantileHistMaker::Builder::AddSplitsToTree(
smalest_nodes_mask[(*p_tree)[entry.nid].RightChild()] = true;
smalest_nodes_mask[ (*p_tree)[entry.nid].LeftChild()] = false;
}
} else {
std::cout << "Not valid!!! entry.nid:" << entry.nid << std::endl;
}
}
child_node_ids_ = complete_node_ids;
}
Expand Down Expand Up @@ -238,14 +230,11 @@ void QuantileHistMaker::Builder::ExpandTree(
RegTree* p_tree,
const std::vector<GradientPair>& gpair_h,
HostDeviceVector<bst_node_t> *p_out_position) {
std::cout << "ExpandTree 1" << std::endl;
monitor_->Start("ExpandTree");
int num_leaves = 0;
split_conditions_.clear();
split_ind_.clear();
Driver<CPUExpandEntry> driver(param_);
std::cout << "ExpandTree 2" << std::endl;
// Driver<CPUExpandEntry> driver(static_cast<TrainParam::TreeGrowPolicy>(param_.grow_policy));
std::vector<CPUExpandEntry> expand;
size_t page_id{0};
std::vector<size_t>& row_indices = *row_set_collection_.Data();
Expand All @@ -261,29 +250,23 @@ void QuantileHistMaker::Builder::ExpandTree(
TrainParam::kDepthWise ? false : true;

InitRoot<BinIdxType, any_missing>(gmat, p_fmat, p_tree, gpair_h, &num_leaves, &expand);
std::cout << "ExpandTree 3" << std::endl;
driver.Push(expand[0]);
child_node_ids_.clear();
child_node_ids_.emplace_back(0);
int32_t depth = 0;
while (!driver.IsEmpty()) {
std::unordered_map<uint32_t, bool> smalest_nodes_mask;
std::cout << "ExpandTree before POP:" << depth << std::endl;
expand = driver.Pop();
std::cout << "ExpandTree after POP:" << depth << std::endl;
if (expand.size()) {
depth = expand[0].depth + 1;
}
std::cout << "ExpandTree depth:" << depth << std::endl;
std::vector<CPUExpandEntry> nodes_for_apply_split;
std::vector<CPUExpandEntry> nodes_to_evaluate;
nodes_for_explicit_hist_build_.clear();
nodes_for_subtraction_trick_.clear();
bool is_left_small = false;
AddSplitsToTree(expand, &driver, p_tree, &num_leaves, &nodes_for_apply_split,
&smalest_nodes_mask, depth, &is_left_small);
std::cout << "AddSplitsToTree finished" << std::endl;

if (nodes_for_apply_split.size() != 0) {
monitor_->Start("ApplySplit");
size_t page_id{0};
Expand All @@ -305,11 +288,9 @@ void QuantileHistMaker::Builder::ExpandTree(
true);
++page_id;
}
std::cout << "UpdatePositionDispatched finished" << std::endl;

monitor_->Stop("ApplySplit");
SplitSiblings(nodes_for_apply_split, &nodes_to_evaluate, p_tree);
std::cout << "SplitSiblings finished" << std::endl;
if (param_.max_depth == 0 || depth < param_.max_depth) {
size_t i = 0;
monitor_->Start("BuildHist");
Expand All @@ -327,7 +308,6 @@ void QuantileHistMaker::Builder::ExpandTree(
std::copy(merged_thread_ids_set[nid].begin(),
merged_thread_ids_set[nid].end(), merged_thread_ids[nid].begin());
}
std::cout << "merged_thread_ids_set finished" << std::endl;

for (auto const &gidx : p_fmat->GetBatches<GHistIndexMatrix>(HistBatch(param_))) {
CommonRowPartitioner &partitioner = this->partitioner_.at(i);
Expand All @@ -338,33 +318,26 @@ void QuantileHistMaker::Builder::ExpandTree(
&(partitioner.GetNodeAssignments()), &merged_thread_ids);
++i;
}
std::cout << "BuildHist finished" << std::endl;

monitor_->Stop("BuildHist");
monitor_->Start("EvaluateSplits");
auto ft = p_fmat->Info().feature_types.ConstHostSpan();
evaluator_->EvaluateSplits(this->histogram_builder_->Histogram(),
feature_values_, ft, *p_tree, &nodes_to_evaluate);
monitor_->Stop("EvaluateSplits");
std::cout << "EvaluateSplits finished" << std::endl;
}
std::cout << "nodes_for_apply_split.size():" << nodes_for_apply_split.size() << std::endl;
for (size_t i = 0; i < nodes_for_apply_split.size(); ++i) {
CPUExpandEntry left_node = nodes_to_evaluate.at(i * 2 + 0);
CPUExpandEntry right_node = nodes_to_evaluate.at(i * 2 + 1);
driver.Push(left_node);
driver.Push(right_node);
}
std::cout << "DRIVERPOP finished" << std::endl;
}
}

auto &h_out_position = p_out_position->HostVector();
std::cout << "LeafPartition started" << std::endl;
this->LeafPartition(*p_tree, &h_out_position);
std::cout << "LeafPartition finished" << std::endl;
monitor_->Stop(__func__);
std::cout << "ExpandTree finished" << std::endl;
}

void QuantileHistMaker::Builder::UpdateTree(HostDeviceVector<GradientPair> *gpair, DMatrix *p_fmat,
Expand Down
2 changes: 1 addition & 1 deletion src/tree/updater_quantile_hist.h
Expand Up @@ -174,7 +174,7 @@ class QuantileHistMaker: public TreeUpdater {
RegTree *p_tree);

void AddSplitsToTree(const std::vector<CPUExpandEntry>& expand,
Driver<CPUExpandEntry>* driver,
Driver<CPUExpandEntry>* driver,
RegTree *p_tree,
int *num_leaves,
std::vector<CPUExpandEntry>* nodes_for_apply_split,
Expand Down

0 comments on commit 7780172

Please sign in to comment.