Skip to content

Commit

Permalink
Copy data from Ellpack to GHist. (#8215)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Sep 6, 2022
1 parent 7ee10e3 commit 441ffc0
Show file tree
Hide file tree
Showing 16 changed files with 459 additions and 105 deletions.
27 changes: 27 additions & 0 deletions src/common/algorithm.cuh
@@ -0,0 +1,27 @@
/*!
* Copyright 2022 by XGBoost Contributors
*/
#pragma once

#include <thrust/binary_search.h> // thrust::upper_bound
#include <thrust/execution_policy.h> // thrust::seq

#include "xgboost/base.h"
#include "xgboost/span.h"

namespace xgboost {
namespace common {
namespace cuda {
template <typename It>
size_t XGBOOST_DEVICE SegmentId(It first, It last, size_t idx) {
size_t segment_id = thrust::upper_bound(thrust::seq, first, last, idx) - 1 - first;
return segment_id;
}

template <typename T>
size_t XGBOOST_DEVICE SegmentId(Span<T> segments_ptr, size_t idx) {
return SegmentId(segments_ptr.cbegin(), segments_ptr.cend(), idx);
}
} // namespace cuda
} // namespace common
} // namespace xgboost
16 changes: 16 additions & 0 deletions src/common/algorithm.h
@@ -0,0 +1,16 @@
/*!
* Copyright 2022 by XGBoost Contributors
*/
#pragma once
#include <algorithm> // std::upper_bound
#include <cinttypes> // std::size_t

namespace xgboost {
namespace common {
template <typename It, typename Idx>
auto SegmentId(It first, It last, Idx idx) {
std::size_t segment_id = std::upper_bound(first, last, idx) - 1 - first;
return segment_id;
}
} // namespace common
} // namespace xgboost
139 changes: 105 additions & 34 deletions src/common/column_matrix.h
Expand Up @@ -18,6 +18,7 @@

#include "../data/adapter.h"
#include "../data/gradient_index.h"
#include "algorithm.h"
#include "hist_util.h"

namespace xgboost {
Expand Down Expand Up @@ -135,6 +136,22 @@ class DenseColumnIter : public Column<BinIdxT> {
class ColumnMatrix {
void InitStorage(GHistIndexMatrix const& gmat, double sparse_threshold);

template <typename ColumnBinT, typename BinT, typename RIdx>
void SetBinSparse(BinT bin_id, RIdx rid, bst_feature_t fid, ColumnBinT* local_index) {
if (type_[fid] == kDenseColumn) {
ColumnBinT* begin = &local_index[feature_offsets_[fid]];
begin[rid] = bin_id - index_base_[fid];
// not thread-safe with bool vector. FIXME(jiamingy): We can directly assign
// kMissingId to the index to avoid missing flags.
missing_flags_[feature_offsets_[fid] + rid] = false;
} else {
ColumnBinT* begin = &local_index[feature_offsets_[fid]];
begin[num_nonzeros_[fid]] = bin_id - index_base_[fid];
row_ind_[feature_offsets_[fid] + num_nonzeros_[fid]] = rid;
++num_nonzeros_[fid];
}
}

public:
// get number of features
bst_feature_t GetNumFeature() const { return static_cast<bst_feature_t>(type_.size()); }
Expand All @@ -144,34 +161,66 @@ class ColumnMatrix {
this->InitStorage(gmat, sparse_threshold);
}

/**
* \brief Initialize ColumnMatrix from GHistIndexMatrix with reference to the original
* SparsePage.
*/
void InitFromSparse(SparsePage const& page, const GHistIndexMatrix& gmat, double sparse_threshold,
int32_t n_threads) {
auto batch = data::SparsePageAdapterBatch{page.GetView()};
this->InitStorage(gmat, sparse_threshold);
// ignore base row id here as we always has one column matrix for each sparse page.
this->PushBatch(n_threads, batch, std::numeric_limits<float>::quiet_NaN(), gmat, 0);
}

/**
* \brief Initialize ColumnMatrix from GHistIndexMatrix without reference to actual
* data.
*
* This function requires a binary search for each bin to get back the feature index
* for those bins.
*/
void InitFromGHist(Context const* ctx, GHistIndexMatrix const& gmat) {
auto n_threads = ctx->Threads();
if (!any_missing_) {
// row index is compressed, we need to dispatch it.
DispatchBinType(gmat.index.GetBinTypeSize(), [&, size = gmat.Size(), n_threads = n_threads,
n_features = gmat.Features()](auto t) {
using RowBinIdxT = decltype(t);
SetIndexNoMissing(gmat.base_rowid, gmat.index.data<RowBinIdxT>(), size, n_features,
n_threads);
});
} else {
SetIndexMixedColumns(gmat);
}
}

/**
* \brief Push batch of data for Quantile DMatrix support.
*
* \param batch Input data wrapped inside a adapter batch.
* \param gmat The row-major histogram index that contains index for ALL data.
* \param base_rowid The beginning row index for current batch.
*/
template <typename Batch>
void PushBatch(int32_t n_threads, Batch const& batch, float missing, GHistIndexMatrix const& gmat,
size_t base_rowid) {
// pre-fill index_ for dense columns
auto n_features = gmat.Features();
if (!any_missing_) {
missing_flags_.resize(feature_offsets_[n_features], false);
// row index is compressed, we need to dispatch it.
DispatchBinType(gmat.index.GetBinTypeSize(), [&, size = batch.Size(), n_features = n_features,
n_threads = n_threads](auto t) {

// use base_rowid from input parameter as gmat is a single matrix that contains all
// the histogram index instead of being only a batch.
DispatchBinType(gmat.index.GetBinTypeSize(), [&, size = batch.Size(), n_threads = n_threads,
n_features = gmat.Features()](auto t) {
using RowBinIdxT = decltype(t);
SetIndexNoMissing(base_rowid, gmat.index.data<RowBinIdxT>(), size, n_features, n_threads);
});
} else {
missing_flags_.resize(feature_offsets_[n_features], true);
SetIndexMixedColumns(base_rowid, batch, gmat, n_features, missing);
SetIndexMixedColumns(base_rowid, batch, gmat, missing);
}
}

// construct column matrix from GHistIndexMatrix
void Init(SparsePage const& page, const GHistIndexMatrix& gmat, double sparse_threshold,
int32_t n_threads) {
auto batch = data::SparsePageAdapterBatch{page.GetView()};
this->InitStorage(gmat, sparse_threshold);
// ignore base row id here as we always has one column matrix for each sparse page.
this->PushBatch(n_threads, batch, std::numeric_limits<float>::quiet_NaN(), gmat, 0);
}

/* Set the number of bytes based on numeric limit of maximum number of bins provided by user */
void SetTypeSize(size_t max_bin_per_feat) {
if ((max_bin_per_feat - 1) <= static_cast<int>(std::numeric_limits<uint8_t>::max())) {
Expand Down Expand Up @@ -210,6 +259,7 @@ class ColumnMatrix {
template <typename RowBinIdxT>
void SetIndexNoMissing(bst_row_t base_rowid, RowBinIdxT const* row_index, const size_t n_samples,
const size_t n_features, int32_t n_threads) {
missing_flags_.resize(feature_offsets_[n_features], false);
DispatchBinType(bins_type_size_, [&](auto t) {
using ColumnBinT = decltype(t);
auto column_index = Span<ColumnBinT>{reinterpret_cast<ColumnBinT*>(index_.data()),
Expand All @@ -232,29 +282,16 @@ class ColumnMatrix {
*/
template <typename Batch>
void SetIndexMixedColumns(size_t base_rowid, Batch const& batch, const GHistIndexMatrix& gmat,
size_t n_features, float missing) {
float missing) {
auto n_features = gmat.Features();
missing_flags_.resize(feature_offsets_[n_features], true);
auto const* row_index = gmat.index.data<uint32_t>() + gmat.row_ptr[base_rowid];
auto is_valid = data::IsValidFunctor {missing};
num_nonzeros_.resize(n_features, 0);
auto is_valid = data::IsValidFunctor{missing};

DispatchBinType(bins_type_size_, [&](auto t) {
using ColumnBinT = decltype(t);
ColumnBinT* local_index = reinterpret_cast<ColumnBinT*>(index_.data());
num_nonzeros_.resize(n_features, 0);
auto get_bin_idx = [&](auto bin_id, auto rid, bst_feature_t fid) {
if (type_[fid] == kDenseColumn) {
ColumnBinT* begin = reinterpret_cast<ColumnBinT*>(&local_index[feature_offsets_[fid]]);
begin[rid] = bin_id - index_base_[fid];
// not thread-safe with bool vector. FIXME(jiamingy): We can directly assign
// kMissingId to the index to avoid missing flags.
missing_flags_[feature_offsets_[fid] + rid] = false;
} else {
ColumnBinT* begin = reinterpret_cast<ColumnBinT*>(&local_index[feature_offsets_[fid]]);
begin[num_nonzeros_[fid]] = bin_id - index_base_[fid];
row_ind_[feature_offsets_[fid] + num_nonzeros_[fid]] = rid;
++num_nonzeros_[fid];
}
};

size_t const batch_size = batch.Size();
size_t k{0};
for (size_t rid = 0; rid < batch_size; ++rid) {
Expand All @@ -264,14 +301,48 @@ class ColumnMatrix {
if (is_valid(coo)) {
auto fid = coo.column_idx;
const uint32_t bin_id = row_index[k];
get_bin_idx(bin_id, rid + base_rowid, fid);
SetBinSparse(bin_id, rid + base_rowid, fid, local_index);
++k;
}
}
}
});
}

/**
* \brief Set column index for both dense and sparse columns, but with only GHistMatrix
* available and requires a search for each bin.
*/
void SetIndexMixedColumns(const GHistIndexMatrix& gmat) {
auto n_features = gmat.Features();
missing_flags_.resize(feature_offsets_[n_features], true);
auto const* row_index = gmat.index.data<uint32_t>() + gmat.row_ptr[gmat.base_rowid];
num_nonzeros_.resize(n_features, 0);
auto const& ptrs = gmat.cut.Ptrs();

DispatchBinType(bins_type_size_, [&](auto t) {
using ColumnBinT = decltype(t);
ColumnBinT* local_index = reinterpret_cast<ColumnBinT*>(index_.data());
auto const batch_size = gmat.Size();
size_t k{0};

for (size_t ridx = 0; ridx < batch_size; ++ridx) {
auto r_beg = gmat.row_ptr[ridx];
auto r_end = gmat.row_ptr[ridx + 1];
bst_feature_t fidx{0};
for (size_t j = r_beg; j < r_end; ++j) {
const uint32_t bin_idx = row_index[k];
// find the feature index for current bin.
while (bin_idx >= ptrs[fidx + 1]) {
fidx++;
}
SetBinSparse(bin_idx, ridx, fidx, local_index);
++k;
}
}
});
}

BinTypeSize GetTypeSize() const { return bins_type_size_; }
auto GetColumnType(bst_feature_t fidx) const { return type_[fidx]; }

Expand Down
13 changes: 2 additions & 11 deletions src/common/device_helpers.cuh
Expand Up @@ -35,6 +35,7 @@
#include "xgboost/global_config.h"

#include "common.h"
#include "algorithm.cuh"

#ifdef XGBOOST_USE_NCCL
#include "nccl.h"
Expand Down Expand Up @@ -1556,17 +1557,7 @@ XGBOOST_DEVICE thrust::transform_iterator<FuncT, IterT, ReturnT> MakeTransformIt
return thrust::transform_iterator<FuncT, IterT, ReturnT>(iter, func);
}

template <typename It>
size_t XGBOOST_DEVICE SegmentId(It first, It last, size_t idx) {
size_t segment_id = thrust::upper_bound(thrust::seq, first, last, idx) -
1 - first;
return segment_id;
}

template <typename T>
size_t XGBOOST_DEVICE SegmentId(xgboost::common::Span<T> segments_ptr, size_t idx) {
return SegmentId(segments_ptr.cbegin(), segments_ptr.cend(), idx);
}
using xgboost::common::cuda::SegmentId; // import it for compatibility

namespace detail {
template <typename Key, typename KeyOutIt>
Expand Down
35 changes: 31 additions & 4 deletions src/common/hist_util.h
Expand Up @@ -22,6 +22,7 @@
#include "row_set.h"
#include "threading_utils.h"
#include "timer.h"
#include "algorithm.h" // SegmentId

namespace xgboost {
class GHistIndexMatrix;
Expand Down Expand Up @@ -130,19 +131,23 @@ class HistogramCuts {
/**
* \brief Search the bin index for categorical feature.
*/
bst_bin_t SearchCatBin(float value, bst_feature_t fidx) const {
auto const &ptrs = this->Ptrs();
auto const &vals = this->Values();
bst_bin_t SearchCatBin(float value, bst_feature_t fidx, std::vector<uint32_t> const& ptrs,
std::vector<float> const& vals) const {
auto end = ptrs.at(fidx + 1) + vals.cbegin();
auto beg = ptrs[fidx] + vals.cbegin();
// Truncates the value in case it's not perfectly rounded.
auto v = static_cast<float>(common::AsCat(value));
auto v = static_cast<float>(common::AsCat(value));
auto bin_idx = std::lower_bound(beg, end, v) - vals.cbegin();
if (bin_idx == ptrs.at(fidx + 1)) {
bin_idx -= 1;
}
return bin_idx;
}
bst_bin_t SearchCatBin(float value, bst_feature_t fidx) const {
auto const& ptrs = this->Ptrs();
auto const& vals = this->Values();
return this->SearchCatBin(value, fidx, ptrs, vals);
}
bst_bin_t SearchCatBin(Entry const& e) const { return SearchCatBin(e.fvalue, e.index); }
};

Expand Down Expand Up @@ -189,6 +194,28 @@ auto DispatchBinType(BinTypeSize type, Fn&& fn) {
* storage class.
*/
struct Index {
// Inside the compressor, bin_idx is the index for cut value across all features. By
// subtracting it with starting pointer of each feature, we can reduce it to smaller
// value and store it with smaller types. Usable only with dense data.
//
// For sparse input we have to store an addition feature index (similar to sparse matrix
// formats like CSR) for each bin in index field to choose the right offset.
template <typename T>
struct CompressBin {
uint32_t const* offsets;

template <typename Bin, typename Feat>
auto operator()(Bin bin_idx, Feat fidx) const {
return static_cast<T>(bin_idx - offsets[fidx]);
}
};

template <typename T>
CompressBin<T> MakeCompressor() const {
uint32_t const* offsets = this->Offset();
return CompressBin<T>{offsets};
}

Index() { SetBinTypeSize(binTypeSize_); }
Index(const Index& i) = delete;
Index& operator=(Index i) = delete;
Expand Down
11 changes: 11 additions & 0 deletions src/data/ellpack_page.cu
Expand Up @@ -547,4 +547,15 @@ EllpackDeviceAccessor EllpackPageImpl::GetDeviceAccessor(
NumSymbols()),
feature_types};
}
EllpackDeviceAccessor EllpackPageImpl::GetHostAccessor(
common::Span<FeatureType const> feature_types) const {
return {Context::kCpuId,
cuts_,
is_dense,
row_stride,
base_rowid,
n_rows,
common::CompressedIterator<uint32_t>(gidx_buffer.ConstHostPointer(), NumSymbols()),
feature_types};
}
} // namespace xgboost
19 changes: 13 additions & 6 deletions src/data/ellpack_page.cuh
Expand Up @@ -43,12 +43,18 @@ struct EllpackDeviceAccessor {
base_rowid(base_rowid),
n_rows(n_rows) ,gidx_iter(gidx_iter),
feature_types{feature_types} {
cuts.cut_values_.SetDevice(device);
cuts.cut_ptrs_.SetDevice(device);
cuts.min_vals_.SetDevice(device);
gidx_fvalue_map = cuts.cut_values_.ConstDeviceSpan();
feature_segments = cuts.cut_ptrs_.ConstDeviceSpan();
min_fvalue = cuts.min_vals_.ConstDeviceSpan();
if (device == Context::kCpuId) {
gidx_fvalue_map = cuts.cut_values_.ConstHostSpan();
feature_segments = cuts.cut_ptrs_.ConstHostSpan();
min_fvalue = cuts.min_vals_.ConstHostSpan();
} else {
cuts.cut_values_.SetDevice(device);
cuts.cut_ptrs_.SetDevice(device);
cuts.min_vals_.SetDevice(device);
gidx_fvalue_map = cuts.cut_values_.ConstDeviceSpan();
feature_segments = cuts.cut_ptrs_.ConstDeviceSpan();
min_fvalue = cuts.min_vals_.ConstDeviceSpan();
}
}
// Get a matrix element, uses binary search for look up Return NaN if missing
// Given a row index and a feature index, returns the corresponding cut value
Expand Down Expand Up @@ -202,6 +208,7 @@ class EllpackPageImpl {
EllpackDeviceAccessor
GetDeviceAccessor(int device,
common::Span<FeatureType const> feature_types = {}) const;
EllpackDeviceAccessor GetHostAccessor(common::Span<FeatureType const> feature_types = {}) const;

private:
/*!
Expand Down

0 comments on commit 441ffc0

Please sign in to comment.