diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index 6065f28e2147..5bb9129a4dac 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -279,7 +279,7 @@ class LaunchKernel { }; template -inline void LaunchN(int device_idx, size_t n, cudaStream_t stream, L lambda) { +inline void LaunchN(size_t n, cudaStream_t stream, L lambda) { if (n == 0) { return; } @@ -291,13 +291,13 @@ inline void LaunchN(int device_idx, size_t n, cudaStream_t stream, L lambda) { // Default stream version template -inline void LaunchN(int device_idx, size_t n, L lambda) { - LaunchN(device_idx, n, nullptr, lambda); +inline void LaunchN(size_t n, L lambda) { + LaunchN(n, nullptr, lambda); } template -void Iota(Container array, int32_t device = CurrentDevice()) { - LaunchN(device, array.size(), [=] __device__(size_t i) { array[i] = i; }); +void Iota(Container array) { + LaunchN(array.size(), [=] __device__(size_t i) { array[i] = i; }); } namespace detail { @@ -539,7 +539,7 @@ class TemporaryArray { int device = 0; dh::safe_cuda(cudaGetDevice(&device)); auto d_data = ptr_.get(); - LaunchN(device, this->size(), [=] __device__(size_t idx) { d_data[idx] = val; }); + LaunchN(this->size(), [=] __device__(size_t idx) { d_data[idx] = val; }); } thrust::device_ptr data() { return ptr_; } // NOLINT size_t size() { return size_; } // NOLINT diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index f85fea945275..1c0961a0e875 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -159,7 +159,7 @@ void RemoveDuplicatedCategories( auto d_new_cuts_size = dh::ToSpan(new_cuts_size); auto d_new_columns_ptr = dh::ToSpan(new_column_scan); CHECK_EQ(new_column_scan.size(), new_cuts_size.size()); - dh::LaunchN(device, new_column_scan.size(), [=] __device__(size_t idx) { + dh::LaunchN(new_column_scan.size(), [=] __device__(size_t idx) { d_old_column_sizes_scan[idx] = d_new_columns_ptr[idx]; if (idx == d_new_columns_ptr.size() - 1) { return; @@ -248,14 +248,14 @@ void ProcessWeightedBatch(int device, const SparsePage& page, << "Must have at least 1 group for ranking."; CHECK_EQ(weights.size(), d_group_ptr.size() - 1) << "Weight size should equal to number of groups."; - dh::LaunchN(device, temp_weights.size(), [=] __device__(size_t idx) { + dh::LaunchN(temp_weights.size(), [=] __device__(size_t idx) { size_t element_idx = idx + begin; size_t ridx = dh::SegmentId(row_ptrs, element_idx); bst_group_t group_idx = dh::SegmentId(d_group_ptr, ridx + base_rowid); d_temp_weights[idx] = weights[group_idx]; }); } else { - dh::LaunchN(device, temp_weights.size(), [=] __device__(size_t idx) { + dh::LaunchN(temp_weights.size(), [=] __device__(size_t idx) { size_t element_idx = idx + begin; size_t ridx = dh::SegmentId(row_ptrs, element_idx); d_temp_weights[idx] = weights[ridx + base_rowid]; diff --git a/src/common/hist_util.cuh b/src/common/hist_util.cuh index 898b198f8f2c..5f2e2add6bdc 100644 --- a/src/common/hist_util.cuh +++ b/src/common/hist_util.cuh @@ -41,7 +41,7 @@ void GetColumnSizesScan(int device, size_t num_columns, size_t num_cuts_per_feat dh::XGBCachingDeviceAllocator alloc; auto d_column_sizes_scan = column_sizes_scan->data().get(); - dh::LaunchN(device, end - begin, [=] __device__(size_t idx) { + dh::LaunchN(end - begin, [=] __device__(size_t idx) { auto e = batch_iter[begin + idx]; if (is_valid(e)) { atomicAdd(&d_column_sizes_scan[e.column_idx], static_cast(1)); diff --git a/src/common/host_device_vector.cu b/src/common/host_device_vector.cu index f50819322bbb..8287cb24a1bd 100644 --- a/src/common/host_device_vector.cu +++ b/src/common/host_device_vector.cu @@ -93,9 +93,8 @@ class HostDeviceVectorImpl { gpu_access_ = GPUAccess::kWrite; SetDevice(); auto s_data = dh::ToSpan(*data_d_); - dh::LaunchN(device_, data_d_->size(), [=]XGBOOST_DEVICE(size_t i) { - s_data[i] = v; - }); + dh::LaunchN(data_d_->size(), + [=] XGBOOST_DEVICE(size_t i) { s_data[i] = v; }); } } diff --git a/src/common/quantile.cu b/src/common/quantile.cu index 607f8fa9fce0..e9f3e93d00ee 100644 --- a/src/common/quantile.cu +++ b/src/common/quantile.cu @@ -61,7 +61,7 @@ void PruneImpl(int device, Span feature_types, Span out_cuts, ToSketchEntry to_sketch_entry) { - dh::LaunchN(device, out_cuts.size(), [=] __device__(size_t idx) { + dh::LaunchN(out_cuts.size(), [=] __device__(size_t idx) { size_t column_id = dh::SegmentId(cuts_ptr, idx); auto out_column = out_cuts.subspan( cuts_ptr[column_id], cuts_ptr[column_id + 1] - cuts_ptr[column_id]); @@ -221,7 +221,7 @@ void MergeImpl(int32_t device, Span const &d_x, auto d_merge_path = MergePath(d_x, x_ptr, d_y, y_ptr, out, out_ptr); auto d_out = out; - dh::LaunchN(device, d_out.size(), [=] __device__(size_t idx) { + dh::LaunchN(d_out.size(), [=] __device__(size_t idx) { auto column_id = dh::SegmentId(out_ptr, idx); idx -= out_ptr[column_id]; @@ -487,7 +487,7 @@ void SketchContainer::FixError() { dh::safe_cuda(cudaSetDevice(device_)); auto d_columns_ptr = this->columns_ptr_.ConstDeviceSpan(); auto in = dh::ToSpan(this->Current()); - dh::LaunchN(device_, in.size(), [=] __device__(size_t idx) { + dh::LaunchN(in.size(), [=] __device__(size_t idx) { auto column_id = dh::SegmentId(d_columns_ptr, idx); auto in_column = in.subspan(d_columns_ptr[column_id], d_columns_ptr[column_id + 1] - @@ -627,7 +627,7 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) { auto out_cut_values = p_cuts->cut_values_.DeviceSpan(); auto d_ft = feature_types_.ConstDeviceSpan(); - dh::LaunchN(0, total_bins, [=] __device__(size_t idx) { + dh::LaunchN(total_bins, [=] __device__(size_t idx) { auto column_id = dh::SegmentId(d_out_columns_ptr, idx); auto in_column = in_cut_values.subspan(d_in_columns_ptr[column_id], d_in_columns_ptr[column_id + 1] - diff --git a/src/common/ranking_utils.cuh b/src/common/ranking_utils.cuh index c9b71c154919..f63e38cba58c 100644 --- a/src/common/ranking_utils.cuh +++ b/src/common/ranking_utils.cuh @@ -44,7 +44,7 @@ SegmentedTrapezoidThreads(xgboost::common::Span group_ptr, CHECK_GE(group_ptr.size(), 1); CHECK_EQ(group_ptr.size(), out_group_threads_ptr.size()); dh::LaunchN( - dh::CurrentDevice(), group_ptr.size(), [=] XGBOOST_DEVICE(size_t idx) { + group_ptr.size(), [=] XGBOOST_DEVICE(size_t idx) { if (idx == 0) { out_group_threads_ptr[0] = 0; return; diff --git a/src/data/data.cu b/src/data/data.cu index b0aba701d1f1..dffe19d668e2 100644 --- a/src/data/data.cu +++ b/src/data/data.cu @@ -29,7 +29,7 @@ void CopyInfoImpl(ArrayInterface column, HostDeviceVector* out) { auto p_dst = thrust::device_pointer_cast(out->DevicePointer()); - dh::LaunchN(ptr_device, column.num_rows, [=] __device__(size_t idx) { + dh::LaunchN(column.num_rows, [=] __device__(size_t idx) { p_dst[idx] = column.GetElement(idx, 0); }); } @@ -49,10 +49,11 @@ void CopyGroupInfoImpl(ArrayInterface column, std::vector* out) { << "Expected integer for group info."; auto ptr_device = SetDeviceToPtr(column.data); + CHECK_EQ(ptr_device, dh::CurrentDevice()); dh::TemporaryArray temp(column.num_rows); auto d_tmp = temp.data(); - dh::LaunchN(ptr_device, column.num_rows, [=] __device__(size_t idx) { + dh::LaunchN(column.num_rows, [=] __device__(size_t idx) { d_tmp[idx] = column.GetElement(idx, 0); }); auto length = column.num_rows; @@ -73,8 +74,8 @@ void CopyQidImpl(ArrayInterface array_interface, dh::caching_device_vector flag(1); auto d_flag = dh::ToSpan(flag); auto d = SetDeviceToPtr(array_interface.data); - dh::LaunchN(d, 1, [=] __device__(size_t) { d_flag[0] = true; }); - dh::LaunchN(d, array_interface.num_rows - 1, [=] __device__(size_t i) { + dh::LaunchN(1, [=] __device__(size_t) { d_flag[0] = true; }); + dh::LaunchN(array_interface.num_rows - 1, [=] __device__(size_t i) { if (array_interface.GetElement(i, 0) > array_interface.GetElement(i + 1, 0)) { d_flag[0] = false; diff --git a/src/data/device_adapter.cuh b/src/data/device_adapter.cuh index ee94b3dec7bd..a772a064f9b7 100644 --- a/src/data/device_adapter.cuh +++ b/src/data/device_adapter.cuh @@ -216,7 +216,7 @@ size_t GetRowCounts(const AdapterBatchT batch, common::Span offset, int device_idx, float missing) { IsValidFunctor is_valid(missing); // Count elements per row - dh::LaunchN(device_idx, batch.Size(), [=] __device__(size_t idx) { + dh::LaunchN(batch.Size(), [=] __device__(size_t idx) { auto element = batch.GetElement(idx); if (is_valid(element)) { atomicAdd(reinterpret_cast( // NOLINT diff --git a/src/data/ellpack_page.cu b/src/data/ellpack_page.cu index ffae315cd026..3574f7d339b6 100644 --- a/src/data/ellpack_page.cu +++ b/src/data/ellpack_page.cu @@ -257,16 +257,16 @@ void WriteNullValues(EllpackPageImpl* dst, int device_idx, common::CompressedBufferWriter writer(device_accessor.NumSymbols()); auto d_compressed_buffer = dst->gidx_buffer.DevicePointer(); auto row_stride = dst->row_stride; - dh::LaunchN(device_idx, row_stride * dst->n_rows, [=] __device__(size_t idx) { - auto writer_non_const = - writer; // For some reason this variable gets captured as const - size_t row_idx = idx / row_stride; - size_t row_offset = idx % row_stride; - if (row_offset >= row_counts[row_idx]) { - writer_non_const.AtomicWriteSymbol(d_compressed_buffer, - device_accessor.NullValue(), idx); - } - }); + dh::LaunchN(row_stride * dst->n_rows, [=] __device__(size_t idx) { + // For some reason this variable got captured as const + auto writer_non_const = writer; + size_t row_idx = idx / row_stride; + size_t row_offset = idx % row_stride; + if (row_offset >= row_counts[row_idx]) { + writer_non_const.AtomicWriteSymbol(d_compressed_buffer, + device_accessor.NullValue(), idx); + } + }); } template @@ -326,7 +326,7 @@ size_t EllpackPageImpl::Copy(int device, EllpackPageImpl* page, size_t offset) { } gidx_buffer.SetDevice(device); page->gidx_buffer.SetDevice(device); - dh::LaunchN(device, num_elements, CopyPage(this, page, offset)); + dh::LaunchN(num_elements, CopyPage(this, page, offset)); monitor_.Stop("Copy"); return num_elements; } @@ -382,7 +382,7 @@ void EllpackPageImpl::Compact(int device, EllpackPageImpl* page, CHECK_LE(page->base_rowid + page->n_rows, row_indexes.size()); gidx_buffer.SetDevice(device); page->gidx_buffer.SetDevice(device); - dh::LaunchN(device, page->n_rows, CompactPage(this, page, row_indexes)); + dh::LaunchN(page->n_rows, CompactPage(this, page, row_indexes)); monitor_.Stop("Compact"); } diff --git a/src/data/simple_dmatrix.cu b/src/data/simple_dmatrix.cu index ff58c6bad6f1..9b1db6f44054 100644 --- a/src/data/simple_dmatrix.cu +++ b/src/data/simple_dmatrix.cu @@ -19,7 +19,7 @@ void CountRowOffsets(const AdapterBatchT& batch, common::Span offset, int device_idx, float missing) { IsValidFunctor is_valid(missing); // Count elements per row - dh::LaunchN(device_idx, batch.Size(), [=] __device__(size_t idx) { + dh::LaunchN(batch.Size(), [=] __device__(size_t idx) { auto element = batch.GetElement(idx); if (is_valid(element)) { atomicAdd(reinterpret_cast( // NOLINT diff --git a/src/gbm/gbtree.cu b/src/gbm/gbtree.cu index 33bca68c3e49..eca8f4dbd269 100644 --- a/src/gbm/gbtree.cu +++ b/src/gbm/gbtree.cu @@ -20,14 +20,13 @@ void GPUCopyGradient(HostDeviceVector const *in_gpair, auto v_in = VectorView{in, group_id}; out_gpair->Resize(v_in.Size()); auto d_out = out_gpair->DeviceSpan(); - dh::LaunchN(dh::CurrentDevice(), v_in.Size(), - [=] __device__(size_t i) { d_out[i] = v_in[i]; }); + dh::LaunchN(v_in.Size(), [=] __device__(size_t i) { d_out[i] = v_in[i]; }); } void GPUDartPredictInc(common::Span out_predts, common::Span predts, float tree_w, size_t n_rows, bst_group_t n_groups, bst_group_t group) { - dh::LaunchN(dh::CurrentDevice(), n_rows, [=]XGBOOST_DEVICE(size_t ridx) { + dh::LaunchN(n_rows, [=] XGBOOST_DEVICE(size_t ridx) { const size_t offset = ridx * n_groups + group; out_predts[offset] += (predts[offset] * tree_w); }); @@ -37,7 +36,7 @@ void GPUDartInplacePredictInc(common::Span out_predts, common::Span predts, float tree_w, size_t n_rows, float base_score, bst_group_t n_groups, bst_group_t group) { - dh::LaunchN(dh::CurrentDevice(), n_rows, [=] XGBOOST_DEVICE(size_t ridx) { + dh::LaunchN(n_rows, [=] XGBOOST_DEVICE(size_t ridx) { const size_t offset = ridx * n_groups + group; out_predts[offset] += (predts[offset] - base_score) * tree_w; }); diff --git a/src/linear/updater_gpu_coordinate.cu b/src/linear/updater_gpu_coordinate.cu index 685ec85f9612..5d83cddb7bb9 100644 --- a/src/linear/updater_gpu_coordinate.cu +++ b/src/linear/updater_gpu_coordinate.cu @@ -193,7 +193,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT void UpdateBiasResidual(float dbias, int group_idx, int num_groups) { if (dbias == 0.0f) return; auto d_gpair = dh::ToSpan(gpair_); - dh::LaunchN(learner_param_->gpu_id, num_row_, [=] __device__(size_t idx) { + dh::LaunchN(num_row_, [=] __device__(size_t idx) { auto &g = d_gpair[idx * num_groups + group_idx]; g += GradientPair(g.GetHess() * dbias, 0); }); @@ -222,7 +222,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT common::Span d_gpair = dh::ToSpan(gpair_); common::Span d_col = dh::ToSpan(data_).subspan(row_ptr_[fidx]); size_t col_size = row_ptr_[fidx + 1] - row_ptr_[fidx]; - dh::LaunchN(learner_param_->gpu_id, col_size, [=] __device__(size_t idx) { + dh::LaunchN(col_size, [=] __device__(size_t idx) { auto entry = d_col[idx]; auto &g = d_gpair[entry.index * num_groups + group_idx]; g += GradientPair(g.GetHess() * dw * entry.fvalue, 0); diff --git a/src/metric/auc.cu b/src/metric/auc.cu index f333ecf14527..708b424f9203 100644 --- a/src/metric/auc.cu +++ b/src/metric/auc.cu @@ -118,12 +118,12 @@ GPUBinaryAUC(common::Span predts, MetaInfo const &info, return thrust::make_pair(fp, tp); }; // NOLINT auto d_fptp = dh::ToSpan(cache->fptp); - dh::LaunchN(device, d_sorted_idx.size(), + dh::LaunchN(d_sorted_idx.size(), [=] __device__(size_t i) { d_fptp[i] = get_fp_tp(i); }); dh::XGBDeviceAllocator alloc; auto d_unique_idx = dh::ToSpan(cache->unique_idx); - dh::Iota(d_unique_idx, device); + dh::Iota(d_unique_idx); auto uni_key = dh::MakeTransformIterator( thrust::make_counting_iterator(0), @@ -144,7 +144,7 @@ GPUBinaryAUC(common::Span predts, MetaInfo const &info, auto d_neg_pos = dh::ToSpan(cache->neg_pos); // scatter unique negaive/positive values // shift to right by 1 with initial value being 0 - dh::LaunchN(device, d_unique_idx.size(), [=] __device__(size_t i) { + dh::LaunchN(d_unique_idx.size(), [=] __device__(size_t i) { if (d_unique_idx[i] == 0) { // first unique index is 0 assert(i == 0); d_neg_pos[0] = {0, 0}; @@ -183,7 +183,7 @@ void Transpose(common::Span in, common::Span out, size_t m, size_t n, int32_t device) { CHECK_EQ(in.size(), out.size()); CHECK_EQ(in.size(), m * n); - dh::LaunchN(device, in.size(), [=] __device__(size_t i) { + dh::LaunchN(in.size(), [=] __device__(size_t i) { size_t col = i / m; size_t row = i % m; size_t idx = row * n + col; @@ -255,9 +255,8 @@ float GPUMultiClassAUCOVR(common::Span predts, MetaInfo const &info if (n_samples == 0) { dh::TemporaryArray resutls(n_classes * 4, 0.0f); auto d_results = dh::ToSpan(resutls); - dh::LaunchN(device, n_classes * 4, [=]__device__(size_t i) { - d_results[i] = 0.0f; - }); + dh::LaunchN(n_classes * 4, + [=] __device__(size_t i) { d_results[i] = 0.0f; }); auto local_area = d_results.subspan(0, n_classes); auto fp = d_results.subspan(n_classes, n_classes); auto tp = d_results.subspan(2 * n_classes, n_classes); @@ -273,9 +272,8 @@ float GPUMultiClassAUCOVR(common::Span predts, MetaInfo const &info dh::TemporaryArray class_ptr(n_classes + 1, 0); auto d_class_ptr = dh::ToSpan(class_ptr); - dh::LaunchN(device, n_classes + 1, [=]__device__(size_t i) { - d_class_ptr[i] = i * n_samples; - }); + dh::LaunchN(n_classes + 1, + [=] __device__(size_t i) { d_class_ptr[i] = i * n_samples; }); // no out-of-place sort for thrust, cub sort doesn't accept general iterator. So can't // use transform iterator in sorting. auto d_sorted_idx = dh::ToSpan(cache->sorted_idx); @@ -301,7 +299,7 @@ float GPUMultiClassAUCOVR(common::Span predts, MetaInfo const &info float tp = label * w; return thrust::make_pair(fp, tp); }; // NOLINT - dh::LaunchN(device, d_sorted_idx.size(), + dh::LaunchN(d_sorted_idx.size(), [=] __device__(size_t i) { d_fptp[i] = get_fp_tp(i); }); /** @@ -309,7 +307,7 @@ float GPUMultiClassAUCOVR(common::Span predts, MetaInfo const &info */ dh::XGBDeviceAllocator alloc; auto d_unique_idx = dh::ToSpan(cache->unique_idx); - dh::Iota(d_unique_idx, device); + dh::Iota(d_unique_idx); auto uni_key = dh::MakeTransformIterator>( thrust::make_counting_iterator(0), [=] __device__(size_t i) { uint32_t class_id = i / n_samples; @@ -363,7 +361,7 @@ float GPUMultiClassAUCOVR(common::Span predts, MetaInfo const &info auto d_neg_pos = dh::ToSpan(cache->neg_pos); // When dataset is not empty, each class must have at least 1 (unique) sample // prediction, so no need to handle special case. - dh::LaunchN(device, d_unique_idx.size(), [=]__device__(size_t i) { + dh::LaunchN(d_unique_idx.size(), [=] __device__(size_t i) { if (d_unique_idx[i] % n_samples == 0) { // first unique index is 0 assert(d_unique_idx[i] % n_samples == 0); d_neg_pos[d_unique_idx[i]] = {0, 0}; // class_id * n_samples = i @@ -419,7 +417,7 @@ float GPUMultiClassAUCOVR(common::Span predts, MetaInfo const &info auto tp = d_results.subspan(2 * n_classes, n_classes); auto auc = d_results.subspan(3 * n_classes, n_classes); - dh::LaunchN(device, n_classes, [=] __device__(size_t c) { + dh::LaunchN(n_classes, [=] __device__(size_t c) { auc[c] = s_d_auc[c]; auto last = d_fptp[n_samples * c + (n_samples - 1)]; fp[c] = last.first; diff --git a/src/metric/rank_metric.cu b/src/metric/rank_metric.cu index 70e4a808a04e..834ca3078e0b 100644 --- a/src/metric/rank_metric.cu +++ b/src/metric/rank_metric.cu @@ -107,7 +107,7 @@ struct EvalPrecisionGpu { int device_id = -1; dh::safe_cuda(cudaGetDevice(&device_id)); // For each group item compute the aggregated precision - dh::LaunchN(device_id, nitems, nullptr, [=] __device__(uint32_t idx) { + dh::LaunchN(nitems, nullptr, [=] __device__(uint32_t idx) { const auto group_idx = dgroup_idx[idx]; const auto group_begin = dgroups[group_idx]; const auto ridx = idx - group_begin; @@ -151,7 +151,7 @@ struct EvalNDCGGpu { dh::safe_cuda(cudaGetDevice(&device_id)); // For each group item compute the aggregated precision - dh::LaunchN(device_id, nitems, nullptr, [=] __device__(uint32_t idx) { + dh::LaunchN(nitems, nullptr, [=] __device__(uint32_t idx) { const auto group_idx = dgroup_idx[idx]; const auto group_begin = dgroups[group_idx]; const auto ridx = idx - group_begin; @@ -185,7 +185,7 @@ struct EvalNDCGGpu { int device_id = -1; dh::safe_cuda(cudaGetDevice(&device_id)); // Compute the group's DCG and reduce it across all groups - dh::LaunchN(device_id, ngroups, nullptr, [=] __device__(uint32_t gidx) { + dh::LaunchN(ngroups, nullptr, [=] __device__(uint32_t gidx) { if (didcg[gidx] == 0.0f) { ddcg[gidx] = (ecfg.minus) ? 0.0f : 1.0f; } else { @@ -244,7 +244,7 @@ struct EvalMAPGpu { int device_id = -1; dh::safe_cuda(cudaGetDevice(&device_id)); // For each group item compute the aggregated precision - dh::LaunchN(device_id, nitems, nullptr, [=] __device__(uint32_t idx) { + dh::LaunchN(nitems, nullptr, [=] __device__(uint32_t idx) { if (DetermineNonTrivialLabelLambda(idx)) { const auto group_idx = dgroup_idx[idx]; const auto group_begin = dgroups[group_idx]; @@ -257,7 +257,7 @@ struct EvalMAPGpu { }); // Aggregate the group's item precisions - dh::LaunchN(device_id, ngroups, nullptr, [=] __device__(uint32_t gidx) { + dh::LaunchN(ngroups, nullptr, [=] __device__(uint32_t gidx) { auto nhits = dgroups[gidx + 1] ? dhits[dgroups[gidx + 1] - 1] : 0; if (nhits != 0) { dsumap[gidx] /= nhits; @@ -391,7 +391,7 @@ struct EvalAucPRGpu : public Metric { int device_id = -1; dh::safe_cuda(cudaGetDevice(&device_id)); // For each group item compute the aggregated precision - dh::LaunchN<1, 32>(device_id, ngroups, nullptr, [=] __device__(uint32_t gidx) { + dh::LaunchN<1, 32>(ngroups, nullptr, [=] __device__(uint32_t gidx) { // We need pos > 0 && neg > 0 if (dtotal_pos[gidx] <= 0.0 || dtotal_neg[gidx] <= 0.0) { atomicAdd(dauc_error, 1); diff --git a/src/objective/rank_obj.cu b/src/objective/rank_obj.cu index 1fa584930c07..164b60611ef3 100644 --- a/src/objective/rank_obj.cu +++ b/src/objective/rank_obj.cu @@ -672,7 +672,7 @@ class SortedLabelList : dh::SegmentSorter { int device_id = -1; dh::safe_cuda(cudaGetDevice(&device_id)); // For each instance in the group, compute the gradient pair concurrently - dh::LaunchN(device_id, niter, nullptr, [=] __device__(uint32_t idx) { + dh::LaunchN(niter, nullptr, [=] __device__(uint32_t idx) { // First, determine the group 'idx' belongs to uint32_t item_idx = idx % total_items; uint32_t group_idx = diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 5b91fc1bf508..fb9a10588c41 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -488,6 +488,7 @@ void ExtractPaths( dh::device_vector> *paths, DeviceModel *model, dh::device_vector *path_categories, int gpu_id) { + dh::safe_cuda(cudaSetDevice(gpu_id)); auto& device_model = *model; dh::caching_device_vector info(device_model.nodes.Size()); @@ -558,7 +559,7 @@ void ExtractPaths( auto d_model_categories = device_model.categories.DeviceSpan(); common::Span d_path_categories = dh::ToSpan(*path_categories); - dh::LaunchN(gpu_id, info.size(), [=] __device__(size_t idx) { + dh::LaunchN(info.size(), [=] __device__(size_t idx) { auto path_info = d_info[idx]; size_t tree_offset = d_tree_segments[path_info.tree_idx]; TreeView tree{0, path_info.tree_idx, d_nodes, @@ -856,7 +857,6 @@ class GPUPredictor : public xgboost::Predictor { const auto margin = p_fmat->Info().base_margin_.ConstDeviceSpan(); float base_score = model.learner_model_param->base_score; dh::LaunchN( - generic_param_->gpu_id, p_fmat->Info().num_row_ * model.learner_model_param->num_output_group, [=] __device__(size_t idx) { phis[(idx + 1) * contributions_columns - 1] += @@ -917,7 +917,6 @@ class GPUPredictor : public xgboost::Predictor { float base_score = model.learner_model_param->base_score; size_t n_features = model.learner_model_param->num_feature; dh::LaunchN( - generic_param_->gpu_id, p_fmat->Info().num_row_ * model.learner_model_param->num_output_group, [=] __device__(size_t idx) { size_t group = idx % ngroup; diff --git a/src/tree/gpu_hist/row_partitioner.cu b/src/tree/gpu_hist/row_partitioner.cu index cf9e3f769c5b..ebb3666e05cb 100644 --- a/src/tree/gpu_hist/row_partitioner.cu +++ b/src/tree/gpu_hist/row_partitioner.cu @@ -96,7 +96,7 @@ void RowPartitioner::SortPosition(common::Span position, void Reset(int device_idx, common::Span ridx, common::Span position) { CHECK_EQ(ridx.size(), position.size()); - dh::LaunchN(device_idx, ridx.size(), [=] __device__(size_t idx) { + dh::LaunchN(ridx.size(), [=] __device__(size_t idx) { ridx[idx] = idx; position[idx] = 0; }); @@ -131,7 +131,7 @@ common::Span RowPartitioner::GetRows( // Return empty span here as a valid result // Will error if we try to construct a span from a pointer with size 0 if (segment.Size() == 0) { - return common::Span(); + return {}; } return ridx_.CurrentSpan().subspan(segment.begin, segment.Size()); } @@ -180,7 +180,7 @@ void RowPartitioner::SortPositionAndCopy(const Segment& segment, const auto d_position_other = position_.Other() + segment.begin; const auto d_ridx_current = ridx_.Current() + segment.begin; const auto d_ridx_other = ridx_.Other() + segment.begin; - dh::LaunchN(device_idx_, segment.Size(), stream, [=] __device__(size_t idx) { + dh::LaunchN(segment.Size(), stream, [=] __device__(size_t idx) { d_position_current[idx] = d_position_other[idx]; d_ridx_current[idx] = d_ridx_other[idx]; }); diff --git a/src/tree/gpu_hist/row_partitioner.cuh b/src/tree/gpu_hist/row_partitioner.cuh index 96f327fb9698..c236b90090b6 100644 --- a/src/tree/gpu_hist/row_partitioner.cuh +++ b/src/tree/gpu_hist/row_partitioner.cuh @@ -120,7 +120,7 @@ class RowPartitioner { int64_t* d_left_count = left_counts_.data().get() + nidx; // Launch 1 thread for each row - dh::LaunchN<1, 128>(device_idx_, segment.Size(), [=] __device__(size_t idx) { + dh::LaunchN<1, 128>(segment.Size(), [=] __device__(size_t idx) { // LaunchN starts from zero, so we restore the row index by adding segment.begin idx += segment.begin; RowIndexT ridx = d_ridx[idx]; @@ -160,7 +160,7 @@ class RowPartitioner { void FinalisePosition(FinalisePositionOpT op) { auto d_position = position_.Current(); const auto d_ridx = ridx_.Current(); - dh::LaunchN(device_idx_, position_.Size(), [=] __device__(size_t idx) { + dh::LaunchN(position_.Size(), [=] __device__(size_t idx) { auto position = d_position[idx]; RowIndexT ridx = d_ridx[idx]; bst_node_t new_position = op(ridx, position); diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 563930f74544..b1bf0af56d0e 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -94,8 +94,8 @@ class DeviceHistogram { void Reset() { auto d_data = data_.data().get(); - dh::LaunchN(device_id_, data_.size(), - [=] __device__(size_t idx) { d_data[idx] = 0.0f; }); + dh::LaunchN(data_.size(), + [=] __device__(size_t idx) { d_data[idx] = 0.0f; }); nidx_map_.clear(); } bool HistogramExists(int nidx) const { @@ -130,7 +130,7 @@ class DeviceHistogram { } // Zero recycled memory auto d_data = data_.data().get() + nidx_map_[nidx]; - dh::LaunchN(device_id_, n_bins_ * 2, + dh::LaunchN(n_bins_ * 2, [=] __device__(size_t idx) { d_data[idx] = 0.0f; }); } else { // Append new node histogram @@ -367,7 +367,7 @@ struct GPUHistMakerDevice { dh::TemporaryArray entries(2); auto evaluator = tree_evaluator.GetEvaluator(); auto d_entries = entries.data().get(); - dh::LaunchN(device_id, 2, [=] __device__(size_t idx) { + dh::LaunchN(2, [=] __device__(size_t idx) { auto split = d_splits_out[idx]; auto nidx = idx == 0 ? left_nidx : right_nidx; @@ -402,7 +402,7 @@ struct GPUHistMakerDevice { auto d_node_hist_histogram = hist.GetNodeHistogram(nidx_histogram); auto d_node_hist_subtraction = hist.GetNodeHistogram(nidx_subtraction); - dh::LaunchN(device_id, page->Cuts().TotalBins(), [=] __device__(size_t idx) { + dh::LaunchN(page->Cuts().TotalBins(), [=] __device__(size_t idx) { d_node_hist_subtraction[idx] = d_node_hist_parent[idx] - d_node_hist_histogram[idx]; }); @@ -545,7 +545,7 @@ struct GPUHistMakerDevice { auto d_node_sum_gradients = device_node_sum_gradients.data().get(); auto evaluator = tree_evaluator.GetEvaluator(); - dh::LaunchN(device_id, d_ridx.size(), [=] __device__(int local_idx) { + dh::LaunchN(d_ridx.size(), [=] __device__(int local_idx) { int pos = d_position[local_idx]; bst_float weight = evaluator.CalcWeight( pos, param_d, GradStats{d_node_sum_gradients[pos]}); @@ -676,7 +676,7 @@ struct GPUHistMakerDevice { auto evaluator = tree_evaluator.GetEvaluator(); GPUTrainingParam gpu_param(param); auto depth = p_tree->GetDepth(kRootNIdx); - dh::LaunchN(device_id, 1, [=] __device__(size_t idx) { + dh::LaunchN(1, [=] __device__(size_t idx) { float left_weight = evaluator.CalcWeight(kRootNIdx, gpu_param, GradStats{split.left_sum}); float right_weight = evaluator.CalcWeight( diff --git a/tests/cpp/common/test_device_helpers.cu b/tests/cpp/common/test_device_helpers.cu index 09349be14843..cb7176c00758 100644 --- a/tests/cpp/common/test_device_helpers.cu +++ b/tests/cpp/common/test_device_helpers.cu @@ -18,8 +18,8 @@ void TestAtomicSizeT() { size_t constexpr kThreads = 235; dh::device_vector out(1, 0); auto d_out = dh::ToSpan(out); - dh::LaunchN(0, kThreads, [=]__device__(size_t idx){ - atomicAdd(&d_out[0], static_cast(1)); + dh::LaunchN(kThreads, [=] __device__(size_t idx) { + atomicAdd(&d_out[0], static_cast(1)); }); ASSERT_EQ(out[0], kThreads); } @@ -32,7 +32,7 @@ void TestSegmentID() { std::vector segments{0, 1, 3}; thrust::device_vector d_segments(segments); auto s_segments = dh::ToSpan(d_segments); - dh::LaunchN(0, 1, [=]__device__(size_t idx) { + dh::LaunchN(1, [=]__device__(size_t idx) { auto id = dh::SegmentId(s_segments, 0); SPAN_CHECK(id == 0); id = dh::SegmentId(s_segments, 1); diff --git a/tests/cpp/common/test_gpu_compressed_iterator.cu b/tests/cpp/common/test_gpu_compressed_iterator.cu index d9eb5c9a3bbb..779202a62002 100644 --- a/tests/cpp/common/test_gpu_compressed_iterator.cu +++ b/tests/cpp/common/test_gpu_compressed_iterator.cu @@ -53,14 +53,14 @@ TEST(CompressedIterator, TestGPU) { // write the data on device auto input_data_d = input_d.data().get(); auto buffer_data_d = buffer_d.data().get(); - dh::LaunchN(0, input_d.size(), - WriteSymbolFunction(cbw, buffer_data_d, input_data_d)); + dh::LaunchN(input_d.size(), + WriteSymbolFunction(cbw, buffer_data_d, input_data_d)); // read the data on device CompressedIterator ci(buffer_d.data().get(), alphabet_size); thrust::device_vector output_d(input.size()); auto output_data_d = output_d.data().get(); - dh::LaunchN(0, output_d.size(), ReadSymbolFunction(ci, output_data_d)); + dh::LaunchN(output_d.size(), ReadSymbolFunction(ci, output_data_d)); std::vector output(output_d.size()); thrust::copy(output_d.begin(), output_d.end(), output.begin()); diff --git a/tests/cpp/common/test_span.cu b/tests/cpp/common/test_span.cu index 901c640233e0..539a9beb1833 100644 --- a/tests/cpp/common/test_span.cu +++ b/tests/cpp/common/test_span.cu @@ -91,14 +91,14 @@ TEST(GPUSpan, FromOther) { TEST(GPUSpan, Assignment) { dh::safe_cuda(cudaSetDevice(0)); TestStatus status; - dh::LaunchN(0, 16, TestAssignment{status.Data()}); + dh::LaunchN(16, TestAssignment{status.Data()}); ASSERT_EQ(status.Get(), 1); } TEST(GPUSpan, TestStatus) { dh::safe_cuda(cudaSetDevice(0)); TestStatus status; - dh::LaunchN(0, 16, TestTestStatus{status.Data()}); + dh::LaunchN(16, TestTestStatus{status.Data()}); ASSERT_EQ(status.Get(), -1); } @@ -143,7 +143,7 @@ TEST(GPUSpan, WithTrust) { thrust::copy(thrust::device, d_vec.begin(), d_vec.end(), d_vec1.begin()); Span s (d_vec1.data().get(), d_vec.size()); - dh::LaunchN(0, 16, TestEqual{ + dh::LaunchN(16, TestEqual{ thrust::raw_pointer_cast(d_vec1.data()), s.data(), status.Data()}); ASSERT_EQ(status.Get(), 1); @@ -158,14 +158,14 @@ TEST(GPUSpan, WithTrust) { TEST(GPUSpan, BeginEnd) { dh::safe_cuda(cudaSetDevice(0)); TestStatus status; - dh::LaunchN(0, 16, TestBeginEnd{status.Data()}); + dh::LaunchN(16, TestBeginEnd{status.Data()}); ASSERT_EQ(status.Get(), 1); } TEST(GPUSpan, RBeginREnd) { dh::safe_cuda(cudaSetDevice(0)); TestStatus status; - dh::LaunchN(0, 16, TestRBeginREnd{status.Data()}); + dh::LaunchN(16, TestRBeginREnd{status.Data()}); ASSERT_EQ(status.Get(), 1); } @@ -197,14 +197,14 @@ TEST(GPUSpan, Modify) { TEST(GPUSpan, Observers) { dh::safe_cuda(cudaSetDevice(0)); TestStatus status; - dh::LaunchN(0, 16, TestObservers{status.Data()}); + dh::LaunchN(16, TestObservers{status.Data()}); ASSERT_EQ(status.Get(), 1); } TEST(GPUSpan, Compare) { dh::safe_cuda(cudaSetDevice(0)); TestStatus status; - dh::LaunchN(0, 16, TestIterCompare{status.Data()}); + dh::LaunchN(16, TestIterCompare{status.Data()}); ASSERT_EQ(status.Get(), 1); } @@ -231,7 +231,7 @@ TEST(GPUSpanDeathTest, ElementAccess) { thrust::copy(h_vec.begin(), h_vec.end(), d_vec.begin()); Span span (d_vec.data().get(), d_vec.size()); - dh::LaunchN(0, 17, TestElementAccess{span}); + dh::LaunchN(17, TestElementAccess{span}); }; testing::internal::CaptureStdout(); @@ -387,42 +387,42 @@ TEST(GPUSpan, Subspan) { TEST(GPUSpanIter, Construct) { dh::safe_cuda(cudaSetDevice(0)); TestStatus status; - dh::LaunchN(0, 16, TestIterConstruct{status.Data()}); + dh::LaunchN(16, TestIterConstruct{status.Data()}); ASSERT_EQ(status.Get(), 1); } TEST(GPUSpanIter, Ref) { dh::safe_cuda(cudaSetDevice(0)); TestStatus status; - dh::LaunchN(0, 16, TestIterRef{status.Data()}); + dh::LaunchN(16, TestIterRef{status.Data()}); ASSERT_EQ(status.Get(), 1); } TEST(GPUSpanIter, Calculate) { dh::safe_cuda(cudaSetDevice(0)); TestStatus status; - dh::LaunchN(0, 16, TestIterCalculate{status.Data()}); + dh::LaunchN(16, TestIterCalculate{status.Data()}); ASSERT_EQ(status.Get(), 1); } TEST(GPUSpanIter, Compare) { dh::safe_cuda(cudaSetDevice(0)); TestStatus status; - dh::LaunchN(0, 16, TestIterCompare{status.Data()}); + dh::LaunchN(16, TestIterCompare{status.Data()}); ASSERT_EQ(status.Get(), 1); } TEST(GPUSpan, AsBytes) { dh::safe_cuda(cudaSetDevice(0)); TestStatus status; - dh::LaunchN(0, 16, TestAsBytes{status.Data()}); + dh::LaunchN(16, TestAsBytes{status.Data()}); ASSERT_EQ(status.Get(), 1); } TEST(GPUSpan, AsWritableBytes) { dh::safe_cuda(cudaSetDevice(0)); TestStatus status; - dh::LaunchN(0, 16, TestAsWritableBytes{status.Data()}); + dh::LaunchN(16, TestAsWritableBytes{status.Data()}); ASSERT_EQ(status.Get(), 1); } diff --git a/tests/cpp/data/test_device_adapter.cu b/tests/cpp/data/test_device_adapter.cu index 34c8e93b7822..f62b3dd80d03 100644 --- a/tests/cpp/data/test_device_adapter.cu +++ b/tests/cpp/data/test_device_adapter.cu @@ -33,7 +33,7 @@ void TestCudfAdapter() EXPECT_EQ(batch.Size(), kRowsA + kRowsB); EXPECT_NO_THROW({ - dh::LaunchN(0, batch.Size(), [=] __device__(size_t idx) { + dh::LaunchN(batch.Size(), [=] __device__(size_t idx) { auto element = batch.GetElement(idx); KERNEL_CHECK(element.row_idx == idx / 2); if (idx % 2 == 0) { diff --git a/tests/cpp/data/test_ellpack_page.cu b/tests/cpp/data/test_ellpack_page.cu index 2ea89331a2e3..6b5b708147e2 100644 --- a/tests/cpp/data/test_ellpack_page.cu +++ b/tests/cpp/data/test_ellpack_page.cu @@ -166,10 +166,10 @@ TEST(EllpackPage, Copy) { EXPECT_EQ(impl->base_rowid, current_row); for (size_t i = 0; i < impl->Size(); i++) { - dh::LaunchN(0, kCols, ReadRowFunction(impl->GetDeviceAccessor(0), current_row, row_d.data().get())); + dh::LaunchN(kCols, ReadRowFunction(impl->GetDeviceAccessor(0), current_row, row_d.data().get())); thrust::copy(row_d.begin(), row_d.end(), row.begin()); - dh::LaunchN(0, kCols, ReadRowFunction(result.GetDeviceAccessor(0), current_row, row_result_d.data().get())); + dh::LaunchN(kCols, ReadRowFunction(result.GetDeviceAccessor(0), current_row, row_result_d.data().get())); thrust::copy(row_result_d.begin(), row_result_d.end(), row_result.begin()); EXPECT_EQ(row, row_result); @@ -221,12 +221,14 @@ TEST(EllpackPage, Compact) { continue; } - dh::LaunchN(0, kCols, ReadRowFunction(impl->GetDeviceAccessor(0), current_row, row_d.data().get())); - dh::safe_cuda (cudaDeviceSynchronize()); + dh::LaunchN(kCols, ReadRowFunction(impl->GetDeviceAccessor(0), + current_row, row_d.data().get())); + dh::safe_cuda(cudaDeviceSynchronize()); thrust::copy(row_d.begin(), row_d.end(), row.begin()); - dh::LaunchN(0, kCols, - ReadRowFunction(result.GetDeviceAccessor(0), compacted_row, row_result_d.data().get())); + dh::LaunchN(kCols, + ReadRowFunction(result.GetDeviceAccessor(0), compacted_row, + row_result_d.data().get())); thrust::copy(row_result_d.begin(), row_result_d.end(), row_result.begin()); EXPECT_EQ(row, row_result); diff --git a/tests/cpp/data/test_sparse_page_dmatrix.cu b/tests/cpp/data/test_sparse_page_dmatrix.cu index 29ac8904016c..26058a8366c5 100644 --- a/tests/cpp/data/test_sparse_page_dmatrix.cu +++ b/tests/cpp/data/test_sparse_page_dmatrix.cu @@ -124,10 +124,10 @@ TEST(SparsePageDMatrix, MultipleEllpackPageContent) { EXPECT_EQ(impl_ext->base_rowid, current_row); for (size_t i = 0; i < impl_ext->Size(); i++) { - dh::LaunchN(0, kCols, ReadRowFunction(impl->GetDeviceAccessor(0), current_row, row_d.data().get())); + dh::LaunchN(kCols, ReadRowFunction(impl->GetDeviceAccessor(0), current_row, row_d.data().get())); thrust::copy(row_d.begin(), row_d.end(), row.begin()); - dh::LaunchN(0, kCols, ReadRowFunction(impl_ext->GetDeviceAccessor(0), current_row, row_ext_d.data().get())); + dh::LaunchN(kCols, ReadRowFunction(impl_ext->GetDeviceAccessor(0), current_row, row_ext_d.data().get())); thrust::copy(row_ext_d.begin(), row_ext_d.end(), row_ext.begin()); EXPECT_EQ(row, row_ext);