Skip to content

Commit

Permalink
Fuse gpu_hist all-reduce calls where possible (#7867)
Browse files Browse the repository at this point in the history
  • Loading branch information
RAMitchell committed May 17, 2022
1 parent b41cf92 commit 71d3b2e
Show file tree
Hide file tree
Showing 9 changed files with 236 additions and 187 deletions.
39 changes: 30 additions & 9 deletions src/tree/driver.h
Expand Up @@ -33,10 +33,11 @@ class Driver {
std::function<bool(ExpandEntryT, ExpandEntryT)>>;

public:
explicit Driver(TrainParam::TreeGrowPolicy policy)
: policy_(policy),
queue_(policy == TrainParam::kDepthWise ? DepthWise<ExpandEntryT> :
LossGuide<ExpandEntryT>) {}
explicit Driver(TrainParam param, std::size_t max_node_batch_size = 256)
: param_(param),
max_node_batch_size_(max_node_batch_size),
queue_(param.grow_policy == TrainParam::kDepthWise ? DepthWise<ExpandEntryT>
: LossGuide<ExpandEntryT>) {}
template <typename EntryIterT>
void Push(EntryIterT begin, EntryIterT end) {
for (auto it = begin; it != end; ++it) {
Expand All @@ -55,24 +56,42 @@ class Driver {
return queue_.empty();
}

// Can a child of this entry still be expanded?
// can be used to avoid extra work
bool IsChildValid(ExpandEntryT const& parent_entry) {
if (param_.max_depth > 0 && parent_entry.depth + 1 >= param_.max_depth) return false;
if (param_.max_leaves > 0 && num_leaves_ >= param_.max_leaves) return false;
return true;
}

// Return the set of nodes to be expanded
// This set has no dependencies between entries so they may be expanded in
// parallel or asynchronously
std::vector<ExpandEntryT> Pop() {
if (queue_.empty()) return {};
// Return a single entry for loss guided mode
if (policy_ == TrainParam::kLossGuide) {
if (param_.grow_policy == TrainParam::kLossGuide) {
ExpandEntryT e = queue_.top();
queue_.pop();
return {e};

if (e.IsValid(param_, num_leaves_)) {
num_leaves_++;
return {e};
} else {
return {};
}
}
// Return nodes on same level for depth wise
std::vector<ExpandEntryT> result;
ExpandEntryT e = queue_.top();
int level = e.depth;
while (e.depth == level && !queue_.empty()) {
while (e.depth == level && !queue_.empty() && result.size() < max_node_batch_size_) {
queue_.pop();
result.emplace_back(e);
if (e.IsValid(param_, num_leaves_)) {
num_leaves_++;
result.emplace_back(e);
}

if (!queue_.empty()) {
e = queue_.top();
}
Expand All @@ -81,7 +100,9 @@ class Driver {
}

private:
TrainParam::TreeGrowPolicy policy_;
TrainParam param_;
std::size_t num_leaves_ = 1;
std::size_t max_node_batch_size_;
ExpandQueue queue_;
};
} // namespace tree
Expand Down
2 changes: 1 addition & 1 deletion src/tree/gpu_hist/evaluate_splits.cuh
Expand Up @@ -103,7 +103,7 @@ class GPUHistEvaluator {
}

/**
* \brief Get sorted index storage based on the left node of inputs .
* \brief Get sorted index storage based on the left node of inputs.
*/
auto SortedIdx(EvaluateSplitInputs<GradientSumT> left) {
if (left.nidx == RegTree::kRoot && !cat_sorted_idx_.empty()) {
Expand Down
9 changes: 0 additions & 9 deletions src/tree/gpu_hist/histogram.cu
Expand Up @@ -247,15 +247,6 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix,
dh::safe_cuda(cudaGetLastError());
}

template void BuildGradientHistogram<GradientPair>(
EllpackDeviceAccessor const& matrix,
FeatureGroupsAccessor const& feature_groups,
common::Span<GradientPair const> gpair,
common::Span<const uint32_t> ridx,
common::Span<GradientPair> histogram,
HistRounding<GradientPair> rounding,
bool force_global_memory);

template void BuildGradientHistogram<GradientPairPrecise>(
EllpackDeviceAccessor const& matrix,
FeatureGroupsAccessor const& feature_groups,
Expand Down
10 changes: 2 additions & 8 deletions src/tree/updater_approx.cc
Expand Up @@ -179,10 +179,9 @@ class GloablApproxBuilder {
p_last_tree_ = p_tree;
this->InitData(p_fmat, hess);

Driver<CPUExpandEntry> driver(static_cast<TrainParam::TreeGrowPolicy>(param_.grow_policy));
Driver<CPUExpandEntry> driver(param_);
auto &tree = *p_tree;
driver.Push({this->InitRoot(p_fmat, gpair, hess, p_tree)});
bst_node_t num_leaves{1};
auto expand_set = driver.Pop();

/**
Expand All @@ -201,14 +200,9 @@ class GloablApproxBuilder {
// candidates that can be applied.
std::vector<CPUExpandEntry> applied;
for (auto const &candidate : expand_set) {
if (!candidate.IsValid(param_, num_leaves)) {
continue;
}
evaluator_.ApplyTreeSplit(candidate, p_tree);
applied.push_back(candidate);
num_leaves++;
int left_child_nidx = tree[candidate.nid].LeftChild();
if (CPUExpandEntry::ChildIsValid(param_, p_tree->GetDepth(left_child_nidx), num_leaves)) {
if (driver.IsChildValid(candidate)) {
valid_candidates.emplace_back(candidate);
}
}
Expand Down

0 comments on commit 71d3b2e

Please sign in to comment.