diff --git a/amalgamation/xgboost-all0.cc b/amalgamation/xgboost-all0.cc index 389be8ab7767..8220135d9e32 100644 --- a/amalgamation/xgboost-all0.cc +++ b/amalgamation/xgboost-all0.cc @@ -70,6 +70,7 @@ #include "../src/common/common.cc" #include "../src/common/charconv.cc" #include "../src/common/timer.cc" +#include "../src/common/quantile.cc" #include "../src/common/host_device_vector.cc" #include "../src/common/hist_util.cc" #include "../src/common/json.cc" diff --git a/include/xgboost/data.h b/include/xgboost/data.h index e7350fffeeba..05b75323f7f4 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -239,6 +239,21 @@ struct BatchParam { } }; +struct HostSparsePageView { + using Inst = common::Span; + + common::Span offset; + common::Span data; + + Inst operator[](size_t i) const { + auto size = *(offset.data() + i + 1) - *(offset.data() + i); + return {data.data() + *(offset.data() + i), + static_cast(size)}; + } + + size_t Size() const { return offset.size() == 0 ? 0 : offset.size() - 1; } +}; + /*! * \brief In-memory storage unit of sparse batch, stored in CSR format. */ @@ -270,6 +285,11 @@ class SparsePage { static_cast(size)}; } + HostSparsePageView GetView() const { + return {offset.ConstHostSpan(), data.ConstHostSpan()}; + } + + /*! \brief constructor */ SparsePage() { this->Clear(); diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index f8e42f2f454a..efd62d701ccd 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -113,346 +113,12 @@ void GHistIndexMatrix::ResizeIndex(const size_t rbegin, const SparsePage& batch, } HistogramCuts::HistogramCuts() { - monitor_.Init(__FUNCTION__); cut_ptrs_.HostVector().emplace_back(0); } -// Dispatch to specific builder. -void HistogramCuts::Build(DMatrix* dmat, uint32_t const max_num_bins) { - auto const& info = dmat->Info(); - size_t const total = info.num_row_ * info.num_col_; - size_t const nnz = info.num_nonzero_; - float const sparsity = static_cast(nnz) / static_cast(total); - // Use a small number to avoid calling `dmat->GetColumnBatches'. - float constexpr kSparsityThreshold = 0.0005; - // FIXME(trivialfis): Distributed environment is not supported. - if (sparsity < kSparsityThreshold && (!rabit::IsDistributed())) { - LOG(INFO) << "Building quantile cut on a sparse dataset."; - SparseCuts cuts(this); - cuts.Build(dmat, max_num_bins); - } else { - LOG(INFO) << "Building quantile cut on a dense dataset or distributed environment."; - DenseCuts cuts(this); - cuts.Build(dmat, max_num_bins); - } - LOG(INFO) << "Total number of hist bins: " << cut_ptrs_.HostVector().back(); -} - -bool CutsBuilder::UseGroup(DMatrix* dmat) { - auto& info = dmat->Info(); - return CutsBuilder::UseGroup(info); -} - -bool CutsBuilder::UseGroup(MetaInfo const& info) { - size_t const num_groups = info.group_ptr_.size() == 0 ? - 0 : info.group_ptr_.size() - 1; - // Use group index for weights? - bool const use_group_ind = num_groups != 0 && - (info.weights_.Size() != info.num_row_); - return use_group_ind; -} - -void SparseCuts::SingleThreadBuild(SparsePage const& page, MetaInfo const& info, - uint32_t max_num_bins, - bool const use_group_ind, - uint32_t beg_col, uint32_t end_col, - uint32_t thread_id) { - CHECK_GE(end_col, beg_col); - - // Data groups, used in ranking. - std::vector const& group_ptr = info.group_ptr_; - auto &local_min_vals = p_cuts_->min_vals_.HostVector(); - auto &local_cuts = p_cuts_->cut_values_.HostVector(); - auto &local_ptrs = p_cuts_->cut_ptrs_.HostVector(); - local_min_vals.resize(end_col - beg_col, 0); - - for (uint32_t col_id = beg_col; col_id < page.Size() && col_id < end_col; ++col_id) { - // Using a local variable makes things easier, but at the cost of memory trashing. - WQSketch sketch; - common::Span const column = page[col_id]; - uint32_t const n_bins = std::min(static_cast(column.size()), - max_num_bins); - if (n_bins == 0) { - // cut_ptrs_ is initialized with a zero, so there's always an element at the back - CHECK_GE(local_ptrs.size(), 1); - local_ptrs.emplace_back(local_ptrs.back()); - continue; - } - - sketch.Init(info.num_row_, 1.0 / (n_bins * WQSketch::kFactor)); - for (auto const& entry : column) { - uint32_t weight_ind = 0; - if (use_group_ind) { - auto row_idx = entry.index; - uint32_t group_ind = - this->SearchGroupIndFromRow(group_ptr, page.base_rowid + row_idx); - weight_ind = group_ind; - } else { - weight_ind = entry.index; - } - sketch.Push(entry.fvalue, info.GetWeight(weight_ind)); - } - - WQSketch::SummaryContainer out_summary; - sketch.GetSummary(&out_summary); - WQSketch::SummaryContainer summary; - summary.Reserve(n_bins + 1); - summary.SetPrune(out_summary, n_bins + 1); - - // Can be use data[1] as the min values so that we don't need to - // store another array? - float mval = summary.data[0].value; - local_min_vals[col_id - beg_col] = mval - (fabs(mval) + 1e-5); - - this->AddCutPoint(summary, max_num_bins); - - bst_float cpt = (summary.size > 0) ? - summary.data[summary.size - 1].value : - local_min_vals[col_id - beg_col]; - cpt += fabs(cpt) + 1e-5; - local_cuts.emplace_back(cpt); - - local_ptrs.emplace_back(local_cuts.size()); - } -} - -std::vector SparseCuts::LoadBalance(SparsePage const& page, - size_t const nthreads) { - /* Some sparse datasets have their mass concentrating on small - * number of features. To avoid wating for a few threads running - * forever, we here distirbute different number of columns to - * different threads according to number of entries. */ - size_t const total_entries = page.data.Size(); - size_t const entries_per_thread = common::DivRoundUp(total_entries, nthreads); - - std::vector cols_ptr(nthreads+1, 0); - size_t count {0}; - size_t current_thread {1}; - - for (size_t col_id = 0; col_id < page.Size(); ++col_id) { - auto const column = page[col_id]; - cols_ptr[current_thread]++; // add one column to thread - count += column.size(); - if (count > entries_per_thread + 1) { - current_thread++; - count = 0; - cols_ptr[current_thread] = cols_ptr[current_thread-1]; - } - } - // Idle threads. - for (; current_thread < cols_ptr.size() - 1; ++current_thread) { - cols_ptr[current_thread+1] = cols_ptr[current_thread]; - } - - return cols_ptr; -} - -void SparseCuts::Build(DMatrix* dmat, uint32_t const max_num_bins) { - monitor_.Start(__FUNCTION__); - // Use group index for weights? - auto use_group = UseGroup(dmat); - uint32_t nthreads = omp_get_max_threads(); - CHECK_GT(nthreads, 0); - std::vector cuts_containers(nthreads); - std::vector> sparse_cuts(nthreads); - for (size_t i = 0; i < nthreads; ++i) { - sparse_cuts[i].reset(new SparseCuts(&cuts_containers[i])); - } - - for (auto const& page : dmat->GetBatches()) { - CHECK_LE(page.Size(), dmat->Info().num_col_); - monitor_.Start("Load balance"); - std::vector col_ptr = LoadBalance(page, nthreads); - monitor_.Stop("Load balance"); - // We here decouples the logic between build and parallelization - // to simplify things a bit. -#pragma omp parallel for num_threads(nthreads) schedule(static) - for (omp_ulong i = 0; i < nthreads; ++i) { - common::Monitor t_monitor; - t_monitor.Init("SingleThreadBuild: " + std::to_string(i)); - t_monitor.Start(std::to_string(i)); - sparse_cuts[i]->SingleThreadBuild(page, dmat->Info(), max_num_bins, use_group, - col_ptr[i], col_ptr[i+1], i); - t_monitor.Stop(std::to_string(i)); - } - - this->Concat(sparse_cuts, dmat->Info().num_col_); - } - - monitor_.Stop(__FUNCTION__); -} - -void SparseCuts::Concat( - std::vector> const& cuts, uint32_t n_cols) { - monitor_.Start(__FUNCTION__); - uint32_t nthreads = omp_get_max_threads(); - auto &local_min_vals = p_cuts_->min_vals_.HostVector(); - auto &local_cuts = p_cuts_->cut_values_.HostVector(); - auto &local_ptrs = p_cuts_->cut_ptrs_.HostVector(); - local_min_vals.resize(n_cols, std::numeric_limits::max()); - size_t min_vals_tail = 0; - - for (uint32_t t = 0; t < nthreads; ++t) { - auto& thread_min_vals = cuts[t]->p_cuts_->min_vals_.HostVector(); - auto& thread_cuts = cuts[t]->p_cuts_->cut_values_.HostVector(); - auto& thread_ptrs = cuts[t]->p_cuts_->cut_ptrs_.HostVector(); - - // concat csc pointers. - size_t const old_ptr_size = local_ptrs.size(); - local_ptrs.resize( - thread_ptrs.size() + local_ptrs.size() - 1); - size_t const new_icp_size = local_ptrs.size(); - auto tail = local_ptrs[old_ptr_size-1]; - for (size_t j = old_ptr_size; j < new_icp_size; ++j) { - local_ptrs[j] = tail + thread_ptrs[j-old_ptr_size+1]; - } - // concat csc values - size_t const old_iv_size = local_cuts.size(); - local_cuts.resize( - thread_cuts.size() + local_cuts.size()); - size_t const new_iv_size = local_cuts.size(); - for (size_t j = old_iv_size; j < new_iv_size; ++j) { - local_cuts[j] = thread_cuts[j-old_iv_size]; - } - // merge min values - for (size_t j = 0; j < thread_min_vals.size(); ++j) { - local_min_vals.at(min_vals_tail + j) = - std::min(local_min_vals.at(min_vals_tail + j), thread_min_vals.at(j)); - } - min_vals_tail += thread_min_vals.size(); - } - monitor_.Stop(__FUNCTION__); -} - -void DenseCuts::Build(DMatrix* p_fmat, uint32_t max_num_bins) { - monitor_.Start(__FUNCTION__); - const MetaInfo& info = p_fmat->Info(); - - // safe factor for better accuracy - std::vector sketchs; - - const int nthread = omp_get_max_threads(); - - unsigned const nstep = - static_cast((info.num_col_ + nthread - 1) / nthread); - unsigned const ncol = static_cast(info.num_col_); - sketchs.resize(info.num_col_); - for (auto& s : sketchs) { - s.Init(info.num_row_, 1.0 / (max_num_bins * WQSketch::kFactor)); - } - - // Data groups, used in ranking. - std::vector const& group_ptr = info.group_ptr_; - size_t const num_groups = group_ptr.size() == 0 ? 0 : group_ptr.size() - 1; - // Use group index for weights? - bool const use_group = UseGroup(p_fmat); - const bool isDense = p_fmat->IsDense(); - for (const auto &batch : p_fmat->GetBatches()) { - size_t group_ind = 0; - if (use_group) { - group_ind = this->SearchGroupIndFromRow(group_ptr, batch.base_rowid); - } -#pragma omp parallel num_threads(nthread) firstprivate(group_ind, use_group) - { - CHECK_EQ(nthread, omp_get_num_threads()); - auto tid = static_cast(omp_get_thread_num()); - unsigned begin = std::min(nstep * tid, ncol); - unsigned end = std::min(nstep * (tid + 1), ncol); - - // do not iterate if no columns are assigned to the thread - if (begin < end && end <= ncol) { - for (size_t i = 0; i < batch.Size(); ++i) { // NOLINT(*) - size_t const ridx = batch.base_rowid + i; - SparsePage::Inst const inst = batch[i]; - if (use_group && - group_ptr[group_ind] == ridx && - // maximum equals to weights.size() - 1 - group_ind < num_groups - 1) { - // move to next group - group_ind++; - } - size_t w_idx = use_group ? group_ind : ridx; - auto w = info.GetWeight(w_idx); - if (isDense) { - auto data = inst.data(); - for (size_t ii = begin; ii < end; ii++) { - sketchs[ii].Push(data[ii].fvalue, w); - } - } else { - for (auto const& entry : inst) { - if (entry.index >= begin && entry.index < end) { - sketchs[entry.index].Push(entry.fvalue, w); - } - } - } - } - } - } - } - - Init(&sketchs, max_num_bins, info.num_row_); - monitor_.Stop(__FUNCTION__); -} - -/** - * \param [in,out] in_sketchs - * \param max_num_bins The maximum number bins. - * \param max_rows Number of rows in this DMatrix. - */ -void DenseCuts::Init -(std::vector* in_sketchs, uint32_t max_num_bins, size_t max_rows) { - monitor_.Start(__func__); - std::vector& sketchs = *in_sketchs; - - // Compute how many cuts samples we need at each node - // Do not require more than the number of total rows in training data - // This allows efficient training on wide data - size_t global_max_rows = max_rows; - rabit::Allreduce(&global_max_rows, 1); - size_t intermediate_num_cuts = - std::min(global_max_rows, static_cast(max_num_bins * WQSketch::kFactor)); - // gather the histogram data - rabit::SerializeReducer sreducer; - std::vector summary_array; - summary_array.resize(sketchs.size()); - for (size_t i = 0; i < sketchs.size(); ++i) { - WQSketch::SummaryContainer out; - sketchs[i].GetSummary(&out); - summary_array[i].Reserve(intermediate_num_cuts); - summary_array[i].SetPrune(out, intermediate_num_cuts); - } - CHECK_EQ(summary_array.size(), in_sketchs->size()); - size_t nbytes = WQSketch::SummaryContainer::CalcMemCost(intermediate_num_cuts); - // TODO(chenqin): rabit failure recovery assumes no boostrap onetime call after loadcheckpoint - // we need to move this allreduce before loadcheckpoint call in future - sreducer.Allreduce(dmlc::BeginPtr(summary_array), nbytes, summary_array.size()); - p_cuts_->min_vals_.HostVector().resize(sketchs.size()); - - for (size_t fid = 0; fid < summary_array.size(); ++fid) { - WQSketch::SummaryContainer a; - a.Reserve(max_num_bins + 1); - a.SetPrune(summary_array[fid], max_num_bins + 1); - const bst_float mval = a.data[0].value; - p_cuts_->min_vals_.HostVector()[fid] = mval - (fabs(mval) + 1e-5); - AddCutPoint(a, max_num_bins); - // push a value that is greater than anything - const bst_float cpt - = (a.size > 0) ? a.data[a.size - 1].value : p_cuts_->min_vals_.HostVector()[fid]; - // this must be bigger than last value in a scale - const bst_float last = cpt + (fabs(cpt) + 1e-5); - p_cuts_->cut_values_.HostVector().push_back(last); - - // Ensure that every feature gets at least one quantile point - CHECK_LE(p_cuts_->cut_values_.HostVector().size(), std::numeric_limits::max()); - auto cut_size = static_cast(p_cuts_->cut_values_.HostVector().size()); - CHECK_GT(cut_size, p_cuts_->cut_ptrs_.HostVector().back()); - p_cuts_->cut_ptrs_.HostVector().push_back(cut_size); - } - monitor_.Stop(__func__); -} - void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_bins) { - cut.Build(p_fmat, max_bins); + cut = SketchOnDMatrix(p_fmat, max_bins); + max_num_bins = max_bins; const int32_t nthread = omp_get_max_threads(); const uint32_t nbins = cut.Ptrs().back(); @@ -1048,12 +714,11 @@ void BuildHistKernel(const std::vector& gpair, } } -template -void GHistBuilder::BuildHist(const std::vector& gpair, - const RowSetCollection::Elem row_indices, - const GHistIndexMatrix& gmat, - GHistRowT hist, - bool isDense) { +template +void GHistBuilder::BuildHist( + const std::vector &gpair, + const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat, + GHistRowT hist, bool isDense) { const size_t nrows = row_indices.Size(); const size_t no_prefetch_size = Prefetch::NoPrefetchSize(nrows); diff --git a/src/common/hist_util.cu b/src/common/hist_util.cu index 6953a556b939..ebd38b7aecd9 100644 --- a/src/common/hist_util.cu +++ b/src/common/hist_util.cu @@ -313,7 +313,6 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins, device, num_cuts_per_feature, has_weights); HistogramCuts cuts; - DenseCuts dense_cuts(&cuts); SketchContainer sketch_container(max_bins, dmat->Info().num_col_, dmat->Info().num_row_, device); @@ -324,7 +323,7 @@ HistogramCuts DeviceSketch(int device, DMatrix* dmat, int max_bins, for (auto begin = 0ull; begin < batch_nnz; begin += sketch_batch_num_elements) { size_t end = std::min(batch_nnz, size_t(begin + sketch_batch_num_elements)); if (has_weights) { - bool is_ranking = CutsBuilder::UseGroup(dmat); + bool is_ranking = HostSketchContainer::UseGroup(dmat->Info()); dh::caching_device_vector groups(info.group_ptr_.cbegin(), info.group_ptr_.cend()); ProcessWeightedBatch( diff --git a/src/common/hist_util.cuh b/src/common/hist_util.cuh index 8dca9fdb9bef..f1034040c1ab 100644 --- a/src/common/hist_util.cuh +++ b/src/common/hist_util.cuh @@ -306,7 +306,7 @@ void AdapterDeviceSketch(Batch batch, int num_bins, size_t end = std::min(batch.Size(), size_t(begin + sketch_batch_num_elements)); ProcessWeightedSlidingWindow(batch, info, num_cuts_per_feature, - CutsBuilder::UseGroup(info), missing, device, num_cols, begin, end, + HostSketchContainer::UseGroup(info), missing, device, num_cols, begin, end, sketch_container); } } else { diff --git a/src/common/hist_util.h b/src/common/hist_util.h index dbb0b35e4cea..d86b73135f34 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -17,6 +17,7 @@ #include #include "row_set.h" +#include "common.h" #include "threading_utils.h" #include "../tree/param.h" #include "./quantile.h" @@ -34,15 +35,8 @@ using GHistIndexRow = Span; // A CSC matrix representing histogram cuts, used in CPU quantile hist. // The cut values represent upper bounds of bins containing approximately equal numbers of elements class HistogramCuts { - // Using friends to avoid creating a virtual class, since HistogramCuts is used as value - // object in many places. - friend class SparseCuts; - friend class DenseCuts; - friend class CutsBuilder; - protected: using BinIdx = uint32_t; - common::Monitor monitor_; public: HostDeviceVector cut_values_; // NOLINT @@ -75,16 +69,12 @@ class HistogramCuts { } HistogramCuts& operator=(HistogramCuts&& that) noexcept(true) { - monitor_ = std::move(that.monitor_); cut_ptrs_ = std::move(that.cut_ptrs_); cut_values_ = std::move(that.cut_values_); min_vals_ = std::move(that.min_vals_); return *this; } - /* \brief Build histogram cuts. */ - void Build(DMatrix* dmat, uint32_t const max_num_bins); - /* \brief How many bins a feature has. */ uint32_t FeatureBins(uint32_t feature) const { return cut_ptrs_.ConstHostVector().at(feature + 1) - cut_ptrs_.ConstHostVector()[feature]; @@ -118,86 +108,42 @@ class HistogramCuts { } }; -/* \brief An interface for building quantile cuts. - * - * `DenseCuts' always assumes there are `max_bins` for each feature, which makes it not - * suitable for sparse dataset. On the other hand `SparseCuts' uses `GetColumnBatches', - * which doubles the memory usage, hence can not be applied to dense dataset. - */ -class CutsBuilder { - public: - using WQSketch = common::WQuantileSketch; - /* \brief return whether group for ranking is used. */ - static bool UseGroup(DMatrix* dmat); - static bool UseGroup(MetaInfo const& info); - - protected: - HistogramCuts* p_cuts_; - - public: - explicit CutsBuilder(HistogramCuts* p_cuts) : p_cuts_{p_cuts} {} - virtual ~CutsBuilder() = default; - - static uint32_t SearchGroupIndFromRow(std::vector const &group_ptr, - size_t const base_rowid) { - CHECK_LT(base_rowid, group_ptr.back()) - << "Row: " << base_rowid << " is not found in any group."; - auto it = - std::upper_bound(group_ptr.cbegin(), group_ptr.cend() - 1, base_rowid); - bst_group_t group_ind = it - group_ptr.cbegin() - 1; - return group_ind; - } - - void AddCutPoint(WQSketch::SummaryContainer const& summary, int max_bin) { - size_t required_cuts = std::min(summary.size, static_cast(max_bin)); - for (size_t i = 1; i < required_cuts; ++i) { - bst_float cpt = summary.data[i].value; - if (i == 1 || cpt > p_cuts_->cut_values_.ConstHostVector().back()) { - p_cuts_->cut_values_.HostVector().push_back(cpt); +inline HistogramCuts SketchOnDMatrix(DMatrix *m, int32_t max_bins) { + HistogramCuts out; + auto const& info = m->Info(); + const auto threads = omp_get_max_threads(); + std::vector> column_sizes(threads); + for (auto& column : column_sizes) { + column.resize(info.num_col_, 0); + } + for (auto const& page : m->GetBatches()) { + page.data.HostVector(); + page.offset.HostVector(); + ParallelFor(page.Size(), threads, [&](size_t i) { + auto &local_column_sizes = column_sizes.at(omp_get_thread_num()); + auto row = page[i]; + auto const *p_row = row.data(); + for (size_t j = 0; j < row.size(); ++j) { + local_column_sizes.at(p_row[j].index)++; } - } + }); } + std::vector reduced(info.num_col_, 0); - /* \brief Build histogram indices. */ - virtual void Build(DMatrix* dmat, uint32_t const max_num_bins) = 0; -}; - -/*! \brief Cut configuration for sparse dataset. */ -class SparseCuts : public CutsBuilder { - /* \brief Distribute columns to each thread according to number of entries. */ - static std::vector LoadBalance(SparsePage const& page, size_t const nthreads); - Monitor monitor_; - - public: - explicit SparseCuts(HistogramCuts* container) : - CutsBuilder(container) { - monitor_.Init(__FUNCTION__); - } - - /* \brief Concatonate the built cuts in each thread. */ - void Concat(std::vector> const& cuts, uint32_t n_cols); - /* \brief Build histogram indices in single thread. */ - void SingleThreadBuild(SparsePage const& page, MetaInfo const& info, - uint32_t max_num_bins, - bool const use_group_ind, - uint32_t beg, uint32_t end, uint32_t thread_id); - void Build(DMatrix* dmat, uint32_t const max_num_bins) override; -}; - -/*! \brief Cut configuration for dense dataset. */ -class DenseCuts : public CutsBuilder { - protected: - Monitor monitor_; + ParallelFor(info.num_col_, threads, [&](size_t i) { + for (auto const &thread : column_sizes) { + reduced[i] += thread[i]; + } + }); - public: - explicit DenseCuts(HistogramCuts* container) : - CutsBuilder(container) { - monitor_.Init(__FUNCTION__); + HostSketchContainer container(reduced, max_bins, + HostSketchContainer::UseGroup(info)); + for (auto const &page : m->GetBatches()) { + container.PushRowPage(page, info); } - void Init(std::vector* sketchs, uint32_t max_num_bins, size_t max_rows); - void Build(DMatrix* p_fmat, uint32_t max_num_bins) override; -}; - + container.MakeCuts(&out); + return out; +} enum BinTypeSize { kUint8BinsTypeSize = 1, diff --git a/src/common/quantile.cc b/src/common/quantile.cc new file mode 100644 index 000000000000..374864c8f4b0 --- /dev/null +++ b/src/common/quantile.cc @@ -0,0 +1,193 @@ +/*! + * Copyright 2020 by XGBoost Contributors + */ +#include +#include +#include "quantile.h" +#include "hist_util.h" + +namespace xgboost { +namespace common { + +HostSketchContainer::HostSketchContainer(std::vector columns_size, + int32_t max_bins, bool use_group) + : columns_size_{std::move(columns_size)}, max_bins_{max_bins}, + use_group_ind_{use_group} { + monitor_.Init(__func__); + CHECK_NE(columns_size_.size(), 0); + sketches_.resize(columns_size_.size()); + for (size_t i = 0; i < sketches_.size(); ++i) { + auto n_bins = std::min(static_cast(max_bins_), columns_size_[i]); + n_bins = std::max(n_bins, static_cast(1)); + auto eps = 1.0 / (static_cast(n_bins) * WQSketch::kFactor); + sketches_[i].Init(columns_size_[i], eps); + sketches_[i].inqueue.queue.resize(sketches_[i].limit_size * 2); + } +} + +std::vector LoadBalance(SparsePage const &page, + std::vector columns_size, + size_t const nthreads) { + /* Some sparse datasets have their mass concentrating on small + * number of features. To avoid wating for a few threads running + * forever, we here distirbute different number of columns to + * different threads according to number of entries. */ + size_t const total_entries = page.data.Size(); + size_t const entries_per_thread = common::DivRoundUp(total_entries, nthreads); + + std::vector cols_ptr(nthreads+1, 0); + size_t count {0}; + size_t current_thread {1}; + + for (auto col : columns_size) { + cols_ptr[current_thread]++; // add one column to thread + count += col; + if (count > entries_per_thread + 1) { + current_thread++; + count = 0; + cols_ptr[current_thread] = cols_ptr[current_thread-1]; + } + } + // Idle threads. + for (; current_thread < cols_ptr.size() - 1; ++current_thread) { + cols_ptr[current_thread+1] = cols_ptr[current_thread]; + } + + return cols_ptr; +} + +void HostSketchContainer::PushRowPage(SparsePage const &page, + MetaInfo const &info) { + monitor_.Start(__func__); + int nthread = omp_get_max_threads(); + CHECK_EQ(sketches_.size(), info.num_col_); + + // Data groups, used in ranking. + std::vector const &group_ptr = info.group_ptr_; + // Use group index for weights? + auto batch = page.GetView(); + dmlc::OMPException exec; + // Parallel over columns. Asumming the data is dense, each thread owns a set of + // consecutive columns. + auto const ncol = static_cast(info.num_col_); + auto const is_dense = info.num_nonzero_ == info.num_col_ * info.num_row_; + auto thread_columns_ptr = LoadBalance(page, columns_size_, nthread); + +#pragma omp parallel num_threads(nthread) + { + exec.Run([&]() { + auto tid = static_cast(omp_get_thread_num()); + auto const begin = thread_columns_ptr[tid]; + auto const end = thread_columns_ptr[tid + 1]; + size_t group_ind = 0; + + // do not iterate if no columns are assigned to the thread + if (begin < end && end <= ncol) { + for (size_t i = 0; i < batch.Size(); ++i) { + size_t const ridx = page.base_rowid + i; + SparsePage::Inst const inst = batch[i]; + if (use_group_ind_) { + group_ind = this->SearchGroupIndFromRow(group_ptr, i + page.base_rowid); + } + size_t w_idx = use_group_ind_ ? group_ind : ridx; + auto w = info.GetWeight(w_idx); + auto p_inst = inst.data(); + if (is_dense) { + for (size_t ii = begin; ii < end; ii++) { + sketches_[ii].Push(p_inst[ii].fvalue, w); + } + } else { + for (size_t i = 0; i < inst.size(); ++i) { + auto const& entry = p_inst[i]; + if (entry.index >= begin && entry.index < end) { + sketches_[entry.index].Push(entry.fvalue, w); + } + } + } + } + } + }); + } + exec.Rethrow(); + monitor_.Stop(__func__); +} + +void AddCutPoint(WQuantileSketch::SummaryContainer const &summary, + int max_bin, HistogramCuts *cuts) { + size_t required_cuts = std::min(summary.size, static_cast(max_bin)); + auto& cut_values = cuts->cut_values_.HostVector(); + for (size_t i = 1; i < required_cuts; ++i) { + bst_float cpt = summary.data[i].value; + if (i == 1 || cpt > cuts->cut_values_.ConstHostVector().back()) { + cut_values.push_back(cpt); + } + } +} + +void HostSketchContainer::MakeCuts(HistogramCuts* cuts) { + monitor_.Start(__func__); + rabit::Allreduce(columns_size_.data(), columns_size_.size()); + std::vector reduced(sketches_.size()); + std::vector num_cuts; + size_t nbytes = 0; + for (size_t i = 0; i < sketches_.size(); ++i) { + int32_t intermediate_num_cuts = static_cast(std::min( + columns_size_[i], static_cast(max_bins_ * WQSketch::kFactor))); + if (columns_size_[i] != 0) { + WQSketch::SummaryContainer out; + sketches_[i].GetSummary(&out); + reduced[i].Reserve(intermediate_num_cuts); + CHECK(reduced[i].data); + reduced[i].SetPrune(out, intermediate_num_cuts); + } + num_cuts.push_back(intermediate_num_cuts); + nbytes = std::max( + WQSketch::SummaryContainer::CalcMemCost(intermediate_num_cuts), nbytes); + } + + if (rabit::IsDistributed()) { + // FIXME(trivialfis): This call will allocate nbytes * num_columns on rabit, which + // may generate oom error when data is sparse. To fix it, we need to: + // - gather the column offsets over all workers. + // - run rabit::allgather on sketch data to collect all data. + // - merge all gathered sketches based on worker offsets and column offsets of data + // from each worker. + // See GPU implementation for details. + rabit::SerializeReducer sreducer; + sreducer.Allreduce(dmlc::BeginPtr(reduced), nbytes, reduced.size()); + } + + cuts->min_vals_.HostVector().resize(sketches_.size(), 0.0f); + for (size_t fid = 0; fid < reduced.size(); ++fid) { + WQSketch::SummaryContainer a; + size_t max_num_bins = std::min(num_cuts[fid], max_bins_); + a.Reserve(max_num_bins + 1); + CHECK(a.data); + if (columns_size_[fid] != 0) { + a.SetPrune(reduced[fid], max_num_bins + 1); + CHECK(a.data && reduced[fid].data); + const bst_float mval = a.data[0].value; + cuts->min_vals_.HostVector()[fid] = mval - fabs(mval) - 1e-5f; + } else { + // Empty column. + const float mval = 1e-5f; + cuts->min_vals_.HostVector()[fid] = mval; + } + AddCutPoint(a, max_num_bins, cuts); + // push a value that is greater than anything + const bst_float cpt + = (a.size > 0) ? a.data[a.size - 1].value : cuts->min_vals_.HostVector()[fid]; + // this must be bigger than last value in a scale + const bst_float last = cpt + (fabs(cpt) + 1e-5f); + cuts->cut_values_.HostVector().push_back(last); + + // Ensure that every feature gets at least one quantile point + CHECK_LE(cuts->cut_values_.HostVector().size(), std::numeric_limits::max()); + auto cut_size = static_cast(cuts->cut_values_.HostVector().size()); + CHECK_GT(cut_size, cuts->cut_ptrs_.HostVector().back()); + cuts->cut_ptrs_.HostVector().push_back(cut_size); + } + monitor_.Stop(__func__); +} +} // namespace common +} // namespace xgboost diff --git a/src/common/quantile.cu b/src/common/quantile.cu index 8910e41cd2e5..52d0e37e97be 100644 --- a/src/common/quantile.cu +++ b/src/common/quantile.cu @@ -20,7 +20,7 @@ namespace xgboost { namespace common { -using WQSketch = DenseCuts::WQSketch; +using WQSketch = HostSketchContainer::WQSketch; using SketchEntry = WQSketch::Entry; // Algorithm 4 in XGBoost's paper, using binary search to find i. diff --git a/src/common/quantile.h b/src/common/quantile.h index 49345d13f406..11e2530f748e 100644 --- a/src/common/quantile.h +++ b/src/common/quantile.h @@ -9,12 +9,15 @@ #include #include +#include #include #include #include #include #include +#include "timer.h" + namespace xgboost { namespace common { /*! @@ -682,6 +685,57 @@ template class WXQuantileSketch : public QuantileSketchTemplate > { }; + +class HistogramCuts; + +/*! + * A sketch matrix storing sketches for each feature. + */ +class HostSketchContainer { + public: + using WQSketch = WQuantileSketch; + + private: + std::vector sketches_; + std::vector columns_size_; + int32_t max_bins_; + bool use_group_ind_{false}; + Monitor monitor_; + + public: + /* \brief Initialize necessary info. + * + * \param columns_size Size of each column. + * \param max_bins maximum number of bins for each feature. + * \param use_group whether is assigned to group to data instance. + */ + HostSketchContainer(std::vector columns_size, int32_t max_bins, + bool use_group); + + static bool UseGroup(MetaInfo const &info) { + size_t const num_groups = + info.group_ptr_.size() == 0 ? 0 : info.group_ptr_.size() - 1; + // Use group index for weights? + bool const use_group_ind = + num_groups != 0 && (info.weights_.Size() != info.num_row_); + return use_group_ind; + } + + static uint32_t SearchGroupIndFromRow(std::vector const &group_ptr, + size_t const base_rowid) { + CHECK_LT(base_rowid, group_ptr.back()) + << "Row: " << base_rowid << " is not found in any group."; + bst_group_t group_ind = + std::upper_bound(group_ptr.cbegin(), group_ptr.cend() - 1, base_rowid) - + group_ptr.cbegin() - 1; + return group_ind; + } + + /* \brief Push a CSR matrix. */ + void PushRowPage(SparsePage const& page, MetaInfo const& info); + + void MakeCuts(HistogramCuts* cuts); +}; } // namespace common } // namespace xgboost #endif // XGBOOST_COMMON_QUANTILE_H_ diff --git a/src/common/threading_utils.h b/src/common/threading_utils.h index 3f0519b054c6..2c6f5ba11e90 100644 --- a/src/common/threading_utils.h +++ b/src/common/threading_utils.h @@ -6,9 +6,9 @@ #ifndef XGBOOST_COMMON_THREADING_UTILS_H_ #define XGBOOST_COMMON_THREADING_UTILS_H_ +#include #include #include - #include "xgboost/logging.h" namespace xgboost { @@ -115,17 +115,32 @@ void ParallelFor2d(const BlockedSpace2d& space, int nthreads, Func func) { nthreads = std::min(nthreads, omp_get_max_threads()); nthreads = std::max(nthreads, 1); + dmlc::OMPException omp_exc; #pragma omp parallel num_threads(nthreads) { - size_t tid = omp_get_thread_num(); - size_t chunck_size = num_blocks_in_space / nthreads + !!(num_blocks_in_space % nthreads); + omp_exc.Run([&]() { + size_t tid = omp_get_thread_num(); + size_t chunck_size = + num_blocks_in_space / nthreads + !!(num_blocks_in_space % nthreads); + + size_t begin = chunck_size * tid; + size_t end = std::min(begin + chunck_size, num_blocks_in_space); + for (auto i = begin; i < end; i++) { + func(space.GetFirstDimension(i), space.GetRange(i)); + } + }); + } + omp_exc.Rethrow(); +} - size_t begin = chunck_size * tid; - size_t end = std::min(begin + chunck_size, num_blocks_in_space); - for (auto i = begin; i < end; i++) { - func(space.GetFirstDimension(i), space.GetRange(i)); - } +template +void ParallelFor(size_t size, size_t nthreads, Func fn) { + dmlc::OMPException omp_exc; +#pragma omp parallel for num_threads(nthreads) + for (omp_ulong i = 0; i < size; ++i) { + omp_exc.Run(fn, i); } + omp_exc.Rethrow(); } } // namespace common diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 1c84f4947514..8d116999e62c 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -44,18 +44,16 @@ bst_float PredValue(const SparsePage::Inst &inst, template struct SparsePageView { - SparsePage const* page; bst_row_t base_rowid; + HostSparsePageView view; static size_t constexpr kUnroll = kUnrollLen; explicit SparsePageView(SparsePage const *p) - : page{p}, base_rowid{page->base_rowid} { - // Pull to host before entering omp block, as this is not thread safe. - page->data.HostVector(); - page->offset.HostVector(); + : base_rowid{p->base_rowid} { + view = p->GetView(); } - SparsePage::Inst operator[](size_t i) { return (*page)[i]; } - size_t Size() const { return page->Size(); } + SparsePage::Inst operator[](size_t i) { return view[i]; } + size_t Size() const { return view.Size(); } }; template diff --git a/tests/cpp/common/test_hist_util.cc b/tests/cpp/common/test_hist_util.cc index 7a0ff9a47215..0fad360f4298 100644 --- a/tests/cpp/common/test_hist_util.cc +++ b/tests/cpp/common/test_hist_util.cc @@ -158,86 +158,20 @@ TEST(CutsBuilder, SearchGroupInd) { HistogramCuts hmat; - size_t group_ind = CutsBuilder::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 0); + size_t group_ind = HostSketchContainer::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 0); ASSERT_EQ(group_ind, 0); - group_ind = CutsBuilder::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 5); + group_ind = HostSketchContainer::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 5); ASSERT_EQ(group_ind, 2); + EXPECT_ANY_THROW(HostSketchContainer::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 17)); + p_mat->Info().Validate(-1); - EXPECT_THROW(CutsBuilder::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 17), + EXPECT_THROW(HostSketchContainer::SearchGroupIndFromRow(p_mat->Info().group_ptr_, 17), dmlc::Error); std::vector group_ptr {0, 1, 2}; - CHECK_EQ(CutsBuilder::SearchGroupIndFromRow(group_ptr, 1), 1); -} - -TEST(SparseCuts, SingleThreadedBuild) { - size_t constexpr kRows = 267; - size_t constexpr kCols = 31; - size_t constexpr kBins = 256; - - auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); - - common::GHistIndexMatrix hmat; - hmat.Init(p_fmat.get(), kBins); - - HistogramCuts cuts; - SparseCuts indices(&cuts); - auto const& page = *(p_fmat->GetBatches().begin()); - indices.SingleThreadBuild(page, p_fmat->Info(), kBins, false, 0, page.Size(), 0); - - ASSERT_EQ(hmat.cut.Ptrs().size(), cuts.Ptrs().size()); - ASSERT_EQ(hmat.cut.Ptrs(), cuts.Ptrs()); - ASSERT_EQ(hmat.cut.Values(), cuts.Values()); - ASSERT_EQ(hmat.cut.MinValues(), cuts.MinValues()); -} - -TEST(SparseCuts, MultiThreadedBuild) { - size_t constexpr kRows = 17; - size_t constexpr kCols = 15; - size_t constexpr kBins = 255; - - omp_ulong ori_nthreads = omp_get_max_threads(); - omp_set_num_threads(16); - - auto Compare = -#if defined(_MSC_VER) // msvc fails to capture - [kBins](DMatrix* p_fmat) { -#else - [](DMatrix* p_fmat) { -#endif - HistogramCuts threaded_container; - SparseCuts threaded_indices(&threaded_container); - threaded_indices.Build(p_fmat, kBins); - - HistogramCuts container; - SparseCuts indices(&container); - auto const& page = *(p_fmat->GetBatches().begin()); - indices.SingleThreadBuild(page, p_fmat->Info(), kBins, false, 0, page.Size(), 0); - - ASSERT_EQ(container.Ptrs().size(), threaded_container.Ptrs().size()); - ASSERT_EQ(container.Values().size(), threaded_container.Values().size()); - - for (uint32_t i = 0; i < container.Ptrs().size(); ++i) { - ASSERT_EQ(container.Ptrs()[i], threaded_container.Ptrs()[i]); - } - for (uint32_t i = 0; i < container.Values().size(); ++i) { - ASSERT_EQ(container.Values()[i], threaded_container.Values()[i]); - } - }; - - { - auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); - Compare(p_fmat.get()); - } - - { - auto p_fmat = RandomDataGenerator(kRows, kCols, 0.0001).GenerateDMatrix(); - Compare(p_fmat.get()); - } - - omp_set_num_threads(ori_nthreads); + CHECK_EQ(HostSketchContainer::SearchGroupIndFromRow(group_ptr, 1), 1); } TEST(HistUtil, DenseCutsCategorical) { @@ -250,9 +184,7 @@ TEST(HistUtil, DenseCutsCategorical) { std::vector x_sorted(x); std::sort(x_sorted.begin(), x_sorted.end()); auto dmat = GetDMatrixFromData(x, n, 1); - HistogramCuts cuts; - DenseCuts dense(&cuts); - dense.Build(dmat.get(), num_bins); + HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins); auto cuts_from_sketch = cuts.Values(); EXPECT_LT(cuts.MinValues()[0], x_sorted.front()); EXPECT_GT(cuts_from_sketch.front(), x_sorted.front()); @@ -264,15 +196,14 @@ TEST(HistUtil, DenseCutsCategorical) { TEST(HistUtil, DenseCutsAccuracyTest) { int bin_sizes[] = {2, 16, 256, 512}; - int sizes[] = {100, 1000, 1500}; + int sizes[] = {100}; + // omp_set_num_threads(1); int num_columns = 5; for (auto num_rows : sizes) { auto x = GenerateRandom(num_rows, num_columns); auto dmat = GetDMatrixFromData(x, num_rows, num_columns); for (auto num_bins : bin_sizes) { - HistogramCuts cuts; - DenseCuts dense(&cuts); - dense.Build(dmat.get(), num_bins); + HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins); ValidateCuts(cuts, dmat.get(), num_bins); } } @@ -288,9 +219,7 @@ TEST(HistUtil, DenseCutsAccuracyTestWeights) { auto w = GenerateRandomWeights(num_rows); dmat->Info().weights_.HostVector() = w; for (auto num_bins : bin_sizes) { - HistogramCuts cuts; - DenseCuts dense(&cuts); - dense.Build(dmat.get(), num_bins); + HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins); ValidateCuts(cuts, dmat.get(), num_bins); } } @@ -306,65 +235,7 @@ TEST(HistUtil, DenseCutsExternalMemory) { auto dmat = GetExternalMemoryDMatrixFromData(x, num_rows, num_columns, 50, tmpdir); for (auto num_bins : bin_sizes) { - HistogramCuts cuts; - DenseCuts dense(&cuts); - dense.Build(dmat.get(), num_bins); - ValidateCuts(cuts, dmat.get(), num_bins); - } - } -} - -TEST(HistUtil, SparseCutsAccuracyTest) { - int bin_sizes[] = {2, 16, 256, 512}; - int sizes[] = {100, 1000, 1500}; - int num_columns = 5; - for (auto num_rows : sizes) { - auto x = GenerateRandom(num_rows, num_columns); - auto dmat = GetDMatrixFromData(x, num_rows, num_columns); - for (auto num_bins : bin_sizes) { - HistogramCuts cuts; - SparseCuts sparse(&cuts); - sparse.Build(dmat.get(), num_bins); - ValidateCuts(cuts, dmat.get(), num_bins); - } - } -} - -TEST(HistUtil, SparseCutsCategorical) { - int categorical_sizes[] = {2, 6, 8, 12}; - int num_bins = 256; - int sizes[] = {25, 100, 1000}; - for (auto n : sizes) { - for (auto num_categories : categorical_sizes) { - auto x = GenerateRandomCategoricalSingleColumn(n, num_categories); - std::vector x_sorted(x); - std::sort(x_sorted.begin(), x_sorted.end()); - auto dmat = GetDMatrixFromData(x, n, 1); - HistogramCuts cuts; - SparseCuts sparse(&cuts); - sparse.Build(dmat.get(), num_bins); - auto cuts_from_sketch = cuts.Values(); - EXPECT_LT(cuts.MinValues()[0], x_sorted.front()); - EXPECT_GT(cuts_from_sketch.front(), x_sorted.front()); - EXPECT_GE(cuts_from_sketch.back(), x_sorted.back()); - EXPECT_EQ(cuts_from_sketch.size(), num_categories); - } - } -} - -TEST(HistUtil, SparseCutsExternalMemory) { - int bin_sizes[] = {2, 16, 256, 512}; - int sizes[] = {100, 1000, 1500}; - int num_columns = 5; - for (auto num_rows : sizes) { - auto x = GenerateRandom(num_rows, num_columns); - dmlc::TemporaryDirectory tmpdir; - auto dmat = - GetExternalMemoryDMatrixFromData(x, num_rows, num_columns, 50, tmpdir); - for (auto num_bins : bin_sizes) { - HistogramCuts cuts; - SparseCuts dense(&cuts); - dense.Build(dmat.get(), num_bins); + HistogramCuts cuts = SketchOnDMatrix(dmat.get(), num_bins); ValidateCuts(cuts, dmat.get(), num_bins); } } @@ -391,25 +262,6 @@ TEST(HistUtil, IndexBinBound) { } } -TEST(HistUtil, SparseIndexBinBound) { - uint64_t bin_sizes[] = { static_cast(std::numeric_limits::max()) + 1, - static_cast(std::numeric_limits::max()) + 1, - static_cast(std::numeric_limits::max()) + 2 }; - BinTypeSize expected_bin_type_sizes[] = { kUint32BinsTypeSize, - kUint32BinsTypeSize, - kUint32BinsTypeSize }; - size_t constexpr kRows = 100; - size_t constexpr kCols = 10; - - size_t bin_id = 0; - for (auto max_bin : bin_sizes) { - auto p_fmat = RandomDataGenerator(kRows, kCols, 0.2).GenerateDMatrix(); - common::GHistIndexMatrix hmat; - hmat.Init(p_fmat.get(), max_bin); - EXPECT_EQ(expected_bin_type_sizes[bin_id++], hmat.index.GetBinTypeSize()); - } -} - template void CheckIndexData(T* data_ptr, uint32_t* offsets, const common::GHistIndexMatrix& hmat, size_t n_cols) { @@ -448,25 +300,61 @@ TEST(HistUtil, IndexBinData) { } } -TEST(HistUtil, SparseIndexBinData) { - uint64_t bin_sizes[] = { static_cast(std::numeric_limits::max()) + 1, - static_cast(std::numeric_limits::max()) + 1, - static_cast(std::numeric_limits::max()) + 2 }; - size_t constexpr kRows = 100; - size_t constexpr kCols = 10; +void TestSketchFromWeights(bool with_group) { + size_t constexpr kRows = 300, kCols = 20, kBins = 256; + size_t constexpr kGroups = 10; + auto m = + RandomDataGenerator{kRows, kCols, 0}.Device(0).GenerateDMatrix(); + common::HistogramCuts cuts = SketchOnDMatrix(m.get(), kBins); + + MetaInfo info; + auto& h_weights = info.weights_.HostVector(); + if (with_group) { + h_weights.resize(kGroups); + } else { + h_weights.resize(kRows); + } + std::fill(h_weights.begin(), h_weights.end(), 1.0f); - for (auto max_bin : bin_sizes) { - auto p_fmat = RandomDataGenerator(kRows, kCols, 0.2).GenerateDMatrix(); - common::GHistIndexMatrix hmat; - hmat.Init(p_fmat.get(), max_bin); - EXPECT_EQ(hmat.index.Offset(), nullptr); + std::vector groups(kGroups); + if (with_group) { + for (size_t i = 0; i < kGroups; ++i) { + groups[i] = kRows / kGroups; + } + info.SetInfo("group", groups.data(), DataType::kUInt32, kGroups); + } + + info.num_row_ = kRows; + info.num_col_ = kCols; + + // Assign weights. + if (with_group) { + m->Info().SetInfo("group", groups.data(), DataType::kUInt32, kGroups); + } + + m->Info().SetInfo("weight", h_weights.data(), DataType::kFloat32, h_weights.size()); + m->Info().num_col_ = kCols; + m->Info().num_row_ = kRows; + ASSERT_EQ(cuts.Ptrs().size(), kCols + 1); + ValidateCuts(cuts, m.get(), kBins); - uint32_t* data_ptr = hmat.index.data(); - for (size_t i = 0; i < hmat.index.Size(); ++i) { - EXPECT_EQ(data_ptr[i], hmat.index[i]); + if (with_group) { + HistogramCuts non_weighted = SketchOnDMatrix(m.get(), kBins); + for (size_t i = 0; i < cuts.Values().size(); ++i) { + EXPECT_EQ(cuts.Values()[i], non_weighted.Values()[i]); + } + for (size_t i = 0; i < cuts.MinValues().size(); ++i) { + ASSERT_EQ(cuts.MinValues()[i], non_weighted.MinValues()[i]); + } + for (size_t i = 0; i < cuts.Ptrs().size(); ++i) { + ASSERT_EQ(cuts.Ptrs().at(i), non_weighted.Ptrs().at(i)); } } } +TEST(HistUtil, SketchFromWeights) { + TestSketchFromWeights(true); + TestSketchFromWeights(false); +} } // namespace common } // namespace xgboost diff --git a/tests/cpp/common/test_hist_util.cu b/tests/cpp/common/test_hist_util.cu index 83e9595f75f1..b225acb2039d 100644 --- a/tests/cpp/common/test_hist_util.cu +++ b/tests/cpp/common/test_hist_util.cu @@ -24,10 +24,8 @@ namespace common { template HistogramCuts GetHostCuts(AdapterT *adapter, int num_bins, float missing) { - HistogramCuts cuts; - DenseCuts builder(&cuts); data::SimpleDMatrix dmat(adapter, missing, 1); - builder.Build(&dmat, num_bins); + HistogramCuts cuts = SketchOnDMatrix(&dmat, num_bins); return cuts; } @@ -39,9 +37,7 @@ TEST(HistUtil, DeviceSketch) { auto dmat = GetDMatrixFromData(x, num_rows, num_columns); auto device_cuts = DeviceSketch(0, dmat.get(), num_bins); - HistogramCuts host_cuts; - DenseCuts builder(&host_cuts); - builder.Build(dmat.get(), num_bins); + HistogramCuts host_cuts = SketchOnDMatrix(dmat.get(), num_bins); EXPECT_EQ(device_cuts.Values(), host_cuts.Values()); EXPECT_EQ(device_cuts.Ptrs(), host_cuts.Ptrs()); @@ -460,7 +456,11 @@ void TestAdapterSketchFromWeights(bool with_group) { &storage); MetaInfo info; auto& h_weights = info.weights_.HostVector(); - h_weights.resize(kRows); + if (with_group) { + h_weights.resize(kGroups); + } else { + h_weights.resize(kRows); + } std::fill(h_weights.begin(), h_weights.end(), 1.0f); std::vector groups(kGroups); diff --git a/tests/cpp/common/test_quantile.cc b/tests/cpp/common/test_quantile.cc new file mode 100644 index 000000000000..c273658e54cb --- /dev/null +++ b/tests/cpp/common/test_quantile.cc @@ -0,0 +1,77 @@ +#include +#include "test_quantile.h" +#include "../../../src/common/quantile.h" +#include "../../../src/common/hist_util.h" + +namespace xgboost { +namespace common { +TEST(Quantile, SameOnAllWorkers) { + std::string msg{"Skipping Quantile AllreduceBasic test"}; + size_t constexpr kWorkers = 4; + InitRabitContext(msg, kWorkers); + auto world = rabit::GetWorldSize(); + if (world != 1) { + CHECK_EQ(world, kWorkers); + } else { + return; + } + + constexpr size_t kRows = 1000, kCols = 100; + RunWithSeedsAndBins( + kRows, [=](int32_t seed, size_t n_bins, MetaInfo const &info) { + auto rank = rabit::GetRank(); + HostDeviceVector storage; + auto m = RandomDataGenerator{kRows, kCols, 0} + .Device(0) + .Seed(rank + seed) + .GenerateDMatrix(); + auto cuts = SketchOnDMatrix(m.get(), n_bins); + std::vector cut_values(cuts.Values().size() * world, 0); + std::vector< + typename std::remove_reference_t::value_type> + cut_ptrs(cuts.Ptrs().size() * world, 0); + std::vector cut_min_values(cuts.MinValues().size() * world, 0); + + size_t value_size = cuts.Values().size(); + rabit::Allreduce(&value_size, 1); + size_t ptr_size = cuts.Ptrs().size(); + rabit::Allreduce(&ptr_size, 1); + CHECK_EQ(ptr_size, kCols + 1); + size_t min_value_size = cuts.MinValues().size(); + rabit::Allreduce(&min_value_size, 1); + CHECK_EQ(min_value_size, kCols); + + size_t value_offset = value_size * rank; + std::copy(cuts.Values().begin(), cuts.Values().end(), + cut_values.begin() + value_offset); + size_t ptr_offset = ptr_size * rank; + std::copy(cuts.Ptrs().cbegin(), cuts.Ptrs().cend(), + cut_ptrs.begin() + ptr_offset); + size_t min_values_offset = min_value_size * rank; + std::copy(cuts.MinValues().cbegin(), cuts.MinValues().cend(), + cut_min_values.begin() + min_values_offset); + + rabit::Allreduce(cut_values.data(), cut_values.size()); + rabit::Allreduce(cut_ptrs.data(), cut_ptrs.size()); + rabit::Allreduce(cut_min_values.data(), cut_min_values.size()); + + for (int32_t i = 0; i < world; i++) { + for (size_t j = 0; j < value_size; ++j) { + size_t idx = i * value_size + j; + ASSERT_NEAR(cuts.Values().at(j), cut_values.at(idx), kRtEps); + } + + for (size_t j = 0; j < ptr_size; ++j) { + size_t idx = i * ptr_size + j; + ASSERT_EQ(cuts.Ptrs().at(j), cut_ptrs.at(idx)); + } + + for (size_t j = 0; j < min_value_size; ++j) { + size_t idx = i * min_value_size + j; + ASSERT_EQ(cuts.MinValues().at(j), cut_min_values.at(idx)); + } + } + }); +} +} // namespace common +} // namespace xgboost diff --git a/tests/cpp/common/test_quantile.cu b/tests/cpp/common/test_quantile.cu index e16edfef06b7..f7c7e22e3650 100644 --- a/tests/cpp/common/test_quantile.cu +++ b/tests/cpp/common/test_quantile.cu @@ -1,4 +1,5 @@ #include +#include "test_quantile.h" #include "../helpers.h" #include "../../../src/common/hist_util.cuh" #include "../../../src/common/quantile.cuh" @@ -16,32 +17,6 @@ TEST(GPUQuantile, Basic) { ASSERT_EQ(sketch.Data().size(), 0); } -template void RunWithSeedsAndBins(size_t rows, Fn fn) { - std::vector seeds(4); - SimpleLCG lcg; - SimpleRealUniformDistribution dist(3, 1000); - std::generate(seeds.begin(), seeds.end(), [&](){ return dist(&lcg); }); - - std::vector bins(8); - for (size_t i = 0; i < bins.size() - 1; ++i) { - bins[i] = i * 35 + 2; - } - bins.back() = rows + 80; // provide a bin number greater than rows. - - std::vector infos(2); - auto& h_weights = infos.front().weights_.HostVector(); - h_weights.resize(rows); - std::generate(h_weights.begin(), h_weights.end(), [&]() { return dist(&lcg); }); - - for (auto seed : seeds) { - for (auto n_bin : bins) { - for (auto const& info : infos) { - fn(seed, n_bin, info); - } - } - } -} - void TestSketchUnique(float sparsity) { constexpr size_t kRows = 1000, kCols = 100; RunWithSeedsAndBins(kRows, [kRows, kCols, sparsity](int32_t seed, size_t n_bins, MetaInfo const& info) { @@ -297,31 +272,12 @@ TEST(GPUQuantile, MergeDuplicated) { } } -void InitRabitContext(std::string msg) { - auto n_gpus = AllVisibleGPUs(); - auto port = std::getenv("DMLC_TRACKER_PORT"); - std::string port_str; - if (port) { - port_str = port; - } else { - LOG(WARNING) << msg << " as `DMLC_TRACKER_PORT` is not set up."; - return; - } - - std::vector envs{ - "DMLC_TRACKER_PORT=" + port_str, - "DMLC_TRACKER_URI=127.0.0.1", - "DMLC_NUM_WORKER=" + std::to_string(n_gpus)}; - char* c_envs[] {&(envs[0][0]), &(envs[1][0]), &(envs[2][0])}; - rabit::Init(3, c_envs); -} - TEST(GPUQuantile, AllReduceBasic) { // This test is supposed to run by a python test that setups the environment. std::string msg {"Skipping AllReduce test"}; #if defined(__linux__) && defined(XGBOOST_USE_NCCL) - InitRabitContext(msg); auto n_gpus = AllVisibleGPUs(); + InitRabitContext(msg, n_gpus); auto world = rabit::GetWorldSize(); if (world != 1) { ASSERT_EQ(world, n_gpus); @@ -407,9 +363,9 @@ TEST(GPUQuantile, AllReduceBasic) { TEST(GPUQuantile, SameOnAllWorkers) { std::string msg {"Skipping SameOnAllWorkers test"}; #if defined(__linux__) && defined(XGBOOST_USE_NCCL) - InitRabitContext(msg); - auto world = rabit::GetWorldSize(); auto n_gpus = AllVisibleGPUs(); + InitRabitContext(msg, n_gpus); + auto world = rabit::GetWorldSize(); if (world != 1) { ASSERT_EQ(world, n_gpus); } else { diff --git a/tests/cpp/common/test_quantile.h b/tests/cpp/common/test_quantile.h new file mode 100644 index 000000000000..7dea0b17deb3 --- /dev/null +++ b/tests/cpp/common/test_quantile.h @@ -0,0 +1,54 @@ +#include +#include +#include +#include + +#include "../helpers.h" + +namespace xgboost { +namespace common { +inline void InitRabitContext(std::string msg, size_t n_workers) { + auto port = std::getenv("DMLC_TRACKER_PORT"); + std::string port_str; + if (port) { + port_str = port; + } else { + LOG(WARNING) << msg << " as `DMLC_TRACKER_PORT` is not set up."; + return; + } + + std::vector envs{ + "DMLC_TRACKER_PORT=" + port_str, + "DMLC_TRACKER_URI=127.0.0.1", + "DMLC_NUM_WORKER=" + std::to_string(n_workers)}; + char* c_envs[] {&(envs[0][0]), &(envs[1][0]), &(envs[2][0])}; + rabit::Init(3, c_envs); +} + +template void RunWithSeedsAndBins(size_t rows, Fn fn) { + std::vector seeds(4); + SimpleLCG lcg; + SimpleRealUniformDistribution dist(3, 1000); + std::generate(seeds.begin(), seeds.end(), [&](){ return dist(&lcg); }); + + std::vector bins(8); + for (size_t i = 0; i < bins.size() - 1; ++i) { + bins[i] = i * 35 + 2; + } + bins.back() = rows + 80; // provide a bin number greater than rows. + + std::vector infos(2); + auto& h_weights = infos.front().weights_.HostVector(); + h_weights.resize(rows); + std::generate(h_weights.begin(), h_weights.end(), [&]() { return dist(&lcg); }); + + for (auto seed : seeds) { + for (auto n_bin : bins) { + for (auto const& info : infos) { + fn(seed, n_bin, info); + } + } + } +} +} // namespace common +} // namespace xgboost diff --git a/tests/python-gpu/test_gpu_with_dask.py b/tests/python-gpu/test_gpu_with_dask.py index f118c7188528..e5c90181373f 100644 --- a/tests/python-gpu/test_gpu_with_dask.py +++ b/tests/python-gpu/test_gpu_with_dask.py @@ -233,12 +233,14 @@ def runit(worker_addr, rabit_args): assert ret.returncode == 0, msg @pytest.mark.skipif(**tm.no_dask()) + @pytest.mark.skipif(**tm.no_dask_cuda()) @pytest.mark.mgpu @pytest.mark.gtest def test_quantile_basic(self): self.run_quantile('AllReduceBasic') @pytest.mark.skipif(**tm.no_dask()) + @pytest.mark.skipif(**tm.no_dask_cuda()) @pytest.mark.mgpu @pytest.mark.gtest def test_quantile_same_on_all_workers(self): diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index b4be33ed348f..dc5c155e6027 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -1,11 +1,16 @@ import testing as tm import pytest +import unittest import xgboost as xgb import sys import numpy as np import json import asyncio from sklearn.datasets import make_classification +import os +import subprocess +from hypothesis import given, strategies, settings, note +from test_updaters import hist_parameter_strategy, exact_parameter_strategy if sys.platform.startswith("win"): pytest.skip("Skipping dask tests on Windows", allow_module_level=True) @@ -14,12 +19,16 @@ try: from distributed import LocalCluster, Client + from distributed.utils_test import client, loop, cluster_fixture import dask.dataframe as dd import dask.array as da from xgboost.dask import DaskDMatrix except ImportError: LocalCluster = None Client = None + client = None + loop = None + cluster_fixture = None dd = None da = None DaskDMatrix = None @@ -461,3 +470,97 @@ def test_with_asyncio(): asyncio.run(run_dask_regressor_asyncio(address)) asyncio.run(run_dask_classifier_asyncio(address)) + + +class TestWithDask: + def run_updater_test(self, client, params, num_rounds, dataset, + tree_method): + params['tree_method'] = tree_method + params = dataset.set_params(params) + # multi class doesn't handle empty dataset well (empty + # means at least 1 worker has data). + if params['objective'] == "multi:softmax": + return + # It doesn't make sense to distribute a completely + # empty dataset. + if dataset.X.shape[0] == 0: + return + + chunk = 128 + X = da.from_array(dataset.X, + chunks=(chunk, dataset.X.shape[1])) + y = da.from_array(dataset.y, chunks=(chunk, )) + if dataset.w is not None: + w = da.from_array(dataset.w, chunks=(chunk, )) + else: + w = None + + m = xgb.dask.DaskDMatrix( + client, data=X, label=y, weight=w) + history = xgb.dask.train(client, params=params, dtrain=m, + num_boost_round=num_rounds, + evals=[(m, 'train')])['history'] + note(history) + assert tm.non_increasing(history['train'][dataset.metric]) + + @given(params=hist_parameter_strategy, + num_rounds=strategies.integers(10, 20), + dataset=tm.dataset_strategy) + @settings(deadline=None) + def test_hist(self, params, num_rounds, dataset, client): + self.run_updater_test(client, params, num_rounds, dataset, 'hist') + + @given(params=exact_parameter_strategy, + num_rounds=strategies.integers(10, 20), + dataset=tm.dataset_strategy) + @settings(deadline=None) + def test_approx(self, client, params, num_rounds, dataset): + self.run_updater_test(client, params, num_rounds, dataset, 'approx') + + def run_quantile(self, name): + if sys.platform.startswith("win"): + pytest.skip("Skipping dask tests on Windows") + + exe = None + for possible_path in {'./testxgboost', './build/testxgboost', + '../build/testxgboost', + '../cpu-build/testxgboost', + '../gpu-build/testxgboost'}: + if os.path.exists(possible_path): + exe = possible_path + if exe is None: + return + + test = "--gtest_filter=Quantile." + name + + def runit(worker_addr, rabit_args): + port = None + # setup environment for running the c++ part. + for arg in rabit_args: + if arg.decode('utf-8').startswith('DMLC_TRACKER_PORT'): + port = arg.decode('utf-8') + port = port.split('=') + env = os.environ.copy() + env[port[0]] = port[1] + return subprocess.run([exe, test], env=env, stdout=subprocess.PIPE) + + with LocalCluster(n_workers=4) as cluster: + with Client(cluster) as client: + workers = list(xgb.dask._get_client_workers(client).keys()) + rabit_args = client.sync( + xgb.dask._get_rabit_args, workers, client) + futures = client.map(runit, + workers, + pure=False, + workers=workers, + rabit_args=rabit_args) + results = client.gather(futures) + for ret in results: + msg = ret.stdout.decode('utf-8') + assert msg.find('1 test from Quantile') != -1, msg + assert ret.returncode == 0, msg + + @pytest.mark.skipif(**tm.no_dask()) + @pytest.mark.gtest + def test_quantile_basic(self): + self.run_quantile('SameOnAllWorkers')