From 116d71181586f0d22e295e369e4730a611e502c0 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 6 Jul 2021 13:38:24 +0800 Subject: [PATCH] Make `SimpleDMatrix` ctor reusable. (#7075) --- src/data/simple_dmatrix.cu | 67 +++---------------------------- src/data/simple_dmatrix.cuh | 78 +++++++++++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+), 61 deletions(-) create mode 100644 src/data/simple_dmatrix.cuh diff --git a/src/data/simple_dmatrix.cu b/src/data/simple_dmatrix.cu index 9b1db6f44054..87f7fa2a031a 100644 --- a/src/data/simple_dmatrix.cu +++ b/src/data/simple_dmatrix.cu @@ -1,89 +1,34 @@ /*! - * Copyright 2019 by Contributors + * Copyright 2019-2021 by XGBoost Contributors * \file simple_dmatrix.cu */ #include -#include -#include #include -#include "../common/random.h" -#include "./simple_dmatrix.h" +#include "simple_dmatrix.cuh" +#include "simple_dmatrix.h" #include "device_adapter.cuh" namespace xgboost { namespace data { - -template -void CountRowOffsets(const AdapterBatchT& batch, common::Span offset, - int device_idx, float missing) { - IsValidFunctor is_valid(missing); - // Count elements per row - dh::LaunchN(batch.Size(), [=] __device__(size_t idx) { - auto element = batch.GetElement(idx); - if (is_valid(element)) { - atomicAdd(reinterpret_cast( // NOLINT - &offset[element.row_idx]), - static_cast(1)); // NOLINT - } - }); - - dh::XGBCachingDeviceAllocator alloc; - thrust::exclusive_scan(thrust::cuda::par(alloc), - thrust::device_pointer_cast(offset.data()), - thrust::device_pointer_cast(offset.data() + offset.size()), - thrust::device_pointer_cast(offset.data())); -} - -template -struct COOToEntryOp { - AdapterBatchT batch; - __device__ Entry operator()(size_t idx) { - const auto& e = batch.GetElement(idx); - return Entry(e.column_idx, e.value); - } -}; - -// Here the data is already correctly ordered and simply needs to be compacted -// to remove missing data -template -void CopyDataToDMatrix(AdapterT* adapter, common::Span data, - float missing) { - auto batch = adapter->Value(); - auto counting = thrust::make_counting_iterator(0llu); - dh::XGBCachingDeviceAllocator alloc; - COOToEntryOp transform_op{batch}; - thrust::transform_iterator - transform_iter(counting, transform_op); - auto begin_output = thrust::device_pointer_cast(data.data()); - dh::CopyIf(transform_iter, transform_iter + batch.Size(), begin_output, - IsValidFunctor(missing)); -} - // Does not currently support metainfo as no on-device data source contains this // Current implementation assumes a single batch. More batches can // be supported in future. Does not currently support inferring row/column size template SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) { dh::safe_cuda(cudaSetDevice(adapter->DeviceIdx())); + CHECK(adapter->NumRows() != kAdapterUnknownSize); CHECK(adapter->NumColumns() != kAdapterUnknownSize); adapter->BeforeFirst(); adapter->Next(); - auto& batch = adapter->Value(); - sparse_page_.offset.SetDevice(adapter->DeviceIdx()); - sparse_page_.data.SetDevice(adapter->DeviceIdx()); // Enforce single batch CHECK(!adapter->Next()); - sparse_page_.offset.Resize(adapter->NumRows() + 1); - auto s_offset = sparse_page_.offset.DeviceSpan(); - CountRowOffsets(batch, s_offset, adapter->DeviceIdx(), missing); - info_.num_nonzero_ = sparse_page_.offset.HostVector().back(); - sparse_page_.data.Resize(info_.num_nonzero_); - CopyDataToDMatrix(adapter, sparse_page_.data.DeviceSpan(), missing); + info_.num_nonzero_ = CopyToSparsePage(adapter->Value(), adapter->DeviceIdx(), + missing, &sparse_page_); info_.num_col_ = adapter->NumColumns(); info_.num_row_ = adapter->NumRows(); // Synchronise worker columns diff --git a/src/data/simple_dmatrix.cuh b/src/data/simple_dmatrix.cuh new file mode 100644 index 000000000000..c71a52b6746e --- /dev/null +++ b/src/data/simple_dmatrix.cuh @@ -0,0 +1,78 @@ +/*! + * Copyright 2019-2021 by XGBoost Contributors + * \file simple_dmatrix.cuh + */ +#ifndef XGBOOST_DATA_SIMPLE_DMATRIX_CUH_ +#define XGBOOST_DATA_SIMPLE_DMATRIX_CUH_ + +#include +#include +#include +#include "device_adapter.cuh" +#include "../common/device_helpers.cuh" + +namespace xgboost { +namespace data { + +template +struct COOToEntryOp { + AdapterBatchT batch; + __device__ Entry operator()(size_t idx) { + const auto& e = batch.GetElement(idx); + return Entry(e.column_idx, e.value); + } +}; + +// Here the data is already correctly ordered and simply needs to be compacted +// to remove missing data +template +void CopyDataToDMatrix(AdapterBatchT batch, common::Span data, + float missing) { + auto counting = thrust::make_counting_iterator(0llu); + dh::XGBCachingDeviceAllocator alloc; + COOToEntryOp transform_op{batch}; + thrust::transform_iterator + transform_iter(counting, transform_op); + auto begin_output = thrust::device_pointer_cast(data.data()); + dh::CopyIf(transform_iter, transform_iter + batch.Size(), begin_output, + IsValidFunctor(missing)); +} + +template +void CountRowOffsets(const AdapterBatchT& batch, common::Span offset, + int device_idx, float missing) { + dh::safe_cuda(cudaSetDevice(device_idx)); + IsValidFunctor is_valid(missing); + // Count elements per row + dh::LaunchN(batch.Size(), [=] __device__(size_t idx) { + auto element = batch.GetElement(idx); + if (is_valid(element)) { + atomicAdd(reinterpret_cast( // NOLINT + &offset[element.row_idx]), + static_cast(1)); // NOLINT + } + }); + + dh::XGBCachingDeviceAllocator alloc; + thrust::exclusive_scan(thrust::cuda::par(alloc), + thrust::device_pointer_cast(offset.data()), + thrust::device_pointer_cast(offset.data() + offset.size()), + thrust::device_pointer_cast(offset.data())); +} + +template +size_t CopyToSparsePage(AdapterBatchT const& batch, int32_t device, float missing, SparsePage* page) { + page->offset.SetDevice(device); + page->data.SetDevice(device); + page->offset.Resize(batch.NumRows() + 1); + auto s_offset = page->offset.DeviceSpan(); + CountRowOffsets(batch, s_offset, device, missing); + auto num_nonzero_ = page->offset.HostVector().back(); + page->data.Resize(num_nonzero_); + CopyDataToDMatrix(batch, page->data.DeviceSpan(), missing); + + return num_nonzero_; +} +} // namespace data +} // namespace xgboost +#endif // XGBOOST_DATA_SIMPLE_DMATRIX_CUH_