Skip to content

Commit

Permalink
Initialise memory in case zero training rows.
Browse files Browse the repository at this point in the history
  • Loading branch information
RAMitchell committed Jun 30, 2022
1 parent 3cd5e41 commit aad0d8e
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 10 deletions.
5 changes: 1 addition & 4 deletions src/tree/gpu_hist/row_partitioner.cu
Expand Up @@ -14,10 +14,7 @@ namespace xgboost {
namespace tree {

RowPartitioner::RowPartitioner(int device_idx, size_t num_rows)
: device_idx_(device_idx),
ridx_(num_rows),
ridx_tmp_(num_rows),
d_counts_(kMaxUpdatePositionBatchSize) {
: device_idx_(device_idx), ridx_(num_rows), ridx_tmp_(num_rows) {
dh::safe_cuda(cudaSetDevice(device_idx_));
ridx_segments_.emplace_back(NodePositionInfo{Segment(0, num_rows)});
thrust::sequence(thrust::device, ridx_.data(), ridx_.data() + ridx_.size());
Expand Down
10 changes: 4 additions & 6 deletions src/tree/gpu_hist/row_partitioner.cuh
Expand Up @@ -218,7 +218,6 @@ class RowPartitioner {
dh::TemporaryArray<RowIndexT> ridx_;
// Staging area for sorting ridx
dh::TemporaryArray<RowIndexT> ridx_tmp_;
dh::TemporaryArray<bst_uint> d_counts_;
dh::device_vector<int8_t> tmp_;
dh::PinnedMemory pinned_;
dh::PinnedMemory pinned2_;
Expand Down Expand Up @@ -283,13 +282,13 @@ class RowPartitioner {

// Temporary arrays
auto h_counts = pinned_.GetSpan<bst_uint>(nidx.size(), 0);
dh::TemporaryArray<bst_uint> d_counts(nidx.size(), 0);

// Partition the rows according to the operator
SortPositionBatch<RowIndexT, UpdatePositionOpT, OpDataT>(
dh::ToSpan(d_batch_info), dh::ToSpan(ridx_), dh::ToSpan(ridx_tmp_), dh::ToSpan(d_counts_),
dh::ToSpan(d_batch_info), dh::ToSpan(ridx_), dh::ToSpan(ridx_tmp_), dh::ToSpan(d_counts),
total_rows, op, &tmp_, stream_);
dh::safe_cuda(cudaMemcpyAsync(h_counts.data(), d_counts_.data().get(),
sizeof(decltype(d_counts_)::value_type) * h_counts.size(),
dh::safe_cuda(cudaMemcpyAsync(h_counts.data(), d_counts.data().get(), h_counts.size_bytes(),
cudaMemcpyDefault, stream_));
// TODO(Rory): this synchronisation hurts performance a lot
// Future optimisation should find a way to skip this
Expand All @@ -299,8 +298,7 @@ class RowPartitioner {
for (int i = 0; i < nidx.size(); i++) {
auto segment = ridx_segments_.at(nidx[i]).segment;
auto left_count = h_counts[i];
CHECK_LE(left_count, segment.Size());
CHECK_GE(left_count, 0);
CHECK_LE(left_count, segment.Size()) << nidx[i];
ridx_segments_.resize(std::max(static_cast<bst_node_t>(ridx_segments_.size()),
std::max(left_nidx[i], right_nidx[i]) + 1));
ridx_segments_[nidx[i]] = NodePositionInfo{segment, left_nidx[i], right_nidx[i]};
Expand Down

0 comments on commit aad0d8e

Please sign in to comment.