diff --git a/src/tree/gpu_hist/row_partitioner.cu b/src/tree/gpu_hist/row_partitioner.cu index 540c07a6fe64..015d817f3640 100644 --- a/src/tree/gpu_hist/row_partitioner.cu +++ b/src/tree/gpu_hist/row_partitioner.cu @@ -14,10 +14,7 @@ namespace xgboost { namespace tree { RowPartitioner::RowPartitioner(int device_idx, size_t num_rows) - : device_idx_(device_idx), - ridx_(num_rows), - ridx_tmp_(num_rows), - d_counts_(kMaxUpdatePositionBatchSize) { + : device_idx_(device_idx), ridx_(num_rows), ridx_tmp_(num_rows) { dh::safe_cuda(cudaSetDevice(device_idx_)); ridx_segments_.emplace_back(NodePositionInfo{Segment(0, num_rows)}); thrust::sequence(thrust::device, ridx_.data(), ridx_.data() + ridx_.size()); diff --git a/src/tree/gpu_hist/row_partitioner.cuh b/src/tree/gpu_hist/row_partitioner.cuh index e9fb7e86add7..4ba0bd27fe2f 100644 --- a/src/tree/gpu_hist/row_partitioner.cuh +++ b/src/tree/gpu_hist/row_partitioner.cuh @@ -218,7 +218,6 @@ class RowPartitioner { dh::TemporaryArray ridx_; // Staging area for sorting ridx dh::TemporaryArray ridx_tmp_; - dh::TemporaryArray d_counts_; dh::device_vector tmp_; dh::PinnedMemory pinned_; dh::PinnedMemory pinned2_; @@ -283,13 +282,13 @@ class RowPartitioner { // Temporary arrays auto h_counts = pinned_.GetSpan(nidx.size(), 0); + dh::TemporaryArray d_counts(nidx.size(), 0); // Partition the rows according to the operator SortPositionBatch( - dh::ToSpan(d_batch_info), dh::ToSpan(ridx_), dh::ToSpan(ridx_tmp_), dh::ToSpan(d_counts_), + dh::ToSpan(d_batch_info), dh::ToSpan(ridx_), dh::ToSpan(ridx_tmp_), dh::ToSpan(d_counts), total_rows, op, &tmp_, stream_); - dh::safe_cuda(cudaMemcpyAsync(h_counts.data(), d_counts_.data().get(), - sizeof(decltype(d_counts_)::value_type) * h_counts.size(), + dh::safe_cuda(cudaMemcpyAsync(h_counts.data(), d_counts.data().get(), h_counts.size_bytes(), cudaMemcpyDefault, stream_)); // TODO(Rory): this synchronisation hurts performance a lot // Future optimisation should find a way to skip this @@ -300,7 +299,6 @@ class RowPartitioner { auto segment = ridx_segments_.at(nidx[i]).segment; auto left_count = h_counts[i]; CHECK_LE(left_count, segment.Size()); - CHECK_GE(left_count, 0); ridx_segments_.resize(std::max(static_cast(ridx_segments_.size()), std::max(left_nidx[i], right_nidx[i]) + 1)); ridx_segments_[nidx[i]] = NodePositionInfo{segment, left_nidx[i], right_nidx[i]};