diff --git a/src/tree/driver.h b/src/tree/driver.h index abb8afadcb8a..0aef93ccf9cd 100644 --- a/src/tree/driver.h +++ b/src/tree/driver.h @@ -33,10 +33,11 @@ class Driver { std::function>; public: - explicit Driver(TrainParam::TreeGrowPolicy policy) - : policy_(policy), - queue_(policy == TrainParam::kDepthWise ? DepthWise : - LossGuide) {} + explicit Driver(TrainParam param, std::size_t max_node_batch_size = 256) + : param_(param), + max_node_batch_size_(max_node_batch_size), + queue_(param.grow_policy == TrainParam::kDepthWise ? DepthWise + : LossGuide) {} template void Push(EntryIterT begin, EntryIterT end) { for (auto it = begin; it != end; ++it) { @@ -55,24 +56,42 @@ class Driver { return queue_.empty(); } + // Can a child of this entry still be expanded? + // can be used to avoid extra work + bool IsChildValid(ExpandEntryT const& parent_entry) { + if (param_.max_depth > 0 && parent_entry.depth + 1 >= param_.max_depth) return false; + if (param_.max_leaves > 0 && num_leaves_ >= param_.max_leaves) return false; + return true; + } + // Return the set of nodes to be expanded // This set has no dependencies between entries so they may be expanded in // parallel or asynchronously std::vector Pop() { if (queue_.empty()) return {}; // Return a single entry for loss guided mode - if (policy_ == TrainParam::kLossGuide) { + if (param_.grow_policy == TrainParam::kLossGuide) { ExpandEntryT e = queue_.top(); queue_.pop(); - return {e}; + + if (e.IsValid(param_, num_leaves_)) { + num_leaves_++; + return {e}; + } else { + return {}; + } } // Return nodes on same level for depth wise std::vector result; ExpandEntryT e = queue_.top(); int level = e.depth; - while (e.depth == level && !queue_.empty()) { + while (e.depth == level && !queue_.empty() && result.size() < max_node_batch_size_) { queue_.pop(); - result.emplace_back(e); + if (e.IsValid(param_, num_leaves_)) { + num_leaves_++; + result.emplace_back(e); + } + if (!queue_.empty()) { e = queue_.top(); } @@ -81,7 +100,9 @@ class Driver { } private: - TrainParam::TreeGrowPolicy policy_; + TrainParam param_; + std::size_t num_leaves_ = 1; + std::size_t max_node_batch_size_; ExpandQueue queue_; }; } // namespace tree diff --git a/src/tree/gpu_hist/evaluate_splits.cuh b/src/tree/gpu_hist/evaluate_splits.cuh index 8d5cc809a280..08b0270ee4d7 100644 --- a/src/tree/gpu_hist/evaluate_splits.cuh +++ b/src/tree/gpu_hist/evaluate_splits.cuh @@ -103,7 +103,7 @@ class GPUHistEvaluator { } /** - * \brief Get sorted index storage based on the left node of inputs . + * \brief Get sorted index storage based on the left node of inputs. */ auto SortedIdx(EvaluateSplitInputs left) { if (left.nidx == RegTree::kRoot && !cat_sorted_idx_.empty()) { diff --git a/src/tree/gpu_hist/histogram.cu b/src/tree/gpu_hist/histogram.cu index 791363a05cdd..efb08d5e44e2 100644 --- a/src/tree/gpu_hist/histogram.cu +++ b/src/tree/gpu_hist/histogram.cu @@ -247,15 +247,6 @@ void BuildGradientHistogram(EllpackDeviceAccessor const& matrix, dh::safe_cuda(cudaGetLastError()); } -template void BuildGradientHistogram( - EllpackDeviceAccessor const& matrix, - FeatureGroupsAccessor const& feature_groups, - common::Span gpair, - common::Span ridx, - common::Span histogram, - HistRounding rounding, - bool force_global_memory); - template void BuildGradientHistogram( EllpackDeviceAccessor const& matrix, FeatureGroupsAccessor const& feature_groups, diff --git a/src/tree/updater_approx.cc b/src/tree/updater_approx.cc index 88a2cfadf256..376798ca1577 100644 --- a/src/tree/updater_approx.cc +++ b/src/tree/updater_approx.cc @@ -179,10 +179,9 @@ class GloablApproxBuilder { p_last_tree_ = p_tree; this->InitData(p_fmat, hess); - Driver driver(static_cast(param_.grow_policy)); + Driver driver(param_); auto &tree = *p_tree; driver.Push({this->InitRoot(p_fmat, gpair, hess, p_tree)}); - bst_node_t num_leaves{1}; auto expand_set = driver.Pop(); /** @@ -201,14 +200,9 @@ class GloablApproxBuilder { // candidates that can be applied. std::vector applied; for (auto const &candidate : expand_set) { - if (!candidate.IsValid(param_, num_leaves)) { - continue; - } evaluator_.ApplyTreeSplit(candidate, p_tree); applied.push_back(candidate); - num_leaves++; - int left_child_nidx = tree[candidate.nid].LeftChild(); - if (CPUExpandEntry::ChildIsValid(param_, p_tree->GetDepth(left_child_nidx), num_leaves)) { + if (driver.IsChildValid(candidate)) { valid_candidates.emplace_back(candidate); } } diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 20db181ef187..88978142ee2e 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -62,7 +62,7 @@ DMLC_REGISTER_PARAMETER(GPUHistMakerTrainParam); #endif // !defined(GTEST_TEST) /** - * \struct DeviceHistogram + * \struct DeviceHistogramStorage * * \summary Data storage for node histograms on device. Automatically expands. * @@ -72,20 +72,27 @@ DMLC_REGISTER_PARAMETER(GPUHistMakerTrainParam); * \author Rory * \date 28/07/2018 */ -template -class DeviceHistogram { +template +class DeviceHistogramStorage { private: /*! \brief Map nidx to starting index of its histogram. */ std::map nidx_map_; + // Large buffer of zeroed memory, caches histograms dh::device_vector data_; + // If we run out of storage allocate one histogram at a time + // in overflow. Not cached, overwritten when a new histogram + // is requested + dh::device_vector overflow_; + std::map overflow_nidx_map_; int n_bins_; int device_id_; static constexpr size_t kNumItemsInGradientSum = sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT); - static_assert(kNumItemsInGradientSum == 2, - "Number of items in gradient type should be 2."); + static_assert(kNumItemsInGradientSum == 2, "Number of items in gradient type should be 2."); public: + // Start with about 16mb + DeviceHistogramStorage() { data_.reserve(1 << 22); } void Init(int device_id, int n_bins) { this->n_bins_ = n_bins; this->device_id_ = device_id; @@ -93,52 +100,47 @@ class DeviceHistogram { void Reset() { auto d_data = data_.data().get(); - dh::LaunchN(data_.size(), - [=] __device__(size_t idx) { d_data[idx] = 0.0f; }); + dh::LaunchN(data_.size(), [=] __device__(size_t idx) { d_data[idx] = 0.0f; }); nidx_map_.clear(); + overflow_nidx_map_.clear(); } bool HistogramExists(int nidx) const { - return nidx_map_.find(nidx) != nidx_map_.cend(); - } - int Bins() const { - return n_bins_; - } - size_t HistogramSize() const { - return n_bins_ * kNumItemsInGradientSum; - } - - dh::device_vector& Data() { - return data_; + return nidx_map_.find(nidx) != nidx_map_.cend() || + overflow_nidx_map_.find(nidx) != overflow_nidx_map_.cend(); } + int Bins() const { return n_bins_; } + size_t HistogramSize() const { return n_bins_ * kNumItemsInGradientSum; } + dh::device_vector& Data() { return data_; } - void AllocateHistogram(int nidx) { - if (HistogramExists(nidx)) return; + void AllocateHistograms(const std::vector& new_nidxs) { + for (int nidx : new_nidxs) { + CHECK(!HistogramExists(nidx)); + } // Number of items currently used in data const size_t used_size = nidx_map_.size() * HistogramSize(); - const size_t new_used_size = used_size + HistogramSize(); - if (data_.size() >= kStopGrowingSize) { - // Recycle histogram memory - if (new_used_size <= data_.size()) { - // no need to remove old node, just insert the new one. - nidx_map_[nidx] = used_size; - // memset histogram size in bytes - } else { - std::pair old_entry = *nidx_map_.begin(); - nidx_map_.erase(old_entry.first); - nidx_map_[nidx] = old_entry.second; + const size_t new_used_size = used_size + HistogramSize() * new_nidxs.size(); + if (used_size >= kStopGrowingSize) { + // Use overflow + // Delete previous entries + overflow_nidx_map_.clear(); + overflow_.resize(HistogramSize() * new_nidxs.size()); + // Zero memory + auto d_data = overflow_.data().get(); + dh::LaunchN(overflow_.size(), + [=] __device__(size_t idx) { d_data[idx] = 0.0; }); + // Append new histograms + for (int nidx : new_nidxs) { + overflow_nidx_map_[nidx] = overflow_nidx_map_.size() * HistogramSize(); } - // Zero recycled memory - auto d_data = data_.data().get() + nidx_map_[nidx]; - dh::LaunchN(n_bins_ * 2, - [=] __device__(size_t idx) { d_data[idx] = 0.0f; }); } else { - // Append new node histogram - nidx_map_[nidx] = used_size; - // Check there is enough memory for another histogram node - if (data_.size() < new_used_size + HistogramSize()) { - size_t new_required_memory = - std::max(data_.size() * 2, HistogramSize()); - data_.resize(new_required_memory); + CHECK_GE(data_.size(), used_size); + // Expand if necessary + if (data_.size() < new_used_size) { + data_.resize(std::max(data_.size() * 2, new_used_size)); + } + // Append new histograms + for (int nidx : new_nidxs) { + nidx_map_[nidx] = nidx_map_.size() * HistogramSize(); } } @@ -152,9 +154,16 @@ class DeviceHistogram { */ common::Span GetNodeHistogram(int nidx) { CHECK(this->HistogramExists(nidx)); - auto ptr = data_.data().get() + nidx_map_.at(nidx); - return common::Span( - reinterpret_cast(ptr), n_bins_); + + if (nidx_map_.find(nidx) != nidx_map_.cend()) { + // Fetch from normal cache + auto ptr = data_.data().get() + nidx_map_.at(nidx); + return common::Span(reinterpret_cast(ptr), n_bins_); + } else { + // Fetch from overflow + auto ptr = overflow_.data().get() + overflow_nidx_map_.at(nidx); + return common::Span(reinterpret_cast(ptr), n_bins_); + } } }; @@ -171,7 +180,7 @@ struct GPUHistMakerDevice { BatchParam batch_param; std::unique_ptr row_partitioner; - DeviceHistogram hist{}; + DeviceHistogramStorage hist{}; dh::caching_device_vector d_gpair; // storage for gpair; common::Span gpair; @@ -195,6 +204,7 @@ struct GPUHistMakerDevice { std::unique_ptr feature_groups; + GPUHistMakerDevice(Context const* ctx, EllpackPageImpl const* _page, common::Span _feature_types, bst_uint _n_rows, TrainParam _param, uint32_t column_sampler_seed, uint32_t n_features, @@ -322,7 +332,6 @@ struct GPUHistMakerDevice { } void BuildHist(int nidx) { - hist.AllocateHistogram(nidx); auto d_node_hist = hist.GetNodeHistogram(nidx); auto d_ridx = row_partitioner->GetRows(nidx); BuildGradientHistogram(page->GetDeviceAccessor(ctx_->gpu_id), @@ -330,8 +339,12 @@ struct GPUHistMakerDevice { d_ridx, d_node_hist, histogram_rounding); } - void SubtractionTrick(int nidx_parent, int nidx_histogram, - int nidx_subtraction) { + // Attempt to do subtraction trick + // return true if succeeded + bool SubtractionTrick(int nidx_parent, int nidx_histogram, int nidx_subtraction) { + if (!hist.HistogramExists(nidx_histogram) || !hist.HistogramExists(nidx_parent)) { + return false; + } auto d_node_hist_parent = hist.GetNodeHistogram(nidx_parent); auto d_node_hist_histogram = hist.GetNodeHistogram(nidx_histogram); auto d_node_hist_subtraction = hist.GetNodeHistogram(nidx_subtraction); @@ -340,12 +353,7 @@ struct GPUHistMakerDevice { d_node_hist_subtraction[idx] = d_node_hist_parent[idx] - d_node_hist_histogram[idx]; }); - } - - bool CanDoSubtractionTrick(int nidx_parent, int nidx_histogram, int nidx_subtraction) { - // Make sure histograms are already allocated - hist.AllocateHistogram(nidx_subtraction); - return hist.HistogramExists(nidx_histogram) && hist.HistogramExists(nidx_parent); + return true; } void UpdatePosition(const GPUExpandEntry &e, RegTree* p_tree) { @@ -505,13 +513,15 @@ struct GPUHistMakerDevice { row_partitioner.reset(); } - void AllReduceHist(int nidx, dh::AllReducer* reducer) { + // num histograms is the number of contiguous histograms in memory to reduce over + void AllReduceHist(int nidx, dh::AllReducer* reducer, int num_histograms) { monitor.Start("AllReduce"); auto d_node_hist = hist.GetNodeHistogram(nidx).data(); - reducer->AllReduceSum( - reinterpret_cast(d_node_hist), - reinterpret_cast(d_node_hist), - page->Cuts().TotalBins() * (sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT))); + reducer->AllReduceSum(reinterpret_cast(d_node_hist), + reinterpret_cast(d_node_hist), + page->Cuts().TotalBins() * + (sizeof(GradientSumT) / sizeof(typename GradientSumT::ValueT)) * + num_histograms); monitor.Stop("AllReduce"); } @@ -519,33 +529,50 @@ struct GPUHistMakerDevice { /** * \brief Build GPU local histograms for the left and right child of some parent node */ - void BuildHistLeftRight(const GPUExpandEntry &candidate, int nidx_left, - int nidx_right, dh::AllReducer* reducer) { - auto build_hist_nidx = nidx_left; - auto subtraction_trick_nidx = nidx_right; - - // Decide whether to build the left histogram or right histogram - // Use sum of Hessian as a heuristic to select node with fewest training instances - bool fewer_right = candidate.split.right_sum.GetHess() < candidate.split.left_sum.GetHess(); - if (fewer_right) { - std::swap(build_hist_nidx, subtraction_trick_nidx); + void BuildHistLeftRight(std::vector const& candidates, dh::AllReducer* reducer, + const RegTree& tree) { + if (candidates.empty()) return; + // Some nodes we will manually compute histograms + // others we will do by subtraction + std::vector hist_nidx; + std::vector subtraction_nidx; + for (auto& e : candidates) { + // Decide whether to build the left histogram or right histogram + // Use sum of Hessian as a heuristic to select node with fewest training instances + bool fewer_right = e.split.right_sum.GetHess() < e.split.left_sum.GetHess(); + if (fewer_right) { + hist_nidx.emplace_back(tree[e.nid].RightChild()); + subtraction_nidx.emplace_back(tree[e.nid].LeftChild()); + } else { + hist_nidx.emplace_back(tree[e.nid].LeftChild()); + subtraction_nidx.emplace_back(tree[e.nid].RightChild()); + } + } + std::vector all_new = hist_nidx; + all_new.insert(all_new.end(), subtraction_nidx.begin(), subtraction_nidx.end()); + // Allocate the histograms + // Guaranteed contiguous memory + hist.AllocateHistograms(all_new); + + for (auto nidx : hist_nidx) { + this->BuildHist(nidx); } - this->BuildHist(build_hist_nidx); - this->AllReduceHist(build_hist_nidx, reducer); + // Reduce all in one go + // This gives much better latency in a distributed setting + // when processing a large batch + this->AllReduceHist(hist_nidx.at(0), reducer, hist_nidx.size()); - // Check whether we can use the subtraction trick to calculate the other - bool do_subtraction_trick = this->CanDoSubtractionTrick( - candidate.nid, build_hist_nidx, subtraction_trick_nidx); + for (int i = 0; i < subtraction_nidx.size(); i++) { + auto build_hist_nidx = hist_nidx.at(i); + auto subtraction_trick_nidx = subtraction_nidx.at(i); + auto parent_nidx = candidates.at(i).nid; - if (do_subtraction_trick) { - // Calculate other histogram using subtraction trick - this->SubtractionTrick(candidate.nid, build_hist_nidx, - subtraction_trick_nidx); - } else { - // Calculate other histogram manually - this->BuildHist(subtraction_trick_nidx); - this->AllReduceHist(subtraction_trick_nidx, reducer); + if (!this->SubtractionTrick(parent_nidx, build_hist_nidx, subtraction_trick_nidx)) { + // Calculate other histogram manually + this->BuildHist(subtraction_trick_nidx); + this->AllReduceHist(subtraction_trick_nidx, reducer, 1); + } } } @@ -605,8 +632,9 @@ struct GPUHistMakerDevice { GradientPairPrecise{}, thrust::plus{}); rabit::Allreduce(reinterpret_cast(&root_sum), 2); + hist.AllocateHistograms({kRootNIdx}); this->BuildHist(kRootNIdx); - this->AllReduceHist(kRootNIdx, reducer); + this->AllReduceHist(kRootNIdx, reducer, 1); // Remember root stats node_sum_gradients[kRootNIdx] = root_sum; @@ -624,7 +652,8 @@ struct GPUHistMakerDevice { RegTree* p_tree, dh::AllReducer* reducer, HostDeviceVector* p_out_position) { auto& tree = *p_tree; - Driver driver(static_cast(param.grow_policy)); + // Process maximum 32 nodes at a time + Driver driver(param, 32); monitor.Start("Reset"); this->Reset(gpair_all, p_fmat, p_fmat->Info().num_col_); @@ -634,48 +663,44 @@ struct GPUHistMakerDevice { driver.Push({ this->InitRoot(p_tree, reducer) }); monitor.Stop("InitRoot"); - auto num_leaves = 1; - // The set of leaves that can be expanded asynchronously auto expand_set = driver.Pop(); while (!expand_set.empty()) { - auto new_candidates = - pinned.GetSpan(expand_set.size() * 2, GPUExpandEntry()); - - for (auto i = 0ull; i < expand_set.size(); i++) { - auto candidate = expand_set.at(i); - if (!candidate.IsValid(param, num_leaves)) { - continue; - } + for (auto& candidate : expand_set) { this->ApplySplit(candidate, p_tree); + } + // Get the candidates we are allowed to expand further + // e.g. We do not bother further processing nodes whose children are beyond max depth + std::vector filtered_expand_set; + std::copy_if(expand_set.begin(), expand_set.end(), std::back_inserter(filtered_expand_set), + [&](const auto& e) { return driver.IsChildValid(e); }); + + + auto new_candidates = + pinned.GetSpan(filtered_expand_set.size() * 2, GPUExpandEntry()); + + for (const auto& e : filtered_expand_set) { + monitor.Start("UpdatePosition"); + // Update position is only run when child is valid, instead of right after apply + // split (as in approx tree method). Hense we have the finalise position call + // in GPU Hist. + this->UpdatePosition(e, p_tree); + monitor.Stop("UpdatePosition"); + } - num_leaves++; + monitor.Start("BuildHist"); + this->BuildHistLeftRight(filtered_expand_set, reducer, tree); + monitor.Stop("BuildHist"); + for (auto i = 0ull; i < filtered_expand_set.size(); i++) { + auto candidate = filtered_expand_set.at(i); int left_child_nidx = tree[candidate.nid].LeftChild(); int right_child_nidx = tree[candidate.nid].RightChild(); - // Only create child entries if needed_ - if (GPUExpandEntry::ChildIsValid(param, tree.GetDepth(left_child_nidx), - num_leaves)) { - monitor.Start("UpdatePosition"); - // Update position is only run when child is valid, instead of right after apply - // split (as in approx tree method). Hense we have the finalise position call - // in GPU Hist. - this->UpdatePosition(candidate, p_tree); - monitor.Stop("UpdatePosition"); - - monitor.Start("BuildHist"); - this->BuildHistLeftRight(candidate, left_child_nidx, right_child_nidx, reducer); - monitor.Stop("BuildHist"); - - monitor.Start("EvaluateSplits"); - this->EvaluateLeftRightSplits(candidate, left_child_nidx, right_child_nidx, *p_tree, - new_candidates.subspan(i * 2, 2)); - monitor.Stop("EvaluateSplits"); - } else { - // Set default - new_candidates[i * 2] = GPUExpandEntry(); - new_candidates[i * 2 + 1] = GPUExpandEntry(); - } + + monitor.Start("EvaluateSplits"); + this->EvaluateLeftRightSplits(candidate, left_child_nidx, right_child_nidx, *p_tree, + new_candidates.subspan(i * 2, 2)); + monitor.Stop("EvaluateSplits"); } dh::DefaultStream().Sync(); driver.Push(new_candidates.begin(), new_candidates.end()); diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index af7dad37fe39..ba02983a428a 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -175,10 +175,9 @@ void QuantileHistMaker::Builder::ExpandTree(DMatrix *p_fmat, RegTree *p_tree, HostDeviceVector *p_out_position) { monitor_->Start(__func__); - Driver driver(static_cast(param_.grow_policy)); + Driver driver(param_); driver.Push(this->InitRoot(p_fmat, p_tree, gpair_h)); auto const &tree = *p_tree; - bst_node_t num_leaves{1}; auto expand_set = driver.Pop(); while (!expand_set.empty()) { @@ -188,13 +187,9 @@ void QuantileHistMaker::Builder::ExpandTree(DMatrix *p_fmat, RegTree *p_tree, std::vector applied; int32_t depth = expand_set.front().depth + 1; for (auto const& candidate : expand_set) { - if (!candidate.IsValid(param_, num_leaves)) { - continue; - } evaluator_->ApplyTreeSplit(candidate, p_tree); applied.push_back(candidate); - num_leaves++; - if (CPUExpandEntry::ChildIsValid(param_, depth, num_leaves)) { + if (driver.IsChildValid(candidate)) { valid_candidates.emplace_back(candidate); } } diff --git a/tests/cpp/tree/gpu_hist/test_driver.cu b/tests/cpp/tree/gpu_hist/test_driver.cu index d35f3510f628..8e7164e37bec 100644 --- a/tests/cpp/tree/gpu_hist/test_driver.cu +++ b/tests/cpp/tree/gpu_hist/test_driver.cu @@ -6,41 +6,58 @@ namespace xgboost { namespace tree { TEST(GpuHist, DriverDepthWise) { - Driver driver(TrainParam::kDepthWise); + TrainParam p; + p.InitAllowUnknown(Args{}); + p.grow_policy = TrainParam::kDepthWise; + Driver driver(p, 2); EXPECT_TRUE(driver.Pop().empty()); DeviceSplitCandidate split; split.loss_chg = 1.0f; - GPUExpandEntry root(0, 0, split, .0f, .0f, .0f); + split.left_sum = {0.0f, 1.0f}; + split.right_sum = {0.0f, 1.0f}; + GPUExpandEntry root(0, 0, split, 2.0f, 1.0f, 1.0f); driver.Push({root}); EXPECT_EQ(driver.Pop().front().nid, 0); - driver.Push({GPUExpandEntry{1, 1, split, .0f, .0f, .0f}}); - driver.Push({GPUExpandEntry{2, 1, split, .0f, .0f, .0f}}); - driver.Push({GPUExpandEntry{3, 2, split, .0f, .0f, .0f}}); - // Should return entries from level 1 + driver.Push({GPUExpandEntry{1, 1, split, 2.0f, 1.0f, 1.0f}}); + driver.Push({GPUExpandEntry{2, 1, split, 2.0f, 1.0f, 1.0f}}); + driver.Push({GPUExpandEntry{3, 1, split, 2.0f, 1.0f, 1.0f}}); + driver.Push({GPUExpandEntry{4, 2, split, 2.0f, 1.0f, 1.0f}}); + // Should return 2 entries from level 1 + // as we limited the driver to pop maximum 2 nodes auto res = driver.Pop(); EXPECT_EQ(res.size(), 2); for (auto &e : res) { EXPECT_EQ(e.depth, 1); } + + // Should now return 1 entry from level 1 + res = driver.Pop(); + EXPECT_EQ(res.size(), 1); + EXPECT_EQ(res.at(0).depth, 1); + res = driver.Pop(); - EXPECT_EQ(res[0].depth, 2); + EXPECT_EQ(res.at(0).depth, 2); EXPECT_TRUE(driver.Pop().empty()); } TEST(GpuHist, DriverLossGuided) { DeviceSplitCandidate high_gain; + high_gain.left_sum = {0.0f, 1.0f}; + high_gain.right_sum = {0.0f, 1.0f}; high_gain.loss_chg = 5.0f; - DeviceSplitCandidate low_gain; + DeviceSplitCandidate low_gain = high_gain; low_gain.loss_chg = 1.0f; - Driver driver(TrainParam::kLossGuide); + TrainParam p; + p.grow_policy=TrainParam::kLossGuide; + Driver driver(p); EXPECT_TRUE(driver.Pop().empty()); - GPUExpandEntry root(0, 0, high_gain, .0f, .0f, .0f); + GPUExpandEntry root(0, 0, high_gain, 2.0f, 1.0f, 1.0f ); driver.Push({root}); EXPECT_EQ(driver.Pop().front().nid, 0); // Select high gain first - driver.Push({GPUExpandEntry{1, 1, low_gain, .0f, .0f, .0f}}); - driver.Push({GPUExpandEntry{2, 2, high_gain, .0f, .0f, .0f}}); + driver.Push({GPUExpandEntry{1, 1, low_gain, 2.0f, 1.0f, 1.0f}}); + driver.Push({GPUExpandEntry{2, 2, high_gain, 2.0f, 1.0f, 1.0f}}); auto res = driver.Pop(); EXPECT_EQ(res.size(), 1); EXPECT_EQ(res[0].nid, 2); @@ -49,8 +66,8 @@ TEST(GpuHist, DriverLossGuided) { EXPECT_EQ(res[0].nid, 1); // If equal gain, use nid - driver.Push({GPUExpandEntry{2, 1, low_gain, .0f, .0f, .0f}}); - driver.Push({GPUExpandEntry{1, 1, low_gain, .0f, .0f, .0f}}); + driver.Push({GPUExpandEntry{2, 1, low_gain, 2.0f, 1.0f, 1.0f}}); + driver.Push({GPUExpandEntry{1, 1, low_gain, 2.0f, 1.0f, 1.0f}}); res = driver.Pop(); EXPECT_EQ(res[0].nid, 1); res = driver.Pop(); diff --git a/tests/cpp/tree/gpu_hist/test_histogram.cu b/tests/cpp/tree/gpu_hist/test_histogram.cu index 3b543a48d7cc..75d97b681a61 100644 --- a/tests/cpp/tree/gpu_hist/test_histogram.cu +++ b/tests/cpp/tree/gpu_hist/test_histogram.cu @@ -95,7 +95,6 @@ TEST(Histogram, GPUDeterministic) { std::vector shm_sizes{48 * 1024, 64 * 1024, 160 * 1024}; for (bool is_dense : is_dense_array) { for (int shm_size : shm_sizes) { - TestDeterministicHistogram(is_dense, shm_size); TestDeterministicHistogram(is_dense, shm_size); } } diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index b3c08736c996..e6069cdfdd4d 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -27,31 +27,40 @@ TEST(GpuHist, DeviceHistogram) { // Ensures that node allocates correctly after reaching `kStopGrowingSize`. dh::safe_cuda(cudaSetDevice(0)); constexpr size_t kNBins = 128; - constexpr size_t kNNodes = 4; + constexpr int kNNodes = 4; constexpr size_t kStopGrowing = kNNodes * kNBins * 2u; - DeviceHistogram histogram; + DeviceHistogramStorage histogram; histogram.Init(0, kNBins); - for (size_t i = 0; i < kNNodes; ++i) { - histogram.AllocateHistogram(i); + for (int i = 0; i < kNNodes; ++i) { + histogram.AllocateHistograms({i}); } histogram.Reset(); ASSERT_EQ(histogram.Data().size(), kStopGrowing); // Use allocated memory but do not erase nidx_map. - for (size_t i = 0; i < kNNodes; ++i) { - histogram.AllocateHistogram(i); + for (int i = 0; i < kNNodes; ++i) { + histogram.AllocateHistograms({i}); } - for (size_t i = 0; i < kNNodes; ++i) { + for (int i = 0; i < kNNodes; ++i) { ASSERT_TRUE(histogram.HistogramExists(i)); } - // Erase existing nidx_map. - for (size_t i = kNNodes; i < kNNodes * 2; ++i) { - histogram.AllocateHistogram(i); - } - for (size_t i = 0; i < kNNodes; ++i) { - ASSERT_FALSE(histogram.HistogramExists(i)); + // Add two new nodes + histogram.AllocateHistograms({kNNodes}); + histogram.AllocateHistograms({kNNodes + 1}); + + // Old cached nodes should still exist + for (int i = 0; i < kNNodes; ++i) { + ASSERT_TRUE(histogram.HistogramExists(i)); } + + // Should be deleted + ASSERT_FALSE(histogram.HistogramExists(kNNodes)); + // Most recent node should exist + ASSERT_TRUE(histogram.HistogramExists(kNNodes + 1)); + + // Add same node again - should fail + EXPECT_ANY_THROW(histogram.AllocateHistograms({kNNodes + 1});); } std::vector GetHostHistGpair() { @@ -96,9 +105,9 @@ void TestBuildHist(bool use_shared_memory_histograms) { thrust::host_vector h_gidx_buffer (page->gidx_buffer.HostVector()); maker.row_partitioner.reset(new RowPartitioner(0, kNRows)); - maker.hist.AllocateHistogram(0); + maker.hist.AllocateHistograms({0}); maker.gpair = gpair.DeviceSpan(); - maker.histogram_rounding = CreateRoundingFactor(maker.gpair);; + maker.histogram_rounding = CreateRoundingFactor(maker.gpair); BuildGradientHistogram( page->GetDeviceAccessor(0), maker.feature_groups->DeviceAccessor(0), @@ -106,7 +115,7 @@ void TestBuildHist(bool use_shared_memory_histograms) { maker.hist.GetNodeHistogram(0), maker.histogram_rounding, !use_shared_memory_histograms); - DeviceHistogram& d_hist = maker.hist; + DeviceHistogramStorage& d_hist = maker.hist; auto node_histogram = d_hist.GetNodeHistogram(0); // d_hist.data stored in float, not gradient pair @@ -129,12 +138,10 @@ void TestBuildHist(bool use_shared_memory_histograms) { TEST(GpuHist, BuildHistGlobalMem) { TestBuildHist(false); - TestBuildHist(false); } TEST(GpuHist, BuildHistSharedMem) { TestBuildHist(true); - TestBuildHist(true); } HistogramCutsWrapper GetHostCutMatrix () { @@ -198,7 +205,7 @@ TEST(GpuHist, EvaluateRootSplit) { // Initialize GPUHistMakerDevice::hist maker.hist.Init(0, (max_bins - 1) * kNCols); - maker.hist.AllocateHistogram(0); + maker.hist.AllocateHistograms({0}); // Each row of hist_gpair represents gpairs for one feature. // Each entry represents a bin. std::vector hist_gpair = GetHostHistGpair();