diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 123dc14e57be..ccec859a286c 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -1939,4 +1939,25 @@ class CUDAStream { CUDAStreamView View() const { return CUDAStreamView{stream_}; } void Sync() { this->View().Sync(); } }; + +// Force nvcc to load data as constant +template +class LDGIterator { + using DeviceWordT = typename cub::UnitWord::DeviceWord; + static constexpr std::size_t kNumWords = sizeof(T) / sizeof(DeviceWordT); + + const T *ptr_; + + public: + explicit LDGIterator(const T *ptr) : ptr_(ptr) {} + __device__ T operator[](std::size_t idx) const { + DeviceWordT tmp[kNumWords]; + static_assert(sizeof(tmp) == sizeof(T), "Expect sizes to be equal."); +#pragma unroll + for (int i = 0; i < kNumWords; i++) { + tmp[i] = __ldg(reinterpret_cast(ptr_ + idx) + i); + } + return *reinterpret_cast(tmp); + } +}; } // namespace dh diff --git a/src/tree/gpu_hist/row_partitioner.cu b/src/tree/gpu_hist/row_partitioner.cu index 44b962a96e5a..015d817f3640 100644 --- a/src/tree/gpu_hist/row_partitioner.cu +++ b/src/tree/gpu_hist/row_partitioner.cu @@ -1,174 +1,46 @@ /*! - * Copyright 2017-2021 XGBoost contributors + * Copyright 2017-2022 XGBoost contributors */ #include #include #include + #include + #include "../../common/device_helpers.cuh" #include "row_partitioner.cuh" namespace xgboost { namespace tree { -struct IndexFlagTuple { - size_t idx; - size_t flag; -}; - -struct IndexFlagOp { - __device__ IndexFlagTuple operator()(const IndexFlagTuple& a, - const IndexFlagTuple& b) const { - return {b.idx, a.flag + b.flag}; - } -}; - -struct WriteResultsFunctor { - bst_node_t left_nidx; - common::Span position_in; - common::Span position_out; - common::Span ridx_in; - common::Span ridx_out; - int64_t* d_left_count; - - __device__ IndexFlagTuple operator()(const IndexFlagTuple& x) { - // the ex_scan_result represents how many rows have been assigned to left - // node so far during scan. - int scatter_address; - if (position_in[x.idx] == left_nidx) { - scatter_address = x.flag - 1; // -1 because inclusive scan - } else { - // current number of rows belong to right node + total number of rows - // belong to left node - scatter_address = (x.idx - x.flag) + *d_left_count; - } - // copy the node id to output - position_out[scatter_address] = position_in[x.idx]; - ridx_out[scatter_address] = ridx_in[x.idx]; - - // Discard - return {}; - } -}; - -// Implement partitioning via single scan operation using transform output to -// write the result -void RowPartitioner::SortPosition(common::Span position, - common::Span position_out, - common::Span ridx, - common::Span ridx_out, - bst_node_t left_nidx, bst_node_t, - int64_t* d_left_count, cudaStream_t stream) { - WriteResultsFunctor write_results{left_nidx, position, position_out, - ridx, ridx_out, d_left_count}; - auto discard_write_iterator = - thrust::make_transform_output_iterator(dh::TypedDiscard(), write_results); - auto counting = thrust::make_counting_iterator(0llu); - auto input_iterator = dh::MakeTransformIterator( - counting, [=] __device__(size_t idx) { - return IndexFlagTuple{idx, static_cast(position[idx] == left_nidx)}; - }); - size_t temp_bytes = 0; - cub::DeviceScan::InclusiveScan(nullptr, temp_bytes, input_iterator, - discard_write_iterator, IndexFlagOp(), - position.size(), stream); - dh::TemporaryArray temp(temp_bytes); - cub::DeviceScan::InclusiveScan(temp.data().get(), temp_bytes, input_iterator, - discard_write_iterator, IndexFlagOp(), - position.size(), stream); -} - -void Reset(int device_idx, common::Span ridx, - common::Span position) { - dh::safe_cuda(cudaSetDevice(device_idx)); - CHECK_EQ(ridx.size(), position.size()); - dh::LaunchN(ridx.size(), [=] __device__(size_t idx) { - ridx[idx] = idx; - position[idx] = 0; - }); -} RowPartitioner::RowPartitioner(int device_idx, size_t num_rows) - : device_idx_(device_idx), ridx_a_(num_rows), position_a_(num_rows), - ridx_b_(num_rows), position_b_(num_rows) { + : device_idx_(device_idx), ridx_(num_rows), ridx_tmp_(num_rows) { dh::safe_cuda(cudaSetDevice(device_idx_)); - ridx_ = dh::DoubleBuffer{&ridx_a_, &ridx_b_}; - position_ = dh::DoubleBuffer{&position_a_, &position_b_}; - ridx_segments_.emplace_back(static_cast(0), num_rows); - - Reset(device_idx, ridx_.CurrentSpan(), position_.CurrentSpan()); - left_counts_.resize(256); - thrust::fill(left_counts_.begin(), left_counts_.end(), 0); - streams_.resize(2); - for (auto& stream : streams_) { - dh::safe_cuda(cudaStreamCreate(&stream)); - } + ridx_segments_.emplace_back(NodePositionInfo{Segment(0, num_rows)}); + thrust::sequence(thrust::device, ridx_.data(), ridx_.data() + ridx_.size()); + dh::safe_cuda(cudaStreamCreate(&stream_)); } + RowPartitioner::~RowPartitioner() { dh::safe_cuda(cudaSetDevice(device_idx_)); - for (auto& stream : streams_) { - dh::safe_cuda(cudaStreamDestroy(stream)); - } + dh::safe_cuda(cudaStreamDestroy(stream_)); } -common::Span RowPartitioner::GetRows( - bst_node_t nidx) { - auto segment = ridx_segments_.at(nidx); - // Return empty span here as a valid result - // Will error if we try to construct a span from a pointer with size 0 - if (segment.Size() == 0) { - return {}; - } - return ridx_.CurrentSpan().subspan(segment.begin, segment.Size()); +common::Span RowPartitioner::GetRows(bst_node_t nidx) { + auto segment = ridx_segments_.at(nidx).segment; + return dh::ToSpan(ridx_).subspan(segment.begin, segment.Size()); } common::Span RowPartitioner::GetRows() { - return ridx_.CurrentSpan(); + return dh::ToSpan(ridx_); } -common::Span RowPartitioner::GetPosition() { - return position_.CurrentSpan(); -} -std::vector RowPartitioner::GetRowsHost( - bst_node_t nidx) { +std::vector RowPartitioner::GetRowsHost(bst_node_t nidx) { auto span = GetRows(nidx); std::vector rows(span.size()); dh::CopyDeviceSpanToVector(&rows, span); return rows; } -std::vector RowPartitioner::GetPositionHost() { - auto span = GetPosition(); - std::vector position(span.size()); - dh::CopyDeviceSpanToVector(&position, span); - return position; -} - -void RowPartitioner::SortPositionAndCopy(const Segment& segment, - bst_node_t left_nidx, - bst_node_t right_nidx, - int64_t* d_left_count, - cudaStream_t stream) { - SortPosition( - // position_in - common::Span(position_.Current() + segment.begin, - segment.Size()), - // position_out - common::Span(position_.Other() + segment.begin, - segment.Size()), - // row index in - common::Span(ridx_.Current() + segment.begin, segment.Size()), - // row index out - common::Span(ridx_.Other() + segment.begin, segment.Size()), - left_nidx, right_nidx, d_left_count, stream); - // Copy back key/value - const auto d_position_current = position_.Current() + segment.begin; - const auto d_position_other = position_.Other() + segment.begin; - const auto d_ridx_current = ridx_.Current() + segment.begin; - const auto d_ridx_other = ridx_.Other() + segment.begin; - dh::LaunchN(segment.Size(), stream, [=] __device__(size_t idx) { - d_position_current[idx] = d_position_other[idx]; - d_ridx_current[idx] = d_ridx_other[idx]; - }); -} }; // namespace tree }; // namespace xgboost diff --git a/src/tree/gpu_hist/row_partitioner.cuh b/src/tree/gpu_hist/row_partitioner.cuh index f46fcfcd38a2..4ba0bd27fe2f 100644 --- a/src/tree/gpu_hist/row_partitioner.cuh +++ b/src/tree/gpu_hist/row_partitioner.cuh @@ -2,33 +2,193 @@ * Copyright 2017-2022 XGBoost contributors */ #pragma once +#include + #include #include -#include "xgboost/base.h" + #include "../../common/device_helpers.cuh" +#include "xgboost/base.h" #include "xgboost/generic_parameters.h" #include "xgboost/task.h" #include "xgboost/tree_model.h" namespace xgboost { namespace tree { -/*! \brief Count how many rows are assigned to left node. */ -__forceinline__ __device__ void AtomicIncrement(int64_t* d_count, bool increment) { -#if __CUDACC_VER_MAJOR__ > 8 - int mask = __activemask(); - unsigned ballot = __ballot_sync(mask, increment); - int leader = __ffs(mask) - 1; - if (threadIdx.x % 32 == leader) { - atomicAdd(reinterpret_cast(d_count), // NOLINT - static_cast(__popc(ballot))); // NOLINT + +/** \brief Used to demarcate a contiguous set of row indices associated with + * some tree node. */ +struct Segment { + bst_uint begin{0}; + bst_uint end{0}; + + Segment() = default; + + Segment(bst_uint begin, bst_uint end) : begin(begin), end(end) { CHECK_GE(end, begin); } + __host__ __device__ size_t Size() const { return end - begin; } +}; + +// TODO(Rory): Can be larger. To be tuned alongside other batch operations. +static const int kMaxUpdatePositionBatchSize = 32; +template +struct PerNodeData { + Segment segment; + OpDataT data; +}; + +template +__device__ __forceinline__ void AssignBatch(BatchIterT batch_info, std::size_t global_thread_idx, + int* batch_idx, std::size_t* item_idx) { + bst_uint sum = 0; + for (int i = 0; i < kMaxUpdatePositionBatchSize; i++) { + if (sum + batch_info[i].segment.Size() > global_thread_idx) { + *batch_idx = i; + *item_idx = (global_thread_idx - sum) + batch_info[i].segment.begin; + break; + } + sum += batch_info[i].segment.Size(); + } +} + +template +__global__ __launch_bounds__(kBlockSize) void SortPositionCopyKernel( + dh::LDGIterator> batch_info, common::Span d_ridx, + const common::Span ridx_tmp, std::size_t total_rows) { + for (auto idx : dh::GridStrideRange(0, total_rows)) { + int batch_idx; + std::size_t item_idx; + AssignBatch(batch_info, idx, &batch_idx, &item_idx); + d_ridx[item_idx] = ridx_tmp[item_idx]; + } +} + +// We can scan over this tuple, where the scan gives us information on how to partition inputs +// according to the flag +struct IndexFlagTuple { + bst_uint idx; // The location of the item we are working on in ridx_ + bst_uint flag_scan; // This gets populated after scanning + int batch_idx; // Which node in the batch does this item belong to + bool flag; // Result of op (is this item going left?) +}; + +struct IndexFlagOp { + __device__ IndexFlagTuple operator()(const IndexFlagTuple& a, const IndexFlagTuple& b) const { + // Segmented scan - resets if we cross batch boundaries + if (a.batch_idx == b.batch_idx) { + // Accumulate the flags, everything else stays the same + return {b.idx, a.flag_scan + b.flag_scan, b.batch_idx, b.flag}; + } else { + return b; + } + } +}; + +template +struct WriteResultsFunctor { + dh::LDGIterator> batch_info; + const bst_uint* ridx_in; + bst_uint* ridx_out; + bst_uint* counts; + + __device__ IndexFlagTuple operator()(const IndexFlagTuple& x) { + std::size_t scatter_address; + const Segment& segment = batch_info[x.batch_idx].segment; + if (x.flag) { + bst_uint num_previous_flagged = x.flag_scan - 1; // -1 because inclusive scan + scatter_address = segment.begin + num_previous_flagged; + } else { + bst_uint num_previous_unflagged = (x.idx - segment.begin) - x.flag_scan; + scatter_address = segment.end - num_previous_unflagged - 1; + } + ridx_out[scatter_address] = ridx_in[x.idx]; + + if (x.idx == (segment.end - 1)) { + // Write out counts + counts[x.batch_idx] = x.flag_scan; + } + + // Discard + return {}; + } +}; + +template +void SortPositionBatch(common::Span> d_batch_info, + common::Span ridx, common::Span ridx_tmp, + common::Span d_counts, std::size_t total_rows, OpT op, + dh::device_vector* tmp, cudaStream_t stream) { + dh::LDGIterator> batch_info_itr(d_batch_info.data()); + WriteResultsFunctor write_results{batch_info_itr, ridx.data(), ridx_tmp.data(), + d_counts.data()}; + + auto discard_write_iterator = + thrust::make_transform_output_iterator(dh::TypedDiscard(), write_results); + auto counting = thrust::make_counting_iterator(0llu); + auto input_iterator = + dh::MakeTransformIterator(counting, [=] __device__(size_t idx) { + int batch_idx; + std::size_t item_idx; + AssignBatch(batch_info_itr, idx, &batch_idx, &item_idx); + auto op_res = op(ridx[item_idx], batch_info_itr[batch_idx].data); + return IndexFlagTuple{bst_uint(item_idx), op_res, batch_idx, op_res}; + }); + size_t temp_bytes = 0; + if (tmp->empty()) { + cub::DeviceScan::InclusiveScan(nullptr, temp_bytes, input_iterator, discard_write_iterator, + IndexFlagOp(), total_rows, stream); + tmp->resize(temp_bytes); } -#else - unsigned ballot = __ballot(increment); - if (threadIdx.x % 32 == 0) { - atomicAdd(reinterpret_cast(d_count), // NOLINT - static_cast(__popc(ballot))); // NOLINT + temp_bytes = tmp->size(); + cub::DeviceScan::InclusiveScan(tmp->data().get(), temp_bytes, input_iterator, + discard_write_iterator, IndexFlagOp(), total_rows, stream); + + constexpr int kBlockSize = 256; + + // Value found by experimentation + const int kItemsThread = 12; + const int grid_size = xgboost::common::DivRoundUp(total_rows, kBlockSize * kItemsThread); + + SortPositionCopyKernel + <<>>(batch_info_itr, ridx, ridx_tmp, total_rows); +} + +struct NodePositionInfo { + Segment segment; + bst_node_t left_child = -1; + bst_node_t right_child = -1; + __device__ bool IsLeaf() { return left_child == -1; } +}; + +__device__ __forceinline__ int GetPositionFromSegments(std::size_t idx, + const NodePositionInfo* d_node_info) { + int position = 0; + NodePositionInfo node = d_node_info[position]; + while (!node.IsLeaf()) { + NodePositionInfo left = d_node_info[node.left_child]; + NodePositionInfo right = d_node_info[node.right_child]; + if (idx >= left.segment.begin && idx < left.segment.end) { + position = node.left_child; + node = left; + } else if (idx >= right.segment.begin && idx < right.segment.end) { + position = node.right_child; + node = right; + } else { + KERNEL_CHECK(false); + } + } + return position; +} + +template +__global__ __launch_bounds__(kBlockSize) void FinalisePositionKernel( + const common::Span d_node_info, + const common::Span d_ridx, common::Span d_out_position, OpT op) { + for (auto idx : dh::GridStrideRange(0, d_ridx.size())) { + auto position = GetPositionFromSegments(idx, d_node_info.data()); + RowIndexT ridx = d_ridx[idx]; + bst_node_t new_position = op(ridx, position); + d_out_position[ridx] = new_position; } -#endif } /** \brief Class responsible for tracking subsets of rows as we add splits and @@ -36,7 +196,6 @@ __forceinline__ __device__ void AtomicIncrement(int64_t* d_count, bool increment class RowPartitioner { public: using RowIndexT = bst_uint; - struct Segment; static constexpr bst_node_t kIgnoredTreePosition = -1; private: @@ -49,23 +208,20 @@ class RowPartitioner { * node id -> segment -> indices of rows belonging to node */ /*! \brief Range of row index for each node, pointers into ridx below. */ - std::vector ridx_segments_; - dh::TemporaryArray ridx_a_; - dh::TemporaryArray ridx_b_; - dh::TemporaryArray position_a_; - dh::TemporaryArray position_b_; + + std::vector ridx_segments_; /*! \brief mapping for node id -> rows. * This looks like: * node id | 1 | 2 | * rows idx | 3, 5, 1 | 13, 31 | */ - dh::DoubleBuffer ridx_; - /*! \brief mapping for row -> node id. */ - dh::DoubleBuffer position_; - dh::caching_device_vector - left_counts_; // Useful to keep a bunch of zeroed memory for sort position - std::vector streams_; + dh::TemporaryArray ridx_; + // Staging area for sorting ridx + dh::TemporaryArray ridx_tmp_; + dh::device_vector tmp_; dh::PinnedMemory pinned_; + dh::PinnedMemory pinned2_; + cudaStream_t stream_; public: RowPartitioner(int device_idx, size_t num_rows); @@ -83,73 +239,74 @@ class RowPartitioner { */ common::Span GetRows(); - /** - * \brief Gets the tree position of all training instances. - */ - common::Span GetPosition(); - /** * \brief Convenience method for testing */ std::vector GetRowsHost(bst_node_t nidx); - /** - * \brief Convenience method for testing - */ - std::vector GetPositionHost(); - /** * \brief Updates the tree position for set of training instances being split * into left and right child nodes. Accepts a user-defined lambda specifying * which branch each training instance should go down. * * \tparam UpdatePositionOpT - * \param nidx The index of the node being split. - * \param left_nidx The left child index. - * \param right_nidx The right child index. - * \param op Device lambda. Should provide the row index as an - * argument and return the new position for this training instance. + * \tparam OpDataT + * \param nidx The index of the nodes being split. + * \param left_nidx The left child indices. + * \param right_nidx The right child indices. + * \param op_data User-defined data provided as the second argument to op + * \param op Device lambda with the row index as the first argument and op_data as the + * second. Returns true if this training instance goes on the left partition. */ - template - void UpdatePosition(bst_node_t nidx, bst_node_t left_nidx, - bst_node_t right_nidx, UpdatePositionOpT op) { - Segment segment = ridx_segments_.at(nidx); // rows belongs to node nidx - auto d_ridx = ridx_.CurrentSpan(); - auto d_position = position_.CurrentSpan(); - if (left_counts_.size() <= static_cast(nidx)) { - left_counts_.resize((nidx * 2) + 1); - thrust::fill(left_counts_.begin(), left_counts_.end(), 0); + template + void UpdatePositionBatch(const std::vector& nidx, + const std::vector& left_nidx, + const std::vector& right_nidx, + const std::vector& op_data, UpdatePositionOpT op) { + if (nidx.empty()) return; + CHECK_EQ(nidx.size(), left_nidx.size()); + CHECK_EQ(nidx.size(), right_nidx.size()); + CHECK_EQ(nidx.size(), op_data.size()); + + auto h_batch_info = pinned2_.GetSpan>(nidx.size()); + dh::TemporaryArray> d_batch_info(nidx.size()); + + std::size_t total_rows = 0; + for (int i = 0; i < nidx.size(); i++) { + h_batch_info[i] = {ridx_segments_.at(nidx.at(i)).segment, op_data.at(i)}; + total_rows += ridx_segments_.at(nidx.at(i)).segment.Size(); + } + dh::safe_cuda(cudaMemcpyAsync(d_batch_info.data().get(), h_batch_info.data(), + h_batch_info.size() * sizeof(PerNodeData), + cudaMemcpyDefault, stream_)); + + // Temporary arrays + auto h_counts = pinned_.GetSpan(nidx.size(), 0); + dh::TemporaryArray d_counts(nidx.size(), 0); + + // Partition the rows according to the operator + SortPositionBatch( + dh::ToSpan(d_batch_info), dh::ToSpan(ridx_), dh::ToSpan(ridx_tmp_), dh::ToSpan(d_counts), + total_rows, op, &tmp_, stream_); + dh::safe_cuda(cudaMemcpyAsync(h_counts.data(), d_counts.data().get(), h_counts.size_bytes(), + cudaMemcpyDefault, stream_)); + // TODO(Rory): this synchronisation hurts performance a lot + // Future optimisation should find a way to skip this + dh::safe_cuda(cudaStreamSynchronize(stream_)); + + // Update segments + for (int i = 0; i < nidx.size(); i++) { + auto segment = ridx_segments_.at(nidx[i]).segment; + auto left_count = h_counts[i]; + CHECK_LE(left_count, segment.Size()); + ridx_segments_.resize(std::max(static_cast(ridx_segments_.size()), + std::max(left_nidx[i], right_nidx[i]) + 1)); + ridx_segments_[nidx[i]] = NodePositionInfo{segment, left_nidx[i], right_nidx[i]}; + ridx_segments_[left_nidx[i]] = + NodePositionInfo{Segment(segment.begin, segment.begin + left_count)}; + ridx_segments_[right_nidx[i]] = + NodePositionInfo{Segment(segment.begin + left_count, segment.end)}; } - // Now we divide the row segment into left and right node. - - int64_t* d_left_count = left_counts_.data().get() + nidx; - // Launch 1 thread for each row - dh::LaunchN<1, 128>(segment.Size(), [segment, op, left_nidx, right_nidx, d_ridx, d_left_count, - d_position] __device__(size_t idx) { - // LaunchN starts from zero, so we restore the row index by adding segment.begin - idx += segment.begin; - RowIndexT ridx = d_ridx[idx]; - bst_node_t new_position = op(ridx); // new node id - KERNEL_CHECK(new_position == left_nidx || new_position == right_nidx); - AtomicIncrement(d_left_count, new_position == left_nidx); - d_position[idx] = new_position; - }); - // Overlap device to host memory copy (left_count) with sort - int64_t &left_count = pinned_.GetSpan(1)[0]; - dh::safe_cuda(cudaMemcpyAsync(&left_count, d_left_count, sizeof(int64_t), - cudaMemcpyDeviceToHost, streams_[0])); - - SortPositionAndCopy(segment, left_nidx, right_nidx, d_left_count, streams_[1]); - - dh::safe_cuda(cudaStreamSynchronize(streams_[0])); - CHECK_LE(left_count, segment.Size()); - CHECK_GE(left_count, 0); - ridx_segments_.resize(std::max(static_cast(ridx_segments_.size()), - std::max(left_nidx, right_nidx) + 1)); - ridx_segments_[left_nidx] = - Segment(segment.begin, segment.begin + left_count); - ridx_segments_[right_nidx] = - Segment(segment.begin + left_count, segment.end); } /** @@ -165,69 +322,21 @@ class RowPartitioner { * argument and return the new position for this training instance. * \param sampled A device lambda to inform the partitioner whether a row is sampled. */ - template - void FinalisePosition(Context const* ctx, ObjInfo task, - HostDeviceVector* p_out_position, FinalisePositionOpT op, - Sampledp sampledp) { - auto d_position = position_.Current(); - const auto d_ridx = ridx_.Current(); - if (!task.UpdateTreeLeaf()) { - dh::LaunchN(position_.Size(), [=] __device__(size_t idx) { - auto position = d_position[idx]; - RowIndexT ridx = d_ridx[idx]; - bst_node_t new_position = op(ridx, position); - if (new_position == kIgnoredTreePosition) { - return; - } - d_position[idx] = new_position; - }); - return; - } + template + void FinalisePosition(common::Span d_out_position, FinalisePositionOpT op) { + dh::TemporaryArray d_node_info_storage(ridx_segments_.size()); + dh::safe_cuda(cudaMemcpyAsync(d_node_info_storage.data().get(), ridx_segments_.data(), + sizeof(NodePositionInfo) * ridx_segments_.size(), + cudaMemcpyDefault, stream_)); - p_out_position->SetDevice(ctx->gpu_id); - p_out_position->Resize(position_.Size()); - auto sorted_position = p_out_position->DevicePointer(); - dh::LaunchN(position_.Size(), [=] __device__(size_t idx) { - auto position = d_position[idx]; - RowIndexT ridx = d_ridx[idx]; - bst_node_t new_position = op(ridx, position); - sorted_position[ridx] = sampledp(ridx) ? ~new_position : new_position; - if (new_position == kIgnoredTreePosition) { - return; - } - d_position[idx] = new_position; - }); + constexpr int kBlockSize = 512; + const int kItemsThread = 8; + const int grid_size = xgboost::common::DivRoundUp(ridx_.size(), kBlockSize * kItemsThread); + common::Span d_ridx(ridx_.data().get(), ridx_.size()); + FinalisePositionKernel<<>>( + dh::ToSpan(d_node_info_storage), d_ridx, d_out_position, op); } - - /** - * \brief Optimised routine for sorting key value pairs into left and right - * segments. Based on a single pass of exclusive scan, uses iterators to - * redirect inputs and outputs. - */ - void SortPosition(common::Span position, - common::Span position_out, - common::Span ridx, - common::Span ridx_out, bst_node_t left_nidx, - bst_node_t right_nidx, int64_t* d_left_count, - cudaStream_t stream = nullptr); - - /*! \brief Sort row indices according to position. */ - void SortPositionAndCopy(const Segment& segment, bst_node_t left_nidx, - bst_node_t right_nidx, int64_t* d_left_count, - cudaStream_t stream); - /** \brief Used to demarcate a contiguous set of row indices associated with - * some tree node. */ - struct Segment { - size_t begin { 0 }; - size_t end { 0 }; - - Segment() = default; - - Segment(size_t begin, size_t end) : begin(begin), end(end) { - CHECK_GE(end, begin); - } - size_t Size() const { return end - begin; } - }; }; + }; // namespace tree }; // namespace xgboost diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index ff5899b21607..5eaaeecbadf6 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -182,10 +182,11 @@ struct GPUHistMakerDevice { std::unique_ptr row_partitioner; DeviceHistogramStorage hist{}; - dh::caching_device_vector d_gpair; // storage for gpair; + dh::device_vector d_gpair; // storage for gpair; common::Span gpair; - dh::caching_device_vector monotone_constraints; + dh::device_vector monotone_constraints; + dh::device_vector update_predictions; /*! \brief Sum gradient for each node. */ std::vector node_sum_gradients; @@ -356,36 +357,49 @@ struct GPUHistMakerDevice { return true; } - void UpdatePosition(const GPUExpandEntry &e, RegTree* p_tree) { - RegTree::Node split_node = (*p_tree)[e.nid]; - auto split_type = p_tree->NodeSplitType(e.nid); - auto d_matrix = page->GetDeviceAccessor(ctx_->gpu_id); - auto node_cats = e.split.split_cats.Bits(); + // Extra data for each node that is passed + // to the update position function + struct NodeSplitData { + RegTree::Node split_node; + FeatureType split_type; + common::CatBitField node_cats; + }; - row_partitioner->UpdatePosition( - e.nid, split_node.LeftChild(), split_node.RightChild(), - [=] __device__(bst_uint ridx) { + void UpdatePosition(const std::vector& candidates, RegTree* p_tree) { + if (candidates.empty()) return; + std::vector nidx(candidates.size()); + std::vector left_nidx(candidates.size()); + std::vector right_nidx(candidates.size()); + std::vector split_data(candidates.size()); + for (int i = 0; i < candidates.size(); i++) { + auto& e = candidates[i]; + RegTree::Node split_node = (*p_tree)[e.nid]; + auto split_type = p_tree->NodeSplitType(e.nid); + nidx.at(i) = e.nid; + left_nidx.at(i) = split_node.LeftChild(); + right_nidx.at(i) = split_node.RightChild(); + split_data.at(i) = NodeSplitData{split_node, split_type, e.split.split_cats}; + } + + auto d_matrix = page->GetDeviceAccessor(ctx_->gpu_id); + row_partitioner->UpdatePositionBatch( + nidx, left_nidx, right_nidx, split_data, + [=] __device__(bst_uint ridx, const NodeSplitData& data) { // given a row index, returns the node id it belongs to - bst_float cut_value = - d_matrix.GetFvalue(ridx, split_node.SplitIndex()); + bst_float cut_value = d_matrix.GetFvalue(ridx, data.split_node.SplitIndex()); // Missing value - bst_node_t new_position = 0; + bool go_left = true; if (isnan(cut_value)) { - new_position = split_node.DefaultChild(); + go_left = data.split_node.DefaultLeft(); } else { - bool go_left = true; - if (split_type == FeatureType::kCategorical) { - go_left = common::Decision(node_cats, cut_value, split_node.DefaultLeft()); + if (data.split_type == FeatureType::kCategorical) { + go_left = common::Decision(data.node_cats.Bits(), cut_value, + data.split_node.DefaultLeft()); } else { - go_left = cut_value <= split_node.SplitCond(); - } - if (go_left) { - new_position = split_node.LeftChild(); - } else { - new_position = split_node.RightChild(); + go_left = cut_value <= data.split_node.SplitCond(); } } - return new_position; + return go_left; }); } @@ -394,6 +408,16 @@ struct GPUHistMakerDevice { // prediction cache void FinalisePosition(RegTree const* p_tree, DMatrix* p_fmat, ObjInfo task, HostDeviceVector* p_out_position) { + // Prediction cache will not be used with external memory + if (!p_fmat->SingleColBlock()) { + if (task.UpdateTreeLeaf()) { + LOG(FATAL) << "Current objective function can not be used with external memory."; + } + p_out_position->Resize(0); + update_predictions.clear(); + return; + } + dh::TemporaryArray d_nodes(p_tree->GetNodes().size()); dh::safe_cuda(cudaMemcpyAsync(d_nodes.data().get(), p_tree->GetNodes().data(), d_nodes.size() * sizeof(RegTree::Node), @@ -412,25 +436,9 @@ struct GPUHistMakerDevice { dh::CopyToD(categories_segments, &d_categories_segments); } - if (row_partitioner->GetRows().size() != p_fmat->Info().num_row_) { - row_partitioner.reset(); // Release the device memory first before reallocating - row_partitioner.reset(new RowPartitioner(ctx_->gpu_id, p_fmat->Info().num_row_)); - } - if (task.UpdateTreeLeaf() && !p_fmat->SingleColBlock() && param.subsample != 1.0) { - // see comment in the `FinalisePositionInPage`. - LOG(FATAL) << "Current objective function can not be used with subsampled external memory."; - } - if (page->n_rows == p_fmat->Info().num_row_) { - FinalisePositionInPage(page, dh::ToSpan(d_nodes), dh::ToSpan(d_split_types), - dh::ToSpan(d_categories), dh::ToSpan(d_categories_segments), task, - p_out_position); - } else { - for (auto const& batch : p_fmat->GetBatches(batch_param)) { - FinalisePositionInPage(batch.Impl(), dh::ToSpan(d_nodes), dh::ToSpan(d_split_types), - dh::ToSpan(d_categories), dh::ToSpan(d_categories_segments), task, - p_out_position); - } - } + FinalisePositionInPage(page, dh::ToSpan(d_nodes), dh::ToSpan(d_split_types), + dh::ToSpan(d_categories), dh::ToSpan(d_categories_segments), + p_out_position); } void FinalisePositionInPage(EllpackPageImpl const *page, @@ -438,79 +446,73 @@ struct GPUHistMakerDevice { common::Span d_feature_types, common::Span categories, common::Span categories_segments, - ObjInfo task, HostDeviceVector* p_out_position) { auto d_matrix = page->GetDeviceAccessor(ctx_->gpu_id); auto d_gpair = this->gpair; - row_partitioner->FinalisePosition( - ctx_, task, p_out_position, - [=] __device__(size_t row_id, int position) { - // What happens if user prune the tree? - if (!d_matrix.IsInRange(row_id)) { - return RowPartitioner::kIgnoredTreePosition; + update_predictions.resize(row_partitioner->GetRows().size()); + auto d_update_predictions = dh::ToSpan(update_predictions); + p_out_position->SetDevice(ctx_->gpu_id); + p_out_position->Resize(row_partitioner->GetRows().size()); + + auto new_position_op = [=] __device__(size_t row_id, int position) { + // What happens if user prune the tree? + if (!d_matrix.IsInRange(row_id)) { + return RowPartitioner::kIgnoredTreePosition; + } + auto node = d_nodes[position]; + + while (!node.IsLeaf()) { + bst_float element = d_matrix.GetFvalue(row_id, node.SplitIndex()); + // Missing value + if (isnan(element)) { + position = node.DefaultChild(); + } else { + bool go_left = true; + if (common::IsCat(d_feature_types, position)) { + auto node_cats = categories.subspan(categories_segments[position].beg, + categories_segments[position].size); + go_left = common::Decision(node_cats, element, node.DefaultLeft()); + } else { + go_left = element <= node.SplitCond(); } - auto node = d_nodes[position]; - - while (!node.IsLeaf()) { - bst_float element = d_matrix.GetFvalue(row_id, node.SplitIndex()); - // Missing value - if (isnan(element)) { - position = node.DefaultChild(); - } else { - bool go_left = true; - if (common::IsCat(d_feature_types, position)) { - auto node_cats = - categories.subspan(categories_segments[position].beg, - categories_segments[position].size); - go_left = common::Decision(node_cats, element, node.DefaultLeft()); - } else { - go_left = element <= node.SplitCond(); - } - if (go_left) { - position = node.LeftChild(); - } else { - position = node.RightChild(); - } - } - node = d_nodes[position]; + if (go_left) { + position = node.LeftChild(); + } else { + position = node.RightChild(); } + } - return position; - }, - [d_gpair] __device__(size_t ridx) { - // FIXME(jiamingy): Doesn't work when sampling is used with external memory as - // the sampler compacts the gradient vector. - return d_gpair[ridx].GetHess() - .0f == 0.f; - }); + node = d_nodes[position]; + } + + d_update_predictions[row_id] = node.LeafValue(); + return position; + }; // NOLINT + + auto d_out_position = p_out_position->DeviceSpan(); + row_partitioner->FinalisePosition(d_out_position, new_position_op); + + dh::LaunchN(row_partitioner->GetRows().size(), [=] __device__(size_t idx) { + bst_node_t position = d_out_position[idx]; + d_update_predictions[idx] = d_nodes[position].LeafValue(); + bool is_row_sampled = d_gpair[idx].GetHess() - .0f == 0.f; + d_out_position[idx] = is_row_sampled ? ~position : position; + }); } - void UpdatePredictionCache(linalg::VectorView out_preds_d, RegTree const* p_tree) { + bool UpdatePredictionCache(linalg::VectorView out_preds_d, RegTree const* p_tree) { + if (update_predictions.empty()) { + return false; + } CHECK(p_tree); dh::safe_cuda(cudaSetDevice(ctx_->gpu_id)); CHECK_EQ(out_preds_d.DeviceIdx(), ctx_->gpu_id); - auto d_ridx = row_partitioner->GetRows(); - - GPUTrainingParam param_d(param); - dh::TemporaryArray device_node_sum_gradients(node_sum_gradients.size()); - - dh::safe_cuda(cudaMemcpyAsync(device_node_sum_gradients.data().get(), node_sum_gradients.data(), - sizeof(GradientPairPrecise) * node_sum_gradients.size(), - cudaMemcpyHostToDevice)); - auto d_position = row_partitioner->GetPosition(); - auto d_node_sum_gradients = device_node_sum_gradients.data().get(); - auto tree_evaluator = evaluator_.GetEvaluator(); - - auto const& h_nodes = p_tree->GetNodes(); - dh::caching_device_vector nodes(h_nodes.size()); - dh::safe_cuda(cudaMemcpyAsync(nodes.data().get(), h_nodes.data(), - h_nodes.size() * sizeof(RegTree::Node), cudaMemcpyHostToDevice)); - auto d_nodes = dh::ToSpan(nodes); - dh::LaunchN(d_ridx.size(), [=] XGBOOST_DEVICE(size_t idx) mutable { - bst_node_t nidx = d_position[idx]; - auto weight = d_nodes[nidx].LeafValue(); - out_preds_d(d_ridx[idx]) += weight; + auto d_update_predictions = dh::ToSpan(update_predictions); + CHECK_EQ(out_preds_d.Size(), d_update_predictions.size()); + dh::LaunchN(out_preds_d.Size(), [=] XGBOOST_DEVICE(size_t idx) mutable { + out_preds_d(idx) += d_update_predictions[idx]; }); - row_partitioner.reset(); + return true; } // num histograms is the number of contiguous histograms in memory to reduce over @@ -684,14 +686,12 @@ struct GPUHistMakerDevice { auto new_candidates = pinned.GetSpan(filtered_expand_set.size() * 2, GPUExpandEntry()); - for (const auto& e : filtered_expand_set) { - monitor.Start("UpdatePosition"); - // Update position is only run when child is valid, instead of right after apply - // split (as in approx tree method). Hense we have the finalise position call - // in GPU Hist. - this->UpdatePosition(e, p_tree); - monitor.Stop("UpdatePosition"); - } + monitor.Start("UpdatePosition"); + // Update position is only run when child is valid, instead of right after apply + // split (as in approx tree method). Hense we have the finalise position call + // in GPU Hist. + this->UpdatePosition(filtered_expand_set, p_tree); + monitor.Stop("UpdatePosition"); monitor.Start("BuildHist"); this->BuildHistLeftRight(filtered_expand_set, reducer, tree); @@ -844,9 +844,9 @@ class GPUHistMaker : public TreeUpdater { return false; } monitor_.Start("UpdatePredictionCache"); - maker->UpdatePredictionCache(p_out_preds, p_last_tree_); + bool result = maker->UpdatePredictionCache(p_out_preds, p_last_tree_); monitor_.Stop("UpdatePredictionCache"); - return true; + return result; } TrainParam param_; // NOLINT diff --git a/tests/cpp/tree/gpu_hist/test_row_partitioner.cu b/tests/cpp/tree/gpu_hist/test_row_partitioner.cu index cf1710699b5e..520cc3cd0b81 100644 --- a/tests/cpp/tree/gpu_hist/test_row_partitioner.cu +++ b/tests/cpp/tree/gpu_hist/test_row_partitioner.cu @@ -19,49 +19,7 @@ namespace xgboost { namespace tree { -void TestSortPosition(const std::vector& position_in, int left_idx, - int right_idx) { - dh::safe_cuda(cudaSetDevice(0)); - std::vector left_count = { - std::count(position_in.begin(), position_in.end(), left_idx)}; - dh::caching_device_vector d_left_count = left_count; - dh::caching_device_vector position = position_in; - dh::caching_device_vector position_out(position.size()); - - dh::caching_device_vector ridx(position.size()); - thrust::sequence(ridx.begin(), ridx.end()); - dh::caching_device_vector ridx_out(ridx.size()); - RowPartitioner rp(0,10); - rp.SortPosition( - common::Span(position.data().get(), position.size()), - common::Span(position_out.data().get(), position_out.size()), - common::Span(ridx.data().get(), ridx.size()), - common::Span(ridx_out.data().get(), ridx_out.size()), left_idx, - right_idx, d_left_count.data().get(), nullptr); - thrust::host_vector position_result = position_out; - thrust::host_vector ridx_result = ridx_out; - - // Check position is sorted - EXPECT_TRUE(std::is_sorted(position_result.begin(), position_result.end())); - // Check row indices are sorted inside left and right segment - EXPECT_TRUE( - std::is_sorted(ridx_result.begin(), ridx_result.begin() + left_count[0])); - EXPECT_TRUE( - std::is_sorted(ridx_result.begin() + left_count[0], ridx_result.end())); - - // Check key value pairs are the same - for (auto i = 0ull; i < ridx_result.size(); i++) { - EXPECT_EQ(position_result[i], position_in[ridx_result[i]]); - } -} -TEST(GpuHist, SortPosition) { - TestSortPosition({1, 2, 1, 2, 1}, 1, 2); - TestSortPosition({1, 1, 1, 1}, 1, 2); - TestSortPosition({2, 2, 2, 2}, 1, 2); - TestSortPosition({1, 2, 1, 2, 3}, 1, 2); -} - -void TestUpdatePosition() { +void TestUpdatePositionBatch() { const int kNumRows = 10; RowPartitioner rp(0, kNumRows); auto rows = rp.GetRowsHost(0); @@ -69,16 +27,11 @@ void TestUpdatePosition() { for (auto i = 0ull; i < kNumRows; i++) { EXPECT_EQ(rows[i], i); } + std::vector extra_data = {0}; // Send the first five training instances to the right node // and the second 5 to the left node - rp.UpdatePosition(0, 1, 2, - [=] __device__(RowPartitioner::RowIndexT ridx) { - if (ridx > 4) { - return 1; - } - else { - return 2; - } + rp.UpdatePositionBatch({0}, {1}, {2}, extra_data, [=] __device__(RowPartitioner::RowIndexT ridx, int) { + return ridx > 4; }); rows = rp.GetRowsHost(1); for (auto r : rows) { @@ -90,88 +43,58 @@ void TestUpdatePosition() { } // Split the left node again - rp.UpdatePosition(1, 3, 4, [=]__device__(RowPartitioner::RowIndexT ridx) - { - if (ridx < 7) { - return 3 - ; - } - return 4; + rp.UpdatePositionBatch({1}, {3}, {4}, extra_data,[=] __device__(RowPartitioner::RowIndexT ridx, int) { + return ridx < 7; }); EXPECT_EQ(rp.GetRows(3).size(), 2); EXPECT_EQ(rp.GetRows(4).size(), 3); - // Check position is as expected - EXPECT_EQ(rp.GetPositionHost(), std::vector({3,3,4,4,4,2,2,2,2,2})); } -TEST(RowPartitioner, Basic) { TestUpdatePosition(); } - -void TestFinalise() { - const int kNumRows = 10; - - ObjInfo task{ObjInfo::kRegression, false, false}; - HostDeviceVector position; - Context ctx; - ctx.gpu_id = 0; +TEST(RowPartitioner, Batch) { TestUpdatePositionBatch(); } - { - RowPartitioner rp(0, kNumRows); - rp.FinalisePosition( - &ctx, task, &position, - [=] __device__(RowPartitioner::RowIndexT ridx, int position) { return 7; }, - [] XGBOOST_DEVICE(size_t) { return false; }); +void TestSortPositionBatch(const std::vector& ridx_in, const std::vector& segments) { + thrust::device_vector ridx = ridx_in; + thrust::device_vector ridx_tmp(ridx_in.size()); + thrust::device_vector counts(segments.size()); - auto position = rp.GetPositionHost(); - for (auto p : position) { - EXPECT_EQ(p, 7); - } - } + auto op = [=] __device__(auto ridx, int data) { return ridx % 2 == 0; }; + std::vector op_data(segments.size()); + std::vector> h_batch_info(segments.size()); + dh::TemporaryArray> d_batch_info(segments.size()); - /** - * Test for sampling. - */ - dh::device_vector hess(kNumRows); - for (size_t i = 0; i < hess.size(); ++i) { - // removed rows, 0, 3, 6, 9 - if (i % 3 == 0) { - hess[i] = 0; - } else { - hess[i] = i; - } + std::size_t total_rows = 0; + for (int i = 0; i < segments.size(); i++) { + h_batch_info[i] = {segments.at(i), 0}; + total_rows += segments.at(i).Size(); } - - auto d_hess = dh::ToSpan(hess); - - RowPartitioner rp(0, kNumRows); - rp.FinalisePosition( - &ctx, task, &position, - [] __device__(RowPartitioner::RowIndexT ridx, bst_node_t position) { - return ridx % 2 == 0 ? 1 : 2; - }, - [d_hess] __device__(size_t ridx) { return d_hess[ridx] - 0.f == 0.f; }); - - auto const& h_position = position.ConstHostVector(); - for (size_t ridx = 0; ridx < h_position.size(); ++ridx) { - if (ridx % 3 == 0) { - ASSERT_LT(h_position[ridx], 0); - } else { - ASSERT_EQ(h_position[ridx], ridx % 2 == 0 ? 1 : 2); - } + dh::safe_cuda(cudaMemcpyAsync(d_batch_info.data().get(), h_batch_info.data(), + h_batch_info.size() * sizeof(PerNodeData), cudaMemcpyDefault, + nullptr)); + dh::device_vector tmp; + SortPositionBatch(dh::ToSpan(d_batch_info), dh::ToSpan(ridx), + dh::ToSpan(ridx_tmp), dh::ToSpan(counts), + total_rows, op, &tmp, nullptr); + + auto op_without_data = [=] __device__(auto ridx) { return ridx % 2 == 0; }; + for (int i = 0; i < segments.size(); i++) { + auto begin = ridx.begin() + segments[i].begin; + auto end = ridx.begin() + segments[i].end; + bst_uint count = counts[i]; + auto left_partition_count = + thrust::count_if(thrust::device, begin, begin + count, op_without_data); + EXPECT_EQ(left_partition_count, count); + auto right_partition_count = + thrust::count_if(thrust::device, begin + count, end, op_without_data); + EXPECT_EQ(right_partition_count, 0); } } -TEST(RowPartitioner, Finalise) { TestFinalise(); } - -void TestIncorrectRow() { - RowPartitioner rp(0, 1); - rp.UpdatePosition(0, 1, 2, [=]__device__ (RowPartitioner::RowIndexT ridx) - { - return 4; // This is not the left branch or the right branch - }); +TEST(GpuHist, SortPositionBatch) { + TestSortPositionBatch({0, 1, 2, 3, 4, 5}, {{0, 3}, {3, 6}}); + TestSortPositionBatch({0, 1, 2, 3, 4, 5}, {{0, 1}, {3, 6}}); + TestSortPositionBatch({0, 1, 2, 3, 4, 5}, {{0, 6}}); + TestSortPositionBatch({0, 1, 2, 3, 4, 5}, {{3, 6}, {0, 2}}); } -TEST(RowPartitionerDeathTest, IncorrectRow) { - ASSERT_DEATH({ TestIncorrectRow(); },".*"); -} } // namespace tree } // namespace xgboost