Skip to content

Commit

Permalink
Implement sketching with adapter. (#8019)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jun 22, 2022
1 parent 142a208 commit f0c1b84
Show file tree
Hide file tree
Showing 6 changed files with 196 additions and 164 deletions.
34 changes: 34 additions & 0 deletions src/common/hist_util.cc
Expand Up @@ -33,6 +33,40 @@ HistogramCuts::HistogramCuts() {
cut_ptrs_.HostVector().emplace_back(0);
}

HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins, int32_t n_threads, bool use_sorted,
Span<float> const hessian) {
HistogramCuts out;
auto const& info = m->Info();
std::vector<bst_row_t> reduced(info.num_col_, 0);
for (auto const &page : m->GetBatches<SparsePage>()) {
auto const &entries_per_column =
CalcColumnSize(data::SparsePageAdapterBatch{page.GetView()}, info.num_col_, n_threads,
[](auto) { return true; });
CHECK_EQ(entries_per_column.size(), info.num_col_);
for (size_t i = 0; i < entries_per_column.size(); ++i) {
reduced[i] += entries_per_column[i];
}
}

if (!use_sorted) {
HostSketchContainer container(max_bins, m->Info(), reduced, HostSketchContainer::UseGroup(info),
n_threads);
for (auto const& page : m->GetBatches<SparsePage>()) {
container.PushRowPage(page, info, hessian);
}
container.MakeCuts(&out);
} else {
SortedSketchContainer container{max_bins, m->Info(), reduced,
HostSketchContainer::UseGroup(info), n_threads};
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
container.PushColPage(page, info, hessian);
}
container.MakeCuts(&out);
}

return out;
}

