Skip to content

Commit

Permalink
Make SimpleDMatrix ctor reusable. (#7075)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jul 6, 2021
1 parent d7e1fa7 commit 116d711
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 61 deletions.
67 changes: 6 additions & 61 deletions src/data/simple_dmatrix.cu
@@ -1,89 +1,34 @@
/*!
* Copyright 2019 by Contributors
* Copyright 2019-2021 by XGBoost Contributors
* \file simple_dmatrix.cu
*/
#include <thrust/copy.h>
#include <thrust/execution_policy.h>
#include <thrust/sort.h>
#include <xgboost/data.h>
#include "../common/random.h"
#include "./simple_dmatrix.h"
#include "simple_dmatrix.cuh"
#include "simple_dmatrix.h"
#include "device_adapter.cuh"

namespace xgboost {
namespace data {


template <typename AdapterBatchT>
void CountRowOffsets(const AdapterBatchT& batch, common::Span<bst_row_t> offset,
int device_idx, float missing) {
IsValidFunctor is_valid(missing);
// Count elements per row
dh::LaunchN(batch.Size(), [=] __device__(size_t idx) {
auto element = batch.GetElement(idx);
if (is_valid(element)) {
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
&offset[element.row_idx]),
static_cast<unsigned long long>(1)); // NOLINT
}
});

dh::XGBCachingDeviceAllocator<char> alloc;
thrust::exclusive_scan(thrust::cuda::par(alloc),
thrust::device_pointer_cast(offset.data()),
thrust::device_pointer_cast(offset.data() + offset.size()),
thrust::device_pointer_cast(offset.data()));
}

template <typename AdapterBatchT>
struct COOToEntryOp {
AdapterBatchT batch;
__device__ Entry operator()(size_t idx) {
const auto& e = batch.GetElement(idx);
return Entry(e.column_idx, e.value);
}
};

// Here the data is already correctly ordered and simply needs to be compacted
// to remove missing data
template <typename AdapterT>
void CopyDataToDMatrix(AdapterT* adapter, common::Span<Entry> data,
float missing) {
auto batch = adapter->Value();
auto counting = thrust::make_counting_iterator(0llu);
dh::XGBCachingDeviceAllocator<char> alloc;
COOToEntryOp<decltype(batch)> transform_op{batch};
thrust::transform_iterator<decltype(transform_op), decltype(counting)>
transform_iter(counting, transform_op);
auto begin_output = thrust::device_pointer_cast(data.data());
dh::CopyIf(transform_iter, transform_iter + batch.Size(), begin_output,
IsValidFunctor(missing));
}

// Does not currently support metainfo as no on-device data source contains this
// Current implementation assumes a single batch. More batches can
// be supported in future. Does not currently support inferring row/column size
template <typename AdapterT>
SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) {
dh::safe_cuda(cudaSetDevice(adapter->DeviceIdx()));

CHECK(adapter->NumRows() != kAdapterUnknownSize);
CHECK(adapter->NumColumns() != kAdapterUnknownSize);

adapter->BeforeFirst();
adapter->Next();
auto& batch = adapter->Value();
sparse_page_.offset.SetDevice(adapter->DeviceIdx());
sparse_page_.data.SetDevice(adapter->DeviceIdx());

// Enforce single batch
CHECK(!adapter->Next());
sparse_page_.offset.Resize(adapter->NumRows() + 1);
auto s_offset = sparse_page_.offset.DeviceSpan();
CountRowOffsets(batch, s_offset, adapter->DeviceIdx(), missing);
info_.num_nonzero_ = sparse_page_.offset.HostVector().back();
sparse_page_.data.Resize(info_.num_nonzero_);
CopyDataToDMatrix(adapter, sparse_page_.data.DeviceSpan(), missing);

info_.num_nonzero_ = CopyToSparsePage(adapter->Value(), adapter->DeviceIdx(),
missing, &sparse_page_);
info_.num_col_ = adapter->NumColumns();
info_.num_row_ = adapter->NumRows();
// Synchronise worker columns
Expand Down
78 changes: 78 additions & 0 deletions src/data/simple_dmatrix.cuh
@@ -0,0 +1,78 @@
/*!
* Copyright 2019-2021 by XGBoost Contributors
* \file simple_dmatrix.cuh
*/
#ifndef XGBOOST_DATA_SIMPLE_DMATRIX_CUH_
#define XGBOOST_DATA_SIMPLE_DMATRIX_CUH_

#include <thrust/copy.h>
#include <thrust/scan.h>
#include <thrust/execution_policy.h>
#include "device_adapter.cuh"
#include "../common/device_helpers.cuh"

namespace xgboost {
namespace data {

template <typename AdapterBatchT>
struct COOToEntryOp {
AdapterBatchT batch;
__device__ Entry operator()(size_t idx) {
const auto& e = batch.GetElement(idx);
return Entry(e.column_idx, e.value);
}
};

// Here the data is already correctly ordered and simply needs to be compacted
// to remove missing data
template <typename AdapterBatchT>
void CopyDataToDMatrix(AdapterBatchT batch, common::Span<Entry> data,
float missing) {
auto counting = thrust::make_counting_iterator(0llu);
dh::XGBCachingDeviceAllocator<char> alloc;
COOToEntryOp<decltype(batch)> transform_op{batch};
thrust::transform_iterator<decltype(transform_op), decltype(counting)>
transform_iter(counting, transform_op);
auto begin_output = thrust::device_pointer_cast(data.data());
dh::CopyIf(transform_iter, transform_iter + batch.Size(), begin_output,
IsValidFunctor(missing));
}

template <typename AdapterBatchT>
void CountRowOffsets(const AdapterBatchT& batch, common::Span<bst_row_t> offset,
int device_idx, float missing) {
dh::safe_cuda(cudaSetDevice(device_idx));
IsValidFunctor is_valid(missing);
// Count elements per row
dh::LaunchN(batch.Size(), [=] __device__(size_t idx) {
auto element = batch.GetElement(idx);
if (is_valid(element)) {
atomicAdd(reinterpret_cast<unsigned long long*>( // NOLINT
&offset[element.row_idx]),
static_cast<unsigned long long>(1)); // NOLINT
}
});

dh::XGBCachingDeviceAllocator<char> alloc;
thrust::exclusive_scan(thrust::cuda::par(alloc),
thrust::device_pointer_cast(offset.data()),
thrust::device_pointer_cast(offset.data() + offset.size()),
thrust::device_pointer_cast(offset.data()));
}

template <typename AdapterBatchT>
size_t CopyToSparsePage(AdapterBatchT const& batch, int32_t device, float missing, SparsePage* page) {
page->offset.SetDevice(device);
page->data.SetDevice(device);
page->offset.Resize(batch.NumRows() + 1);
auto s_offset = page->offset.DeviceSpan();
CountRowOffsets(batch, s_offset, device, missing);
auto num_nonzero_ = page->offset.HostVector().back();
page->data.Resize(num_nonzero_);
CopyDataToDMatrix(batch, page->data.DeviceSpan(), missing);

return num_nonzero_;
}
} // namespace data
} // namespace xgboost
#endif // XGBOOST_DATA_SIMPLE_DMATRIX_CUH_

0 comments on commit 116d711

Please sign in to comment.