Skip to content

Commit

Permalink
Merge kernels.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Aug 11, 2022
1 parent 555a912 commit 43ae4ad
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 22 deletions.
25 changes: 9 additions & 16 deletions src/data/ellpack_page.cu
Expand Up @@ -288,7 +288,8 @@ ELLPACK_BATCH_SPECIALIZE(data::CupyAdapterBatch)

namespace {
void CopyGHistToEllpack(GHistIndexMatrix const& page, common::Span<size_t const> d_row_ptr,
size_t row_stride, common::CompressedByteT* d_compressed_buffer) {
size_t row_stride, common::CompressedByteT* d_compressed_buffer,
size_t null) {
dh::device_vector<uint8_t> data(page.index.begin(), page.index.end());
auto d_data = dh::ToSpan(data);

Expand All @@ -305,9 +306,10 @@ void CopyGHistToEllpack(GHistIndexMatrix const& page, common::Span<size_t const>

auto r_begin = d_row_ptr[ridx];
auto r_end = d_row_ptr[ridx + 1];
size_t rsize = r_end - r_begin;
size_t r_size = r_end - r_begin;

if (ifeature >= rsize) {
if (ifeature >= r_size) {
writer.AtomicWriteSymbol(d_compressed_buffer, null, idx);
return;
}

Expand All @@ -320,15 +322,10 @@ void CopyGHistToEllpack(GHistIndexMatrix const& page, common::Span<size_t const>
using T = decltype(t);
auto ptr = reinterpret_cast<T const*>(d_data.data());
auto bin_idx = ptr[r_begin + ifeature] + offset;
writer.AtomicWriteSymbol(d_compressed_buffer, bin_idx, ridx * row_stride + ifeature);
writer.AtomicWriteSymbol(d_compressed_buffer, bin_idx, idx);
});
});
}

void RowCountsFromIndptr(common::Span<size_t const> d_row_ptr, common::Span<size_t> row_counts) {
dh::LaunchN(row_counts.size(),
[=] XGBOOST_DEVICE(size_t i) { row_counts[i] = d_row_ptr[i + 1] - d_row_ptr[i]; });
}
} // anonymous namespace

EllpackPageImpl::EllpackPageImpl(Context const* ctx, GHistIndexMatrix const& page,
Expand All @@ -344,17 +341,13 @@ EllpackPageImpl::EllpackPageImpl(Context const* ctx, GHistIndexMatrix const& pag
monitor_.Stop("InitCompressedData");

// copy gidx
auto accessor = this->GetDeviceAccessor(ctx->gpu_id, ft);
common::CompressedByteT* d_compressed_buffer = gidx_buffer.DevicePointer();
dh::device_vector<size_t> row_ptr(page.row_ptr);
auto d_row_ptr = dh::ToSpan(row_ptr);
CopyGHistToEllpack(page, d_row_ptr, row_stride, d_compressed_buffer);

// write null value
dh::device_vector<size_t> row_counts(page.Size());
auto row_counts_span = dh::ToSpan(row_counts);
RowCountsFromIndptr(d_row_ptr, row_counts_span);
WriteNullValues(this, ctx->gpu_id, row_counts_span);
auto accessor = this->GetDeviceAccessor(ctx->gpu_id, ft);
auto null = accessor.NullValue();
CopyGHistToEllpack(page, d_row_ptr, row_stride, d_compressed_buffer, null);
}

// A functor that copies the data from one EllpackPage to another.
Expand Down
9 changes: 3 additions & 6 deletions src/data/iterative_dmatrix.h
Expand Up @@ -40,12 +40,9 @@ namespace data {
*
* - The CPU format and the GPU format are different, the former uses a CSR + CSC for
* histogram index while the latter uses only Ellpack. This results into a design that
* we can obtain the GPU format from CPU but not the other way around since we can't
* recover the CSC from Ellpack. More concretely, if users want to construct a CPU
* version of `QuantileDMatrix`, input data must be on CPU. However, if users want to
* have a GPU version of `QuantileDMatrix`, data can be on either place. We can fix this
* by retaining the feature index information in ellpack if there are feature
* requests.
* we can obtain the GPU format from CPU but the other way around is not yet
* supported. We can search the bin value from ellpack to recover the feature index when
* we support copying data from GPU to CPU.
*/
class IterativeDMatrix : public DMatrix {
MetaInfo info_;
Expand Down

0 comments on commit 43ae4ad

Please sign in to comment.