/*!
* \brief fill a histogram by zeros in range [begin, end)
*/
Expand Down
37 changes: 2 additions & 35 deletions src/common/hist_util.h
Expand Up @@ -152,41 +152,8 @@ class HistogramCuts {
* \param use_sorted Whether should we use SortedCSC for sketching, it's more efficient
* but consumes more memory.
*/
inline HistogramCuts SketchOnDMatrix(DMatrix* m, int32_t max_bins, int32_t n_threads,
bool use_sorted = false, Span<float> const hessian = {}) {
HistogramCuts out;
auto const& info = m->Info();
std::vector<std::vector<bst_row_t>> column_sizes(n_threads);
for (auto& column : column_sizes) {
column.resize(info.num_col_, 0);
}
std::vector<bst_row_t> reduced(info.num_col_, 0);
for (auto const& page : m->GetBatches<SparsePage>()) {
auto const& entries_per_column =
HostSketchContainer::CalcColumnSize(page, info.num_col_, n_threads);
for (size_t i = 0; i < entries_per_column.size(); ++i) {
reduced[i] += entries_per_column[i];
}
}

if (!use_sorted) {
HostSketchContainer container(max_bins, m->Info(), reduced, HostSketchContainer::UseGroup(info),
n_threads);
for (auto const& page : m->GetBatches<SparsePage>()) {
container.PushRowPage(page, info, hessian);
}
container.MakeCuts(&out);
} else {
SortedSketchContainer container{max_bins, m->Info(), reduced,
HostSketchContainer::UseGroup(info), n_threads};
for (auto const& page : m->GetBatches<SortedCSCPage>()) {
container.PushColPage(page, info, hessian);
}
container.MakeCuts(&out);
}

return out;
}
HistogramCuts SketchOnDMatrix(DMatrix* m, int32_t max_bins, int32_t n_threads,
bool use_sorted = false, Span<float> const hessian = {});

enum BinTypeSize : uint8_t {
kUint8BinsTypeSize = 1,
Expand Down
140 changes: 29 additions & 111 deletions src/common/quantile.cc
Expand Up @@ -6,6 +6,7 @@
#include <limits>
#include <utility>

#include "../data/adapter.h"
#include "categorical.h"
#include "hist_util.h"
#include "rabit/rabit.h"
Expand All @@ -31,72 +32,6 @@ SketchContainerImpl<WQSketch>::SketchContainerImpl(std::vector<bst_row_t> column
has_categorical_ = std::any_of(feature_types_.cbegin(), feature_types_.cend(), IsCatOp{});
}

template <typename WQSketch>
std::vector<bst_row_t> SketchContainerImpl<WQSketch>::CalcColumnSize(SparsePage const &batch,
bst_feature_t const n_columns,
size_t const nthreads) {
auto page = batch.GetView();
std::vector<std::vector<bst_row_t>> column_sizes(nthreads);
for (auto &column : column_sizes) {
column.resize(n_columns, 0);
}

ParallelFor(page.Size(), nthreads, [&](omp_ulong i) {
auto &local_column_sizes = column_sizes.at(omp_get_thread_num());
auto row = page[i];
auto const *p_row = row.data();
for (size_t j = 0; j < row.size(); ++j) {
local_column_sizes.at(p_row[j].index)++;
}
});
std::vector<bst_row_t> entries_per_columns(n_columns, 0);
ParallelFor(n_columns, nthreads, [&](bst_omp_uint i) {
for (auto const &thread : column_sizes) {
entries_per_columns[i] += thread[i];
}
});
return entries_per_columns;
}

template <typename WQSketch>
std::vector<bst_feature_t> SketchContainerImpl<WQSketch>::LoadBalance(SparsePage const &batch,
bst_feature_t n_columns,
size_t const nthreads) {
/* Some sparse datasets have their mass concentrating on small number of features. To
* avoid waiting for a few threads running forever, we here distribute different number
* of columns to different threads according to number of entries.
*/
auto page = batch.GetView();
size_t const total_entries = page.data.size();
size_t const entries_per_thread = DivRoundUp(total_entries, nthreads);

std::vector<std::vector<bst_row_t>> column_sizes(nthreads);
for (auto& column : column_sizes) {
column.resize(n_columns, 0);
}
std::vector<bst_row_t> entries_per_columns =
CalcColumnSize(batch, n_columns, nthreads);
std::vector<bst_feature_t> cols_ptr(nthreads + 1, 0);
size_t count {0};
size_t current_thread {1};

for (auto col : entries_per_columns) {
cols_ptr.at(current_thread)++; // add one column to thread
count += col;
CHECK_LE(count, total_entries);
if (count > entries_per_thread) {
current_thread++;
count = 0;
cols_ptr.at(current_thread) = cols_ptr[current_thread-1];
}
}
// Idle threads.
for (; current_thread < cols_ptr.size() - 1; ++current_thread) {
cols_ptr[current_thread+1] = cols_ptr[current_thread];
}
return cols_ptr;
}

namespace {
// Function to merge hessian and sample weights
std::vector<float> MergeWeights(MetaInfo const &info, Span<float const> hessian, bool use_group,
Expand Down Expand Up @@ -143,54 +78,37 @@ void SketchContainerImpl<WQSketch>::PushRowPage(SparsePage const &page, MetaInfo
CHECK_EQ(weights.size(), info.num_row_);
}

auto batch = page.GetView();
// Parallel over columns. Each thread owns a set of consecutive columns.
auto const ncol = static_cast<bst_feature_t>(info.num_col_);
auto thread_columns_ptr = LoadBalance(page, info.num_col_, n_threads_);

dmlc::OMPException exc;
#pragma omp parallel num_threads(n_threads_)
{
exc.Run([&]() {
auto tid = static_cast<uint32_t>(omp_get_thread_num());
auto const begin = thread_columns_ptr[tid];
auto const end = thread_columns_ptr[tid + 1];

// do not iterate if no columns are assigned to the thread
if (begin < end && end <= ncol) {
for (size_t i = 0; i < batch.Size(); ++i) {
size_t const ridx = page.base_rowid + i;
SparsePage::Inst const inst = batch[i];
auto w = weights.empty() ? 1.0f : weights[ridx];
auto p_inst = inst.data();
if (is_dense) {
for (size_t ii = begin; ii < end; ii++) {
if (IsCat(feature_types_, ii)) {
categories_[ii].emplace(p_inst[ii].fvalue);
} else {
sketches_[ii].Push(p_inst[ii].fvalue, w);
}
}
} else {
for (size_t i = 0; i < inst.size(); ++i) {
auto const& entry = p_inst[i];
if (entry.index >= begin && entry.index < end) {
if (IsCat(feature_types_, entry.index)) {
categories_[entry.index].emplace(entry.fvalue);
} else {
sketches_[entry.index].Push(entry.fvalue, w);
}
}
}
}
}
}
});
}
exc.Rethrow();
auto batch = data::SparsePageAdapterBatch{page.GetView()};
this->PushRowPageImpl(batch, page.base_rowid, OptionalWeights{weights}, page.data.Size(),
info.num_col_, is_dense, [](auto) { return true; });
monitor_.Stop(__func__);
}

template <typename Batch>
void HostSketchContainer::PushAdapterBatch(Batch const &batch, size_t base_rowid,
MetaInfo const &info, size_t nnz, float missing) {
auto const &h_weights =
(use_group_ind_ ? detail::UnrollGroupWeights(info) : info.weights_.HostVector());

auto is_valid = data::IsValidFunctor{missing};
auto weights = OptionalWeights{Span<float const>{h_weights}};
// the nnz from info is not reliable as sketching might be the first place to go through
// the data.
auto is_dense = nnz == info.num_col_ * info.num_row_;
this->PushRowPageImpl(batch, base_rowid, weights, nnz, info.num_col_, is_dense, is_valid);
}

#define INSTANTIATE(_type) \
template void HostSketchContainer::PushAdapterBatch<data::_type>( \
data::_type const &batch, size_t base_rowid, MetaInfo const &info, size_t nnz, \
float missing);

INSTANTIATE(ArrayAdapterBatch)
INSTANTIATE(CSRArrayAdapterBatch)
INSTANTIATE(CSCAdapterBatch)
INSTANTIATE(DataTableAdapterBatch)
INSTANTIATE(SparsePageAdapterBatch)

namespace {
/**
* \brief A view over gathered sketch values.
Expand Down

0 comments on commit f0c1b84

Please sign in to comment.