Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support hessian in host sketch container. #7081

Merged
merged 6 commits into from Jul 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/xgboost/generic_parameters.h
Expand Up @@ -39,6 +39,10 @@ struct GenericParameter : public XGBoostParameter<GenericParameter> {
* \param require_gpu Whether GPU is explicitly required from user.
*/
void ConfigureGpuId(bool require_gpu);
/*!
* Return automatically chosen threads.
*/
int32_t Threads() const;

// declare parameters
DMLC_DECLARE_PARAMETER(GenericParameter) {
Expand Down
7 changes: 4 additions & 3 deletions src/common/hist_util.h
Expand Up @@ -110,7 +110,8 @@ class HistogramCuts {
}
};

inline HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins) {
inline HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins,
std::vector<float> const &hessian = {}) {
HistogramCuts out;
auto const& info = m->Info();
const auto threads = omp_get_max_threads();
Expand All @@ -127,9 +128,9 @@ inline HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins) {
}
}
HostSketchContainer container(reduced, max_bins,
HostSketchContainer::UseGroup(info));
HostSketchContainer::UseGroup(info), threads);
for (auto const &page : m->GetBatches<SparsePage>()) {
container.PushRowPage(page, info);
container.PushRowPage(page, info, hessian);
}
container.MakeCuts(&out);
return out;
Expand Down
141 changes: 101 additions & 40 deletions src/common/quantile.cc
Expand Up @@ -10,19 +10,21 @@ namespace xgboost {
namespace common {

HostSketchContainer::HostSketchContainer(std::vector<bst_row_t> columns_size,
int32_t max_bins, bool use_group)
int32_t max_bins, bool use_group,
int32_t n_threads)
: columns_size_{std::move(columns_size)}, max_bins_{max_bins},
use_group_ind_{use_group} {
use_group_ind_{use_group}, n_threads_{n_threads} {
monitor_.Init(__func__);
CHECK_NE(columns_size_.size(), 0);
sketches_.resize(columns_size_.size());
for (size_t i = 0; i < sketches_.size(); ++i) {
CHECK_GE(n_threads_, 1);
ParallelFor(sketches_.size(), n_threads_, Sched::Auto(), [&](auto i) {
auto n_bins = std::min(static_cast<size_t>(max_bins_), columns_size_[i]);
n_bins = std::max(n_bins, static_cast<decltype(n_bins)>(1));
auto eps = 1.0 / (static_cast<float>(n_bins) * WQSketch::kFactor);
sketches_[i].Init(columns_size_[i], eps);
sketches_[i].inqueue.queue.resize(sketches_[i].limit_size * 2);
}
});
}

std::vector<bst_row_t>
Expand Down Expand Up @@ -89,40 +91,94 @@ std::vector<bst_feature_t> HostSketchContainer::LoadBalance(
return cols_ptr;
}

void HostSketchContainer::PushRowPage(SparsePage const &page,
MetaInfo const &info) {
namespace {
// Function to merge hessian and sample weights
std::vector<float> MergeWeights(MetaInfo const &info,
std::vector<float> const &hessian,
bool use_group, int32_t n_threads) {
CHECK_EQ(hessian.size(), info.num_row_);
std::vector<float> results(hessian.size());
auto const &group_ptr = info.group_ptr_;
if (use_group) {
auto const &group_weights = info.weights_.HostVector();
CHECK_GE(group_ptr.size(), 2);
CHECK_EQ(group_ptr.back(), hessian.size());
size_t cur_group = 0;
for (size_t i = 0; i < hessian.size(); ++i) {
results[i] = hessian[i] * group_weights[cur_group];
if (i == group_ptr[cur_group + 1]) {
cur_group++;
}
}
trivialfis marked this conversation as resolved.
Show resolved Hide resolved
} else {
auto const &sample_weights = info.weights_.HostVector();
ParallelFor(hessian.size(), n_threads, Sched::Auto(),
[&](auto i) { results[i] = hessian[i] * sample_weights[i]; });
}
return results;
}

std::vector<float> UnrollGroupWeights(MetaInfo const &info) {
std::vector<float> const &group_weights = info.weights_.HostVector();
if (group_weights.empty()) {
return group_weights;
}

size_t n_samples = info.num_row_;
auto const &group_ptr = info.group_ptr_;
std::vector<float> results(n_samples);
CHECK_GE(group_ptr.size(), 2);
CHECK_EQ(group_ptr.back(), n_samples);
size_t cur_group = 0;
for (size_t i = 0; i < n_samples; ++i) {
results[i] = group_weights[cur_group];
if (i == group_ptr[cur_group + 1]) {
cur_group++;
}
}
return results;
}
} // anonymous namespace

void HostSketchContainer::PushRowPage(
SparsePage const &page, MetaInfo const &info, std::vector<float> const &hessian) {
monitor_.Start(__func__);
int nthread = omp_get_max_threads();
CHECK_EQ(sketches_.size(), info.num_col_);
bst_feature_t n_columns = info.num_col_;
auto is_dense = info.num_nonzero_ == info.num_col_ * info.num_row_;
CHECK_GE(n_threads_, 1);
CHECK_EQ(sketches_.size(), n_columns);

// glue these conditions using ternary operator to avoid making data copies.
auto const &weights =
hessian.empty()
? (use_group_ind_ ? UnrollGroupWeights(info) // use group weight
: info.weights_.HostVector()) // use sample weight
: MergeWeights(
info, hessian, use_group_ind_,
n_threads_); // use hessian merged with group/sample weights
if (!weights.empty()) {
CHECK_EQ(weights.size(), info.num_row_);
}

// Data groups, used in ranking.
std::vector<bst_uint> const &group_ptr = info.group_ptr_;
// Use group index for weights?
auto batch = page.GetView();
// Parallel over columns. Each thread owns a set of consecutive columns.
auto const ncol = static_cast<uint32_t>(info.num_col_);
auto const is_dense = info.num_nonzero_ == info.num_col_ * info.num_row_;
auto thread_columns_ptr = LoadBalance(page, info.num_col_, nthread);
auto const ncol = static_cast<bst_feature_t>(info.num_col_);
auto thread_columns_ptr = LoadBalance(page, info.num_col_, n_threads_);

dmlc::OMPException exc;
#pragma omp parallel num_threads(nthread)
#pragma omp parallel num_threads(n_threads_)
{
exc.Run([&]() {
auto tid = static_cast<uint32_t>(omp_get_thread_num());
auto const begin = thread_columns_ptr[tid];
auto const end = thread_columns_ptr[tid + 1];
size_t group_ind = 0;

// do not iterate if no columns are assigned to the thread
if (begin < end && end <= ncol) {
for (size_t i = 0; i < batch.Size(); ++i) {
size_t const ridx = page.base_rowid + i;
SparsePage::Inst const inst = batch[i];
if (use_group_ind_) {
group_ind = this->SearchGroupIndFromRow(group_ptr, i + page.base_rowid);
}
size_t w_idx = use_group_ind_ ? group_ind : ridx;
auto w = info.GetWeight(w_idx);
auto w = weights.empty() ? 1.0f : weights[ridx];
auto p_inst = inst.data();
if (is_dense) {
for (size_t ii = begin; ii < end; ii++) {
Expand Down Expand Up @@ -201,6 +257,8 @@ void HostSketchContainer::AllReduce(
monitor_.Start(__func__);
auto& num_cuts = *p_num_cuts;
CHECK_EQ(num_cuts.size(), 0);
num_cuts.resize(sketches_.size());

auto &reduced = *p_reduced;
reduced.resize(sketches_.size());

Expand All @@ -212,25 +270,23 @@ void HostSketchContainer::AllReduce(
std::vector<bst_row_t> global_column_size(columns_size_);
rabit::Allreduce<rabit::op::Sum>(global_column_size.data(), global_column_size.size());

size_t nbytes = 0;
for (size_t i = 0; i < sketches_.size(); ++i) {
int32_t intermediate_num_cuts = static_cast<int32_t>(std::min(
global_column_size[i], static_cast<size_t>(max_bins_ * WQSketch::kFactor)));
ParallelFor(sketches_.size(), n_threads_, [&](size_t i) {
int32_t intermediate_num_cuts = static_cast<int32_t>(
std::min(global_column_size[i],
static_cast<size_t>(max_bins_ * WQSketch::kFactor)));
if (global_column_size[i] != 0) {
WQSketch::SummaryContainer out;
sketches_[i].GetSummary(&out);
reduced[i].Reserve(intermediate_num_cuts);
CHECK(reduced[i].data);
reduced[i].SetPrune(out, intermediate_num_cuts);
nbytes = std::max(
WQSketch::SummaryContainer::CalcMemCost(intermediate_num_cuts),
nbytes);
}
num_cuts[i] = intermediate_num_cuts;
});

num_cuts.push_back(intermediate_num_cuts);
}
auto world = rabit::GetWorldSize();
if (world == 1) {
monitor_.Stop(__func__);
return;
}

Expand All @@ -242,7 +298,7 @@ size_t nbytes = 0;
&global_sketches);

std::vector<WQSketch::SummaryContainer> final_sketches(n_columns);
ParallelFor(omp_ulong(n_columns), [&](omp_ulong fidx) {
ParallelFor(n_columns, n_threads_, [&](auto fidx) {
int32_t intermediate_num_cuts = num_cuts[fidx];
auto nbytes =
WQSketch::SummaryContainer::CalcMemCost(intermediate_num_cuts);
Expand Down Expand Up @@ -276,7 +332,7 @@ void AddCutPoint(WQuantileSketch<float, float>::SummaryContainer const &summary,
auto& cut_values = cuts->cut_values_.HostVector();
for (size_t i = 1; i < required_cuts; ++i) {
bst_float cpt = summary.data[i].value;
if (i == 1 || cpt > cuts->cut_values_.ConstHostVector().back()) {
if (i == 1 || cpt > cut_values.back()) {
cut_values.push_back(cpt);
}
}
Expand All @@ -289,23 +345,28 @@ void HostSketchContainer::MakeCuts(HistogramCuts* cuts) {
this->AllReduce(&reduced, &num_cuts);

cuts->min_vals_.HostVector().resize(sketches_.size(), 0.0f);
std::vector<WQSketch::SummaryContainer> final_summaries(reduced.size());

for (size_t fid = 0; fid < reduced.size(); ++fid) {
WQSketch::SummaryContainer a;
size_t max_num_bins = std::min(num_cuts[fid], max_bins_);
ParallelFor(reduced.size(), n_threads_, Sched::Guided(), [&](size_t fidx) {
WQSketch::SummaryContainer &a = final_summaries[fidx];
size_t max_num_bins = std::min(num_cuts[fidx], max_bins_);
a.Reserve(max_num_bins + 1);
CHECK(a.data);
if (num_cuts[fid] != 0) {
a.SetPrune(reduced[fid], max_num_bins + 1);
CHECK(a.data && reduced[fid].data);
if (num_cuts[fidx] != 0) {
a.SetPrune(reduced[fidx], max_num_bins + 1);
CHECK(a.data && reduced[fidx].data);
const bst_float mval = a.data[0].value;
cuts->min_vals_.HostVector()[fid] = mval - fabs(mval) - 1e-5f;
cuts->min_vals_.HostVector()[fidx] = mval - fabs(mval) - 1e-5f;
} else {
// Empty column.
const float mval = 1e-5f;
cuts->min_vals_.HostVector()[fid] = mval;
cuts->min_vals_.HostVector()[fidx] = mval;
}
});

for (size_t fid = 0; fid < reduced.size(); ++fid) {
size_t max_num_bins = std::min(num_cuts[fid], max_bins_);
WQSketch::SummaryContainer const& a = final_summaries[fid];
AddCutPoint(a, max_num_bins, cuts);
// push a value that is greater than anything
const bst_float cpt
Expand Down
6 changes: 4 additions & 2 deletions src/common/quantile.h
Expand Up @@ -710,6 +710,7 @@ class HostSketchContainer {
std::vector<bst_row_t> columns_size_;
int32_t max_bins_;
bool use_group_ind_{false};
int32_t n_threads_;
Monitor monitor_;

public:
Expand All @@ -720,7 +721,7 @@ class HostSketchContainer {
* \param use_group whether is assigned to group to data instance.
*/
HostSketchContainer(std::vector<bst_row_t> columns_size, int32_t max_bins,
bool use_group);
bool use_group, int32_t n_threads);

static bool UseGroup(MetaInfo const &info) {
size_t const num_groups =
Expand Down Expand Up @@ -758,7 +759,8 @@ class HostSketchContainer {
std::vector<int32_t>* p_num_cuts);

/* \brief Push a CSR matrix. */
void PushRowPage(SparsePage const& page, MetaInfo const& info);
void PushRowPage(SparsePage const &page, MetaInfo const &info,
std::vector<float> const &hessian = {});

void MakeCuts(HistogramCuts* cuts);
};
Expand Down