From 4fcfd9c96e00670c06f0338f8da443052861a994 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Mon, 16 May 2022 21:11:50 +0800 Subject: [PATCH] Fix and cleanup for column matrix. (#7901) * Fix missed type dispatching for dense columns with missing values. * Code cleanup to reduce special cases. * Reduce memory usage. --- src/common/column_matrix.h | 216 +++++++++----------- src/common/hist_util.h | 24 ++- src/common/partition_builder.h | 2 +- src/objective/adaptive.cc | 2 +- src/tree/updater_approx.cc | 2 +- src/tree/updater_quantile_hist.cc | 4 +- src/tree/updater_quantile_hist.h | 2 +- tests/cpp/common/test_column_matrix.cc | 7 +- tests/cpp/tree/hist/test_evaluate_splits.cc | 1 - tests/cpp/tree/test_quantile_hist.cc | 2 +- 10 files changed, 125 insertions(+), 137 deletions(-) diff --git a/src/common/column_matrix.h b/src/common/column_matrix.h index 01dfe548b181..77c67620b444 100644 --- a/src/common/column_matrix.h +++ b/src/common/column_matrix.h @@ -125,16 +125,20 @@ class DenseColumnIter : public Column { } }; -/*! \brief a collection of columns, with support for construction from - GHistIndexMatrix. */ +/** + * \brief Column major matrix for gradient index. This matrix contains both dense column + * and sparse column, the type of the column is controlled by sparse threshold. When the + * number of missing values in a column is below the threshold it classified as dense + * column. + */ class ColumnMatrix { public: // get number of features bst_feature_t GetNumFeature() const { return static_cast(type_.size()); } // construct column matrix from GHistIndexMatrix - inline void Init(SparsePage const& page, const GHistIndexMatrix& gmat, double sparse_threshold, - int32_t n_threads) { + void Init(SparsePage const& page, const GHistIndexMatrix& gmat, double sparse_threshold, + int32_t n_threads) { auto const nfeature = static_cast(gmat.cut.Ptrs().size() - 1); const size_t nrow = gmat.row_ptr.size() - 1; // identify type of each column @@ -145,13 +149,14 @@ class ColumnMatrix { for (bst_feature_t fid = 0; fid < nfeature; ++fid) { CHECK_LE(gmat.cut.Ptrs()[fid + 1] - gmat.cut.Ptrs()[fid], max_val); } - bool all_dense = gmat.IsDense(); + + bool all_dense_column = true; gmat.GetFeatureCounts(&feature_counts_[0]); // classify features for (bst_feature_t fid = 0; fid < nfeature; ++fid) { if (static_cast(feature_counts_[fid]) < sparse_threshold * nrow) { type_[fid] = kSparseColumn; - all_dense = false; + all_dense_column = false; } else { type_[fid] = kDenseColumn; } @@ -160,70 +165,51 @@ class ColumnMatrix { // want to compute storage boundary for each feature // using variants of prefix sum scan feature_offsets_.resize(nfeature + 1); - size_t accum_index_ = 0; - feature_offsets_[0] = accum_index_; + size_t accum_index = 0; + feature_offsets_[0] = accum_index; for (bst_feature_t fid = 1; fid < nfeature + 1; ++fid) { if (type_[fid - 1] == kDenseColumn) { - accum_index_ += static_cast(nrow); + accum_index += static_cast(nrow); } else { - accum_index_ += feature_counts_[fid - 1]; + accum_index += feature_counts_[fid - 1]; } - feature_offsets_[fid] = accum_index_; + feature_offsets_[fid] = accum_index; } SetTypeSize(gmat.max_num_bins); - - index_.resize(feature_offsets_[nfeature] * bins_type_size_, 0); - if (!all_dense) { + auto storage_size = + feature_offsets_.back() * static_cast>(bins_type_size_); + index_.resize(storage_size, 0); + if (!all_dense_column) { row_ind_.resize(feature_offsets_[nfeature]); } // store least bin id for each feature index_base_ = const_cast(gmat.cut.Ptrs().data()); - const bool noMissingValues = NoMissingValues(gmat.row_ptr[nrow], nrow, nfeature); - any_missing_ = !noMissingValues; + any_missing_ = !gmat.IsDense(); missing_flags_.clear(); - if (noMissingValues) { + // pre-fill index_ for dense columns + BinTypeSize gmat_bin_size = gmat.index.GetBinTypeSize(); + if (!any_missing_) { missing_flags_.resize(feature_offsets_[nfeature], false); + // row index is compressed, we need to dispatch it. + DispatchBinType(gmat_bin_size, [&, nrow, nfeature, n_threads](auto t) { + using RowBinIdxT = decltype(t); + SetIndexNoMissing(page, gmat.index.data(), nrow, nfeature, n_threads); + }); } else { missing_flags_.resize(feature_offsets_[nfeature], true); - } - - // pre-fill index_ for dense columns - if (all_dense) { - BinTypeSize gmat_bin_size = gmat.index.GetBinTypeSize(); - if (gmat_bin_size == kUint8BinsTypeSize) { - SetIndexAllDense(page, gmat.index.data(), gmat, nrow, nfeature, noMissingValues, - n_threads); - } else if (gmat_bin_size == kUint16BinsTypeSize) { - SetIndexAllDense(page, gmat.index.data(), gmat, nrow, nfeature, noMissingValues, - n_threads); - } else { - CHECK_EQ(gmat_bin_size, kUint32BinsTypeSize); - SetIndexAllDense(page, gmat.index.data(), gmat, nrow, nfeature, noMissingValues, - n_threads); - } - /* For sparse DMatrix gmat.index.getBinTypeSize() returns always kUint32BinsTypeSize - but for ColumnMatrix we still have a chance to reduce the memory consumption */ - } else { - if (bins_type_size_ == kUint8BinsTypeSize) { - SetIndex(page, gmat.index.data(), gmat, nfeature); - } else if (bins_type_size_ == kUint16BinsTypeSize) { - SetIndex(page, gmat.index.data(), gmat, nfeature); - } else { - CHECK_EQ(bins_type_size_, kUint32BinsTypeSize); - SetIndex(page, gmat.index.data(), gmat, nfeature); - } + SetIndexMixedColumns(page, gmat.index.data(), gmat, nfeature); } } /* Set the number of bytes based on numeric limit of maximum number of bins provided by user */ - void SetTypeSize(size_t max_num_bins) { - if ((max_num_bins - 1) <= static_cast(std::numeric_limits::max())) { + void SetTypeSize(size_t max_bin_per_feat) { + if ((max_bin_per_feat - 1) <= static_cast(std::numeric_limits::max())) { bins_type_size_ = kUint8BinsTypeSize; - } else if ((max_num_bins - 1) <= static_cast(std::numeric_limits::max())) { + } else if ((max_bin_per_feat - 1) <= static_cast(std::numeric_limits::max())) { bins_type_size_ = kUint16BinsTypeSize; } else { bins_type_size_ = kUint32BinsTypeSize; @@ -252,98 +238,78 @@ class ColumnMatrix { bin_index, static_cast(index_base_[fidx]), missing_flags_, feature_offset}); } - template - inline void SetIndexAllDense(SparsePage const& page, T const* index, const GHistIndexMatrix& gmat, - const size_t nrow, const size_t nfeature, const bool noMissingValues, - int32_t n_threads) { - T* local_index = reinterpret_cast(&index_[0]); - - /* missing values make sense only for column with type kDenseColumn, - and if no missing values were observed it could be handled much faster. */ - if (noMissingValues) { - ParallelFor(nrow, n_threads, [&](auto rid) { - const size_t ibegin = rid * nfeature; - const size_t iend = (rid + 1) * nfeature; + // all columns are dense column and has no missing value + // FIXME(jiamingy): We don't need a column matrix if there's no missing value. + template + void SetIndexNoMissing(SparsePage const& page, RowBinIdxT const* row_index, + const size_t n_samples, const size_t n_features, int32_t n_threads) { + DispatchBinType(bins_type_size_, [&](auto t) { + using ColumnBinT = decltype(t); + auto column_index = Span{reinterpret_cast(index_.data()), + index_.size() / sizeof(ColumnBinT)}; + ParallelFor(n_samples, n_threads, [&](auto rid) { + const size_t ibegin = rid * n_features; + const size_t iend = (rid + 1) * n_features; size_t j = 0; for (size_t i = ibegin; i < iend; ++i, ++j) { const size_t idx = feature_offsets_[j]; - local_index[idx + rid] = index[i]; + // No need to add offset, as row index is compressed and stores the local index + column_index[idx + rid] = row_index[i]; } }); - } else { - /* to handle rows in all batches, sum of all batch sizes equal to gmat.row_ptr.size() - 1 */ - auto get_bin_idx = [&](auto bin_id, auto rid, bst_feature_t fid) { - // T* begin = &local_index[feature_offsets_[fid]]; - const size_t idx = feature_offsets_[fid]; - /* rbegin allows to store indexes from specific SparsePage batch */ - local_index[idx + rid] = bin_id; - - missing_flags_[idx + rid] = false; - }; - this->SetIndexSparse(page, index, gmat, nfeature, get_bin_idx); - } - } - - // FIXME(jiamingy): In the future we might want to simply use binary search to simplify - // this and remove the dependency on SparsePage. This way we can have quantilized - // matrix for host similar to `DeviceQuantileDMatrix`. - template - void SetIndexSparse(SparsePage const& batch, T* index, const GHistIndexMatrix& gmat, - const size_t nfeature, BinFn&& assign_bin) { - std::vector num_nonzeros(nfeature, 0ul); - const xgboost::Entry* data_ptr = batch.data.HostVector().data(); - const std::vector& offset_vec = batch.offset.HostVector(); - auto rbegin = 0; - const size_t batch_size = gmat.Size(); - CHECK_LT(batch_size, offset_vec.size()); - - for (size_t rid = 0; rid < batch_size; ++rid) { - const size_t ibegin = gmat.row_ptr[rbegin + rid]; - const size_t iend = gmat.row_ptr[rbegin + rid + 1]; - const size_t size = offset_vec[rid + 1] - offset_vec[rid]; - SparsePage::Inst inst = {data_ptr + offset_vec[rid], size}; - - CHECK_EQ(ibegin + inst.size(), iend); - size_t j = 0; - for (size_t i = ibegin; i < iend; ++i, ++j) { - const uint32_t bin_id = index[i]; - auto fid = inst[j].index; - assign_bin(bin_id, rid, fid); - } - } + }); } - template - inline void SetIndex(SparsePage const& page, uint32_t const* index, const GHistIndexMatrix& gmat, - const size_t nfeature) { - T* local_index = reinterpret_cast(&index_[0]); + /** + * \brief Set column index for both dense and sparse columns + */ + void SetIndexMixedColumns(SparsePage const& page, uint32_t const* row_index, + const GHistIndexMatrix& gmat, size_t n_features) { std::vector num_nonzeros; - num_nonzeros.resize(nfeature); - std::fill(num_nonzeros.begin(), num_nonzeros.end(), 0); - - auto get_bin_idx = [&](auto bin_id, auto rid, bst_feature_t fid) { - if (type_[fid] == kDenseColumn) { - T* begin = &local_index[feature_offsets_[fid]]; - begin[rid] = bin_id - index_base_[fid]; - missing_flags_[feature_offsets_[fid] + rid] = false; - } else { - T* begin = &local_index[feature_offsets_[fid]]; - begin[num_nonzeros[fid]] = bin_id - index_base_[fid]; - row_ind_[feature_offsets_[fid] + num_nonzeros[fid]] = rid; - ++num_nonzeros[fid]; + num_nonzeros.resize(n_features, 0); + + DispatchBinType(bins_type_size_, [&](auto t) { + using ColumnBinT = decltype(t); + ColumnBinT* local_index = reinterpret_cast(index_.data()); + + auto get_bin_idx = [&](auto bin_id, auto rid, bst_feature_t fid) { + if (type_[fid] == kDenseColumn) { + ColumnBinT* begin = &local_index[feature_offsets_[fid]]; + begin[rid] = bin_id - index_base_[fid]; + // not thread-safe with bool vector. + missing_flags_[feature_offsets_[fid] + rid] = false; + } else { + ColumnBinT* begin = &local_index[feature_offsets_[fid]]; + begin[num_nonzeros[fid]] = bin_id - index_base_[fid]; + row_ind_[feature_offsets_[fid] + num_nonzeros[fid]] = rid; + ++num_nonzeros[fid]; + } + }; + + const xgboost::Entry* data_ptr = page.data.HostVector().data(); + const std::vector& offset_vec = page.offset.HostVector(); + const size_t batch_size = gmat.Size(); + CHECK_LT(batch_size, offset_vec.size()); + for (size_t rid = 0; rid < batch_size; ++rid) { + const size_t ibegin = gmat.row_ptr[rid]; + const size_t iend = gmat.row_ptr[rid + 1]; + const size_t size = offset_vec[rid + 1] - offset_vec[rid]; + SparsePage::Inst inst = {data_ptr + offset_vec[rid], size}; + + CHECK_EQ(ibegin + inst.size(), iend); + size_t j = 0; + for (size_t i = ibegin; i < iend; ++i, ++j) { + const uint32_t bin_id = row_index[i]; + auto fid = inst[j].index; + get_bin_idx(bin_id, rid, fid); + } } - }; - this->SetIndexSparse(page, index, gmat, nfeature, get_bin_idx); + }); } BinTypeSize GetTypeSize() const { return bins_type_size_; } auto GetColumnType(bst_feature_t fidx) const { return type_[fidx]; } - // This is just an utility function - bool NoMissingValues(const size_t n_elements, const size_t n_row, const size_t n_features) { - return n_elements == n_features * n_row; - } - // And this returns part of state bool AnyMissing() const { return any_missing_; } diff --git a/src/common/hist_util.h b/src/common/hist_util.h index c203a0eb4357..66188bec6d1d 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -113,7 +113,7 @@ class HistogramCuts { auto end = ptrs[column_id + 1]; auto beg = ptrs[column_id]; auto it = std::upper_bound(values.cbegin() + beg, values.cbegin() + end, value); - bst_bin_t idx = it - values.cbegin(); + auto idx = it - values.cbegin(); idx -= !!(idx == end); return idx; } @@ -189,12 +189,30 @@ inline HistogramCuts SketchOnDMatrix(DMatrix* m, int32_t max_bins, int32_t n_thr return out; } -enum BinTypeSize : uint32_t { - kUint8BinsTypeSize = 1, +enum BinTypeSize : uint8_t { + kUint8BinsTypeSize = 1, kUint16BinsTypeSize = 2, kUint32BinsTypeSize = 4 }; +/** + * \brief Dispatch for bin type, fn is a function that accepts a scalar of the bin type. + */ +template +auto DispatchBinType(BinTypeSize type, Fn&& fn) { + switch (type) { + case kUint8BinsTypeSize: { + return fn(uint8_t{}); + } + case kUint16BinsTypeSize: { + return fn(uint16_t{}); + } + case kUint32BinsTypeSize: { + return fn(uint32_t{}); + } + } +} + /** * \brief Optionally compressed gradient index. The compression works only with dense * data. diff --git a/src/common/partition_builder.h b/src/common/partition_builder.h index 7d065e1c4ff0..29d1de9be437 100644 --- a/src/common/partition_builder.h +++ b/src/common/partition_builder.h @@ -108,7 +108,7 @@ class PartitionBuilder { template void Partition(const size_t node_in_set, const size_t nid, const common::Range1d range, - const int32_t split_cond, GHistIndexMatrix const& gmat, + const bst_bin_t split_cond, GHistIndexMatrix const& gmat, const ColumnMatrix& column_matrix, const RegTree& tree, const size_t* rid) { common::Span rid_span(rid + range.begin(), rid + range.end()); common::Span left = GetLeftBuffer(node_in_set, range.begin(), range.end()); diff --git a/src/objective/adaptive.cc b/src/objective/adaptive.cc index 43dc36600013..6ddf39849949 100644 --- a/src/objective/adaptive.cc +++ b/src/objective/adaptive.cc @@ -28,7 +28,7 @@ void EncodeTreeLeafHost(RegTree const& tree, std::vector const& posi sorted_pos[i] = position[ridx[i]]; } // find the first non-sampled row - auto begin_pos = + size_t begin_pos = std::distance(sorted_pos.cbegin(), std::find_if(sorted_pos.cbegin(), sorted_pos.cend(), [](bst_node_t nidx) { return nidx >= 0; })); CHECK_LE(begin_pos, sorted_pos.size()); diff --git a/src/tree/updater_approx.cc b/src/tree/updater_approx.cc index b507b5220f2a..88a2cfadf256 100644 --- a/src/tree/updater_approx.cc +++ b/src/tree/updater_approx.cc @@ -264,7 +264,7 @@ class GlobalApproxUpdater : public TreeUpdater { public: explicit GlobalApproxUpdater(GenericParameter const *ctx, ObjInfo task) - : task_{task}, TreeUpdater(ctx) { + : TreeUpdater(ctx), task_{task} { monitor_.Init(__func__); } diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index a4a8ace83b5c..af7dad37fe39 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -355,11 +355,11 @@ void HistRowPartitioner::FindSplitConditions(const std::vector & const bst_float split_pt = tree[nid].SplitCond(); const uint32_t lower_bound = gmat.cut.Ptrs()[fid]; const uint32_t upper_bound = gmat.cut.Ptrs()[fid + 1]; - int32_t split_cond = -1; + bst_bin_t split_cond = -1; // convert floating-point split_pt into corresponding bin_id // split_cond = -1 indicates that split_pt is less than all known cut points CHECK_LT(upper_bound, static_cast(std::numeric_limits::max())); - for (uint32_t bound = lower_bound; bound < upper_bound; ++bound) { + for (auto bound = lower_bound; bound < upper_bound; ++bound) { if (split_pt == gmat.cut.Values()[bound]) { split_cond = static_cast(bound); } diff --git a/src/tree/updater_quantile_hist.h b/src/tree/updater_quantile_hist.h index d7c2b4dec3ef..f50e6ab8a0bb 100644 --- a/src/tree/updater_quantile_hist.h +++ b/src/tree/updater_quantile_hist.h @@ -324,7 +324,7 @@ class QuantileHistMaker: public TreeUpdater { std::unique_ptr> histogram_builder_; ObjInfo task_; // Context for number of threads - GenericParameter const* ctx_; + Context const* ctx_; std::unique_ptr monitor_; }; diff --git a/tests/cpp/common/test_column_matrix.cc b/tests/cpp/common/test_column_matrix.cc index 363a22176cdf..1122c04d57c2 100644 --- a/tests/cpp/common/test_column_matrix.cc +++ b/tests/cpp/common/test_column_matrix.cc @@ -15,6 +15,7 @@ TEST(DenseColumn, Test) { int32_t max_num_bins[] = {static_cast(std::numeric_limits::max()) + 1, static_cast(std::numeric_limits::max()) + 1, static_cast(std::numeric_limits::max()) + 2}; + BinTypeSize last{kUint8BinsTypeSize}; for (int32_t max_num_bin : max_num_bins) { auto dmat = RandomDataGenerator(100, 10, 0.0).GenerateDMatrix(); auto sparse_thresh = 0.2; @@ -24,7 +25,10 @@ TEST(DenseColumn, Test) { for (auto const& page : dmat->GetBatches()) { column_matrix.Init(page, gmat, sparse_thresh, common::OmpGetNumThreads(0)); } - + ASSERT_GE(column_matrix.GetTypeSize(), last); + ASSERT_LE(column_matrix.GetTypeSize(), kUint32BinsTypeSize); + last = column_matrix.GetTypeSize(); + ASSERT_FALSE(column_matrix.AnyMissing()); for (auto i = 0ull; i < dmat->Info().num_row_; i++) { for (auto j = 0ull; j < dmat->Info().num_col_; j++) { switch (column_matrix.GetTypeSize()) { @@ -105,6 +109,7 @@ TEST(DenseColumnWithMissing, Test) { for (auto const& page : dmat->GetBatches()) { column_matrix.Init(page, gmat, 0.2, common::OmpGetNumThreads(0)); } + ASSERT_TRUE(column_matrix.AnyMissing()); switch (column_matrix.GetTypeSize()) { case kUint8BinsTypeSize: { auto col = column_matrix.DenseColumn(0); diff --git a/tests/cpp/tree/hist/test_evaluate_splits.cc b/tests/cpp/tree/hist/test_evaluate_splits.cc index 81b5812fd27a..493535aab536 100644 --- a/tests/cpp/tree/hist/test_evaluate_splits.cc +++ b/tests/cpp/tree/hist/test_evaluate_splits.cc @@ -130,7 +130,6 @@ TEST_F(TestPartitionBasedSplit, CPUHist) { namespace { auto CompareOneHotAndPartition(bool onehot) { int static constexpr kRows = 128, kCols = 1; - using GradientSumT = double; std::vector ft(kCols, FeatureType::kCategorical); TrainParam param; diff --git a/tests/cpp/tree/test_quantile_hist.cc b/tests/cpp/tree/test_quantile_hist.cc index 0c89cd5e82bc..e03933411cfd 100644 --- a/tests/cpp/tree/test_quantile_hist.cc +++ b/tests/cpp/tree/test_quantile_hist.cc @@ -35,7 +35,7 @@ TEST(QuantileHist, Partitioner) { for (auto const& page : Xy->GetBatches()) { GHistIndexMatrix gmat; - gmat.Init(page, {}, cuts, 64, false, 0.5, ctx.Threads()); + gmat.Init(page, {}, cuts, 64, true, 0.5, ctx.Threads()); bst_feature_t const split_ind = 0; common::ColumnMatrix column_indices; column_indices.Init(page, gmat, 0.5, ctx.Threads());