Skip to content

Commit

Permalink
Remove use of device_idx in dh::LaunchN. (#7063)
Browse files Browse the repository at this point in the history
It's an unused parameter, removing it can make the CI log more readable.
  • Loading branch information
trivialfis committed Jun 29, 2021
1 parent dd4db34 commit 1c8fdf2
Show file tree
Hide file tree
Showing 25 changed files with 105 additions and 107 deletions.
12 changes: 6 additions & 6 deletions src/common/device_helpers.cuh
Expand Up @@ -279,7 +279,7 @@ class LaunchKernel {
};

template <int ITEMS_PER_THREAD = 8, int BLOCK_THREADS = 256, typename L>
inline void LaunchN(int device_idx, size_t n, cudaStream_t stream, L lambda) {
inline void LaunchN(size_t n, cudaStream_t stream, L lambda) {
if (n == 0) {
return;
}
Expand All @@ -291,13 +291,13 @@ inline void LaunchN(int device_idx, size_t n, cudaStream_t stream, L lambda) {

// Default stream version
template <int ITEMS_PER_THREAD = 8, int BLOCK_THREADS = 256, typename L>
inline void LaunchN(int device_idx, size_t n, L lambda) {
LaunchN<ITEMS_PER_THREAD, BLOCK_THREADS>(device_idx, n, nullptr, lambda);
inline void LaunchN(size_t n, L lambda) {
LaunchN<ITEMS_PER_THREAD, BLOCK_THREADS>(n, nullptr, lambda);
}

template <typename Container>
void Iota(Container array, int32_t device = CurrentDevice()) {
LaunchN(device, array.size(), [=] __device__(size_t i) { array[i] = i; });
void Iota(Container array) {
LaunchN(array.size(), [=] __device__(size_t i) { array[i] = i; });
}

namespace detail {
Expand Down Expand Up @@ -539,7 +539,7 @@ class TemporaryArray {
int device = 0;
dh::safe_cuda(cudaGetDevice(&device));
auto d_data = ptr_.get();
LaunchN(device, this->size(), [=] __device__(size_t idx) { d_data[idx] = val; });
LaunchN(this->size(), [=] __device__(size_t idx) { d_data[idx] = val; });
}
thrust::device_ptr<T> data() { return ptr_; } // NOLINT
size_t size() { return size_; } // NOLINT
Expand Down
6 changes: 3 additions & 3 deletions src/common/hist_util.cu
Expand Up @@ -159,7 +159,7 @@ void RemoveDuplicatedCategories(
auto d_new_cuts_size = dh::ToSpan(new_cuts_size);
auto d_new_columns_ptr = dh::ToSpan(new_column_scan);
CHECK_EQ(new_column_scan.size(), new_cuts_size.size());
dh::LaunchN(device, new_column_scan.size(), [=] __device__(size_t idx) {
dh::LaunchN(new_column_scan.size(), [=] __device__(size_t idx) {
d_old_column_sizes_scan[idx] = d_new_columns_ptr[idx];
if (idx == d_new_columns_ptr.size() - 1) {
return;
Expand Down Expand Up @@ -248,14 +248,14 @@ void ProcessWeightedBatch(int device, const SparsePage& page,
<< "Must have at least 1 group for ranking.";
CHECK_EQ(weights.size(), d_group_ptr.size() - 1)
<< "Weight size should equal to number of groups.";
dh::LaunchN(device, temp_weights.size(), [=] __device__(size_t idx) {
dh::LaunchN(temp_weights.size(), [=] __device__(size_t idx) {
size_t element_idx = idx + begin;
size_t ridx = dh::SegmentId(row_ptrs, element_idx);
bst_group_t group_idx = dh::SegmentId(d_group_ptr, ridx + base_rowid);
d_temp_weights[idx] = weights[group_idx];
});
} else {
dh::LaunchN(device, temp_weights.size(), [=] __device__(size_t idx) {
dh::LaunchN(temp_weights.size(), [=] __device__(size_t idx) {
size_t element_idx = idx + begin;
size_t ridx = dh::SegmentId(row_ptrs, element_idx);
d_temp_weights[idx] = weights[ridx + base_rowid];
Expand Down
2 changes: 1 addition & 1 deletion src/common/hist_util.cuh
Expand Up @@ -41,7 +41,7 @@ void GetColumnSizesScan(int device, size_t num_columns, size_t num_cuts_per_feat

dh::XGBCachingDeviceAllocator<char> alloc;
auto d_column_sizes_scan = column_sizes_scan->data().get();
dh::LaunchN(device, end - begin, [=] __device__(size_t idx) {
dh::LaunchN(end - begin, [=] __device__(size_t idx) {
auto e = batch_iter[begin + idx];
if (is_valid(e)) {
atomicAdd(&d_column_sizes_scan[e.column_idx], static_cast<size_t>(1));
Expand Down
5 changes: 2 additions & 3 deletions src/common/host_device_vector.cu
Expand Up @@ -93,9 +93,8 @@ class HostDeviceVectorImpl {
gpu_access_ = GPUAccess::kWrite;
SetDevice();
auto s_data = dh::ToSpan(*data_d_);
dh::LaunchN(device_, data_d_->size(), [=]XGBOOST_DEVICE(size_t i) {
s_data[i] = v;
});
dh::LaunchN(data_d_->size(),
[=] XGBOOST_DEVICE(size_t i) { s_data[i] = v; });
}
}

Expand Down
8 changes: 4 additions & 4 deletions src/common/quantile.cu
Expand Up @@ -61,7 +61,7 @@ void PruneImpl(int device,
Span<FeatureType const> feature_types,
Span<SketchEntry> out_cuts,
ToSketchEntry to_sketch_entry) {
dh::LaunchN(device, out_cuts.size(), [=] __device__(size_t idx) {
dh::LaunchN(out_cuts.size(), [=] __device__(size_t idx) {
size_t column_id = dh::SegmentId(cuts_ptr, idx);
auto out_column = out_cuts.subspan(
cuts_ptr[column_id], cuts_ptr[column_id + 1] - cuts_ptr[column_id]);
Expand Down Expand Up @@ -221,7 +221,7 @@ void MergeImpl(int32_t device, Span<SketchEntry const> const &d_x,
auto d_merge_path = MergePath(d_x, x_ptr, d_y, y_ptr, out, out_ptr);
auto d_out = out;

dh::LaunchN(device, d_out.size(), [=] __device__(size_t idx) {
dh::LaunchN(d_out.size(), [=] __device__(size_t idx) {
auto column_id = dh::SegmentId(out_ptr, idx);
idx -= out_ptr[column_id];

Expand Down Expand Up @@ -487,7 +487,7 @@ void SketchContainer::FixError() {
dh::safe_cuda(cudaSetDevice(device_));
auto d_columns_ptr = this->columns_ptr_.ConstDeviceSpan();
auto in = dh::ToSpan(this->Current());
dh::LaunchN(device_, in.size(), [=] __device__(size_t idx) {
dh::LaunchN(in.size(), [=] __device__(size_t idx) {
auto column_id = dh::SegmentId(d_columns_ptr, idx);
auto in_column = in.subspan(d_columns_ptr[column_id],
d_columns_ptr[column_id + 1] -
Expand Down Expand Up @@ -627,7 +627,7 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) {
auto out_cut_values = p_cuts->cut_values_.DeviceSpan();
auto d_ft = feature_types_.ConstDeviceSpan();

dh::LaunchN(0, total_bins, [=] __device__(size_t idx) {
dh::LaunchN(total_bins, [=] __device__(size_t idx) {
auto column_id = dh::SegmentId(d_out_columns_ptr, idx);
auto in_column = in_cut_values.subspan(d_in_columns_ptr[column_id],
d_in_columns_ptr[column_id + 1] -
Expand Down
2 changes: 1 addition & 1 deletion src/common/ranking_utils.cuh
Expand Up @@ -44,7 +44,7 @@ SegmentedTrapezoidThreads(xgboost::common::Span<U> group_ptr,
CHECK_GE(group_ptr.size(), 1);
CHECK_EQ(group_ptr.size(), out_group_threads_ptr.size());
dh::LaunchN(
dh::CurrentDevice(), group_ptr.size(), [=] XGBOOST_DEVICE(size_t idx) {
group_ptr.size(), [=] XGBOOST_DEVICE(size_t idx) {
if (idx == 0) {
out_group_threads_ptr[0] = 0;
return;
Expand Down
9 changes: 5 additions & 4 deletions src/data/data.cu
Expand Up @@ -29,7 +29,7 @@ void CopyInfoImpl(ArrayInterface column, HostDeviceVector<float>* out) {

auto p_dst = thrust::device_pointer_cast(out->DevicePointer());

dh::LaunchN(ptr_device, column.num_rows, [=] __device__(size_t idx) {
dh::LaunchN(column.num_rows, [=] __device__(size_t idx) {
p_dst[idx] = column.GetElement(idx, 0);
});
}
Expand All @@ -49,10 +49,11 @@ void CopyGroupInfoImpl(ArrayInterface column, std::vector<bst_group_t>* out) {
<< "Expected integer for group info.";

auto ptr_device = SetDeviceToPtr(column.data);
CHECK_EQ(ptr_device, dh::CurrentDevice());
dh::TemporaryArray<bst_group_t> temp(column.num_rows);
auto d_tmp = temp.data();

dh::LaunchN(ptr_device, column.num_rows, [=] __device__(size_t idx) {
dh::LaunchN(column.num_rows, [=] __device__(size_t idx) {
d_tmp[idx] = column.GetElement<size_t>(idx, 0);
});
auto length = column.num_rows;
Expand All @@ -73,8 +74,8 @@ void CopyQidImpl(ArrayInterface array_interface,
dh::caching_device_vector<bool> flag(1);
auto d_flag = dh::ToSpan(flag);
auto d = SetDeviceToPtr(array_interface.data);
dh::LaunchN(d, 1, [=] __device__(size_t) { d_flag[0] = true; });
dh::LaunchN(d, array_interface.num_rows - 1, [=] __device__(size_t i) {
dh::LaunchN(1, [=] __device__(size_t) { d_flag[0] = true; });
dh::LaunchN(array_interface.num_rows - 1, [=] __device__(size_t i) {
if (array_interface.GetElement<uint32_t>(i, 0) >
array_interface.GetElement<uint32_t>(i + 1, 0)) {
d_flag[0] = false;
Expand Down
2 changes: 1 addition & 1 deletion src/data/device_adapter.cuh
Expand Up @@ -216,7 +216,7 @@ size_t GetRowCounts(const AdapterBatchT batch, common::Span<size_t> offset,
int device_idx, float missing) {
IsValidFunctor is_valid(missing);
// Count elements per row
dh::LaunchN(device_idx, batch.Size(), [=] __device__(size_t idx) {
dh::LaunchN(batch.Size(), [=] __device__(size_t idx) {
auto element = batch.GetElement(idx);
if (is_valid(element)) {
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
Expand Down
24 changes: 12 additions & 12 deletions src/data/ellpack_page.cu
Expand Up @@ -257,16 +257,16 @@ void WriteNullValues(EllpackPageImpl* dst, int device_idx,
common::CompressedBufferWriter writer(device_accessor.NumSymbols());
auto d_compressed_buffer = dst->gidx_buffer.DevicePointer();
auto row_stride = dst->row_stride;
dh::LaunchN(device_idx, row_stride * dst->n_rows, [=] __device__(size_t idx) {
auto writer_non_const =
writer; // For some reason this variable gets captured as const
size_t row_idx = idx / row_stride;
size_t row_offset = idx % row_stride;
if (row_offset >= row_counts[row_idx]) {
writer_non_const.AtomicWriteSymbol(d_compressed_buffer,
device_accessor.NullValue(), idx);
}
});
dh::LaunchN(row_stride * dst->n_rows, [=] __device__(size_t idx) {
// For some reason this variable got captured as const
auto writer_non_const = writer;
size_t row_idx = idx / row_stride;
size_t row_offset = idx % row_stride;
if (row_offset >= row_counts[row_idx]) {
writer_non_const.AtomicWriteSymbol(d_compressed_buffer,
device_accessor.NullValue(), idx);
}
});
}

template <typename AdapterBatch>
Expand Down Expand Up @@ -326,7 +326,7 @@ size_t EllpackPageImpl::Copy(int device, EllpackPageImpl* page, size_t offset) {
}
gidx_buffer.SetDevice(device);
page->gidx_buffer.SetDevice(device);
dh::LaunchN(device, num_elements, CopyPage(this, page, offset));
dh::LaunchN(num_elements, CopyPage(this, page, offset));
monitor_.Stop("Copy");
return num_elements;
}
Expand Down Expand Up @@ -382,7 +382,7 @@ void EllpackPageImpl::Compact(int device, EllpackPageImpl* page,
CHECK_LE(page->base_rowid + page->n_rows, row_indexes.size());
gidx_buffer.SetDevice(device);
page->gidx_buffer.SetDevice(device);
dh::LaunchN(device, page->n_rows, CompactPage(this, page, row_indexes));
dh::LaunchN(page->n_rows, CompactPage(this, page, row_indexes));
monitor_.Stop("Compact");
}

Expand Down
2 changes: 1 addition & 1 deletion src/data/simple_dmatrix.cu
Expand Up @@ -19,7 +19,7 @@ void CountRowOffsets(const AdapterBatchT& batch, common::Span<bst_row_t> offset,
int device_idx, float missing) {
IsValidFunctor is_valid(missing);
// Count elements per row
dh::LaunchN(device_idx, batch.Size(), [=] __device__(size_t idx) {
dh::LaunchN(batch.Size(), [=] __device__(size_t idx) {
auto element = batch.GetElement(idx);
if (is_valid(element)) {
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
Expand Down
7 changes: 3 additions & 4 deletions src/gbm/gbtree.cu
Expand Up @@ -20,14 +20,13 @@ void GPUCopyGradient(HostDeviceVector<GradientPair> const *in_gpair,
auto v_in = VectorView<GradientPair const>{in, group_id};
out_gpair->Resize(v_in.Size());
auto d_out = out_gpair->DeviceSpan();
dh::LaunchN(dh::CurrentDevice(), v_in.Size(),
[=] __device__(size_t i) { d_out[i] = v_in[i]; });
dh::LaunchN(v_in.Size(), [=] __device__(size_t i) { d_out[i] = v_in[i]; });
}

void GPUDartPredictInc(common::Span<float> out_predts,
common::Span<float> predts, float tree_w, size_t n_rows,
bst_group_t n_groups, bst_group_t group) {
dh::LaunchN(dh::CurrentDevice(), n_rows, [=]XGBOOST_DEVICE(size_t ridx) {
dh::LaunchN(n_rows, [=] XGBOOST_DEVICE(size_t ridx) {
const size_t offset = ridx * n_groups + group;
out_predts[offset] += (predts[offset] * tree_w);
});
Expand All @@ -37,7 +36,7 @@ void GPUDartInplacePredictInc(common::Span<float> out_predts,
common::Span<float> predts, float tree_w,
size_t n_rows, float base_score,
bst_group_t n_groups, bst_group_t group) {
dh::LaunchN(dh::CurrentDevice(), n_rows, [=] XGBOOST_DEVICE(size_t ridx) {
dh::LaunchN(n_rows, [=] XGBOOST_DEVICE(size_t ridx) {
const size_t offset = ridx * n_groups + group;
out_predts[offset] += (predts[offset] - base_score) * tree_w;
});
Expand Down
4 changes: 2 additions & 2 deletions src/linear/updater_gpu_coordinate.cu
Expand Up @@ -193,7 +193,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
void UpdateBiasResidual(float dbias, int group_idx, int num_groups) {
if (dbias == 0.0f) return;
auto d_gpair = dh::ToSpan(gpair_);
dh::LaunchN(learner_param_->gpu_id, num_row_, [=] __device__(size_t idx) {
dh::LaunchN(num_row_, [=] __device__(size_t idx) {
auto &g = d_gpair[idx * num_groups + group_idx];
g += GradientPair(g.GetHess() * dbias, 0);
});
Expand Down Expand Up @@ -222,7 +222,7 @@ class GPUCoordinateUpdater : public LinearUpdater { // NOLINT
common::Span<GradientPair> d_gpair = dh::ToSpan(gpair_);
common::Span<Entry> d_col = dh::ToSpan(data_).subspan(row_ptr_[fidx]);
size_t col_size = row_ptr_[fidx + 1] - row_ptr_[fidx];
dh::LaunchN(learner_param_->gpu_id, col_size, [=] __device__(size_t idx) {
dh::LaunchN(col_size, [=] __device__(size_t idx) {
auto entry = d_col[idx];
auto &g = d_gpair[entry.index * num_groups + group_idx];
g += GradientPair(g.GetHess() * dw * entry.fvalue, 0);
Expand Down
26 changes: 12 additions & 14 deletions src/metric/auc.cu
Expand Up @@ -118,12 +118,12 @@ GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
return thrust::make_pair(fp, tp);
}; // NOLINT
auto d_fptp = dh::ToSpan(cache->fptp);
dh::LaunchN(device, d_sorted_idx.size(),
dh::LaunchN(d_sorted_idx.size(),
[=] __device__(size_t i) { d_fptp[i] = get_fp_tp(i); });

dh::XGBDeviceAllocator<char> alloc;
auto d_unique_idx = dh::ToSpan(cache->unique_idx);
dh::Iota(d_unique_idx, device);
dh::Iota(d_unique_idx);

auto uni_key = dh::MakeTransformIterator<float>(
thrust::make_counting_iterator(0),
Expand All @@ -144,7 +144,7 @@ GPUBinaryAUC(common::Span<float const> predts, MetaInfo const &info,
auto d_neg_pos = dh::ToSpan(cache->neg_pos);
// scatter unique negaive/positive values
// shift to right by 1 with initial value being 0
dh::LaunchN(device, d_unique_idx.size(), [=] __device__(size_t i) {
dh::LaunchN(d_unique_idx.size(), [=] __device__(size_t i) {
if (d_unique_idx[i] == 0) { // first unique index is 0
assert(i == 0);
d_neg_pos[0] = {0, 0};
Expand Down Expand Up @@ -183,7 +183,7 @@ void Transpose(common::Span<float const> in, common::Span<float> out, size_t m,
size_t n, int32_t device) {
CHECK_EQ(in.size(), out.size());
CHECK_EQ(in.size(), m * n);
dh::LaunchN(device, in.size(), [=] __device__(size_t i) {
dh::LaunchN(in.size(), [=] __device__(size_t i) {
size_t col = i / m;
size_t row = i % m;
size_t idx = row * n + col;
Expand Down Expand Up @@ -255,9 +255,8 @@ float GPUMultiClassAUCOVR(common::Span<float const> predts, MetaInfo const &info
if (n_samples == 0) {
dh::TemporaryArray<float> resutls(n_classes * 4, 0.0f);
auto d_results = dh::ToSpan(resutls);
dh::LaunchN(device, n_classes * 4, [=]__device__(size_t i) {
d_results[i] = 0.0f;
});
dh::LaunchN(n_classes * 4,
[=] __device__(size_t i) { d_results[i] = 0.0f; });
auto local_area = d_results.subspan(0, n_classes);
auto fp = d_results.subspan(n_classes, n_classes);
auto tp = d_results.subspan(2 * n_classes, n_classes);
Expand All @@ -273,9 +272,8 @@ float GPUMultiClassAUCOVR(common::Span<float const> predts, MetaInfo const &info

dh::TemporaryArray<uint32_t> class_ptr(n_classes + 1, 0);
auto d_class_ptr = dh::ToSpan(class_ptr);
dh::LaunchN(device, n_classes + 1, [=]__device__(size_t i) {
d_class_ptr[i] = i * n_samples;
});
dh::LaunchN(n_classes + 1,
[=] __device__(size_t i) { d_class_ptr[i] = i * n_samples; });
// no out-of-place sort for thrust, cub sort doesn't accept general iterator. So can't
// use transform iterator in sorting.
auto d_sorted_idx = dh::ToSpan(cache->sorted_idx);
Expand All @@ -301,15 +299,15 @@ float GPUMultiClassAUCOVR(common::Span<float const> predts, MetaInfo const &info
float tp = label * w;
return thrust::make_pair(fp, tp);
}; // NOLINT
dh::LaunchN(device, d_sorted_idx.size(),
dh::LaunchN(d_sorted_idx.size(),
[=] __device__(size_t i) { d_fptp[i] = get_fp_tp(i); });

/**
* Handle duplicated predictions
*/
dh::XGBDeviceAllocator<char> alloc;
auto d_unique_idx = dh::ToSpan(cache->unique_idx);
dh::Iota(d_unique_idx, device);
dh::Iota(d_unique_idx);
auto uni_key = dh::MakeTransformIterator<thrust::pair<uint32_t, float>>(
thrust::make_counting_iterator(0), [=] __device__(size_t i) {
uint32_t class_id = i / n_samples;
Expand Down Expand Up @@ -363,7 +361,7 @@ float GPUMultiClassAUCOVR(common::Span<float const> predts, MetaInfo const &info
auto d_neg_pos = dh::ToSpan(cache->neg_pos);
// When dataset is not empty, each class must have at least 1 (unique) sample
// prediction, so no need to handle special case.
dh::LaunchN(device, d_unique_idx.size(), [=]__device__(size_t i) {
dh::LaunchN(d_unique_idx.size(), [=] __device__(size_t i) {
if (d_unique_idx[i] % n_samples == 0) { // first unique index is 0
assert(d_unique_idx[i] % n_samples == 0);
d_neg_pos[d_unique_idx[i]] = {0, 0}; // class_id * n_samples = i
Expand Down Expand Up @@ -419,7 +417,7 @@ float GPUMultiClassAUCOVR(common::Span<float const> predts, MetaInfo const &info
auto tp = d_results.subspan(2 * n_classes, n_classes);
auto auc = d_results.subspan(3 * n_classes, n_classes);

dh::LaunchN(device, n_classes, [=] __device__(size_t c) {
dh::LaunchN(n_classes, [=] __device__(size_t c) {
auc[c] = s_d_auc[c];
auto last = d_fptp[n_samples * c + (n_samples - 1)];
fp[c] = last.first;
Expand Down

0 comments on commit 1c8fdf2

Please sign in to comment.