diff --git a/demo/guide-python/external_memory.py b/demo/guide-python/external_memory.py index 3e864a53e815..7c4e2ef9bbc2 100644 --- a/demo/guide-python/external_memory.py +++ b/demo/guide-python/external_memory.py @@ -79,7 +79,7 @@ def main(tmpdir: str) -> xgboost.Booster: # Other tree methods including ``hist`` and ``gpu_hist`` also work, but has some # caveats. This is still an experimental feature. - booster = xgboost.train({"tree_method": "approx"}, Xy, evals=[(Xy, "Train")]) + booster = xgboost.train({"tree_method": "hist", "max_depth": 2}, Xy, evals=[(Xy, "Train")], num_boost_round=1) return booster diff --git a/demo/guide-python/feature_weights.py b/demo/guide-python/feature_weights.py index f0b4907aaa42..34c8ed44026b 100644 --- a/demo/guide-python/feature_weights.py +++ b/demo/guide-python/feature_weights.py @@ -27,7 +27,7 @@ def main(args): dtrain.set_info(feature_weights=fw) bst = xgboost.train({'tree_method': 'hist', - 'colsample_bynode': 0.5}, + 'colsample_bynode': 0.2}, dtrain, num_boost_round=10, evals=[(dtrain, 'd')]) feature_map = bst.get_fscore() diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 7399b8265377..6dc214bd6f1e 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -240,6 +240,7 @@ struct BatchParam { if (hess.empty() && other.hess.empty()) { return gpu_id != other.gpu_id || max_bin != other.max_bin; } + // fixme: sprse_thresh return gpu_id != other.gpu_id || max_bin != other.max_bin || hess.data() != other.hess.data(); } bool operator==(BatchParam const& other) const { diff --git a/src/common/column_matrix.h b/src/common/column_matrix.h index 747004cc0991..9705d06b48a2 100644 --- a/src/common/column_matrix.h +++ b/src/common/column_matrix.h @@ -1,5 +1,5 @@ /*! - * Copyright 2017 by Contributors + * Copyright 2017-2022 by Contributors * \file column_matrix.h * \brief Utility for fast column-wise access * \author Philip Cho @@ -8,21 +8,22 @@ #ifndef XGBOOST_COMMON_COLUMN_MATRIX_H_ #define XGBOOST_COMMON_COLUMN_MATRIX_H_ +#include + +#include #include -#include #include -#include "hist_util.h" +#include + #include "../data/gradient_index.h" +#include "hist_util.h" namespace xgboost { namespace common { class ColumnMatrix; /*! \brief column type */ -enum ColumnType { - kDenseColumn, - kSparseColumn -}; +enum ColumnType : uint8_t { kDenseColumn, kSparseColumn }; /*! \brief a column storage, to be used with ApplySplit. Note that each bin id is stored as index[i] + index_base. @@ -34,9 +35,7 @@ class Column { static constexpr int32_t kMissingId = -1; Column(ColumnType type, common::Span index, const uint32_t index_base) - : type_(type), - index_(index), - index_base_(index_base) {} + : type_(type), index_(index), index_base_(index_base) {} virtual ~Column() = default; @@ -65,12 +64,11 @@ class Column { }; template -class SparseColumn: public Column { +class SparseColumn : public Column { public: - SparseColumn(ColumnType type, common::Span index, - uint32_t index_base, common::Span row_ind) - : Column(type, index, index_base), - row_ind_(row_ind) {} + SparseColumn(ColumnType type, common::Span index, uint32_t index_base, + common::Span row_ind) + : Column(type, index, index_base), row_ind_(row_ind) {} const size_t* GetRowData() const { return row_ind_.data(); } @@ -98,9 +96,7 @@ class SparseColumn: public Column { return p - row_data; } - size_t GetRowIdx(size_t idx) const { - return row_ind_.data()[idx]; - } + size_t GetRowIdx(size_t idx) const { return row_ind_.data()[idx]; } private: /* indexes of rows */ @@ -108,11 +104,10 @@ class SparseColumn: public Column { }; template -class DenseColumn: public Column { +class DenseColumn : public Column { public: - DenseColumn(ColumnType type, common::Span index, - uint32_t index_base, const std::vector& missing_flags, - size_t feature_offset) + DenseColumn(ColumnType type, common::Span index, uint32_t index_base, + const std::vector& missing_flags, size_t feature_offset) : Column(type, index, index_base), missing_flags_(missing_flags), feature_offset_(feature_offset) {} @@ -126,9 +121,7 @@ class DenseColumn: public Column { } } - size_t GetInitialState(const size_t first_row_id) const { - return 0; - } + size_t GetInitialState(const size_t first_row_id) const { return 0; } private: /* flags for missing values in dense columns */ @@ -141,28 +134,26 @@ class DenseColumn: public Column { class ColumnMatrix { public: // get number of features - inline bst_uint GetNumFeature() const { - return static_cast(type_.size()); - } + bst_feature_t GetNumFeature() const { return static_cast(type_.size()); } // construct column matrix from GHistIndexMatrix - inline void Init(const GHistIndexMatrix& gmat, double sparse_threshold, int32_t n_threads) { - const int32_t nfeature = static_cast(gmat.cut.Ptrs().size() - 1); + inline void Init(SparsePage const& page, const GHistIndexMatrix& gmat, double sparse_threshold, + int32_t n_threads) { + auto const nfeature = static_cast(gmat.cut.Ptrs().size() - 1); const size_t nrow = gmat.row_ptr.size() - 1; // identify type of each column feature_counts_.resize(nfeature); type_.resize(nfeature); std::fill(feature_counts_.begin(), feature_counts_.end(), 0); uint32_t max_val = std::numeric_limits::max(); - for (int32_t fid = 0; fid < nfeature; ++fid) { + for (bst_feature_t fid = 0; fid < nfeature; ++fid) { CHECK_LE(gmat.cut.Ptrs()[fid + 1] - gmat.cut.Ptrs()[fid], max_val); } bool all_dense = gmat.IsDense(); gmat.GetFeatureCounts(&feature_counts_[0]); // classify features - for (int32_t fid = 0; fid < nfeature; ++fid) { - if (static_cast(feature_counts_[fid]) - < sparse_threshold * nrow) { + for (bst_feature_t fid = 0; fid < nfeature; ++fid) { + if (static_cast(feature_counts_[fid]) < sparse_threshold * nrow) { type_[fid] = kSparseColumn; all_dense = false; } else { @@ -175,7 +166,7 @@ class ColumnMatrix { feature_offsets_.resize(nfeature + 1); size_t accum_index_ = 0; feature_offsets_[0] = accum_index_; - for (int32_t fid = 1; fid < nfeature + 1; ++fid) { + for (bst_feature_t fid = 1; fid < nfeature + 1; ++fid) { if (type_[fid - 1] == kDenseColumn) { accum_index_ += static_cast(nrow); } else { @@ -197,6 +188,7 @@ class ColumnMatrix { const bool noMissingValues = NoMissingValues(gmat.row_ptr[nrow], nrow, nfeature); any_missing_ = !noMissingValues; + missing_flags_.clear(); if (noMissingValues) { missing_flags_.resize(feature_offsets_[nfeature], false); } else { @@ -207,33 +199,33 @@ class ColumnMatrix { if (all_dense) { BinTypeSize gmat_bin_size = gmat.index.GetBinTypeSize(); if (gmat_bin_size == kUint8BinsTypeSize) { - SetIndexAllDense(gmat.index.data(), gmat, nrow, nfeature, noMissingValues, + SetIndexAllDense(page, gmat.index.data(), gmat, nrow, nfeature, noMissingValues, n_threads); } else if (gmat_bin_size == kUint16BinsTypeSize) { - SetIndexAllDense(gmat.index.data(), gmat, nrow, nfeature, noMissingValues, + SetIndexAllDense(page, gmat.index.data(), gmat, nrow, nfeature, noMissingValues, n_threads); } else { CHECK_EQ(gmat_bin_size, kUint32BinsTypeSize); - SetIndexAllDense(gmat.index.data(), gmat, nrow, nfeature, noMissingValues, + SetIndexAllDense(page, gmat.index.data(), gmat, nrow, nfeature, noMissingValues, n_threads); } - /* For sparse DMatrix gmat.index.getBinTypeSize() returns always kUint32BinsTypeSize - but for ColumnMatrix we still have a chance to reduce the memory consumption */ + /* For sparse DMatrix gmat.index.getBinTypeSize() returns always kUint32BinsTypeSize + but for ColumnMatrix we still have a chance to reduce the memory consumption */ } else { if (bins_type_size_ == kUint8BinsTypeSize) { - SetIndex(gmat.index.data(), gmat, nfeature); + SetIndex(page, gmat.index.data(), gmat, nfeature); } else if (bins_type_size_ == kUint16BinsTypeSize) { - SetIndex(gmat.index.data(), gmat, nfeature); + SetIndex(page, gmat.index.data(), gmat, nfeature); } else { - CHECK_EQ(bins_type_size_, kUint32BinsTypeSize); - SetIndex(gmat.index.data(), gmat, nfeature); + CHECK_EQ(bins_type_size_, kUint32BinsTypeSize); + SetIndex(page, gmat.index.data(), gmat, nfeature); } } } /* Set the number of bytes based on numeric limit of maximum number of bins provided by user */ void SetTypeSize(size_t max_num_bins) { - if ( (max_num_bins - 1) <= static_cast(std::numeric_limits::max()) ) { + if ((max_num_bins - 1) <= static_cast(std::numeric_limits::max())) { bins_type_size_ = kUint8BinsTypeSize; } else if ((max_num_bins - 1) <= static_cast(std::numeric_limits::max())) { bins_type_size_ = kUint16BinsTypeSize; @@ -250,123 +242,180 @@ class ColumnMatrix { const size_t feature_offset = feature_offsets_[fid]; // to get right place for certain feature const size_t column_size = feature_offsets_[fid + 1] - feature_offset; - common::Span bin_index = { reinterpret_cast( - &index_[feature_offset * bins_type_size_]), - column_size }; + common::Span bin_index = { + reinterpret_cast(&index_[feature_offset * bins_type_size_]), + column_size}; std::unique_ptr > res; if (type_[fid] == ColumnType::kDenseColumn) { CHECK_EQ(any_missing, any_missing_); res.reset(new DenseColumn(type_[fid], bin_index, index_base_[fid], - missing_flags_, feature_offset)); + missing_flags_, feature_offset)); } else { res.reset(new SparseColumn(type_[fid], bin_index, index_base_[fid], - {&row_ind_[feature_offset], column_size})); + {&row_ind_[feature_offset], column_size})); } return res; } template - inline void SetIndexAllDense(T *index, const GHistIndexMatrix &gmat, - const size_t nrow, const size_t nfeature, - const bool noMissingValues, int32_t n_threads) { + inline void SetIndexAllDense(SparsePage const& page, T* index, const GHistIndexMatrix& gmat, + const size_t nrow, const size_t nfeature, const bool noMissingValues, + int32_t n_threads) { T* local_index = reinterpret_cast(&index_[0]); /* missing values make sense only for column with type kDenseColumn, and if no missing values were observed it could be handled much faster. */ if (noMissingValues) { ParallelFor(nrow, n_threads, [&](auto rid) { - const size_t ibegin = rid*nfeature; - const size_t iend = (rid+1)*nfeature; + const size_t ibegin = rid * nfeature; + const size_t iend = (rid + 1) * nfeature; size_t j = 0; for (size_t i = ibegin; i < iend; ++i, ++j) { - const size_t idx = feature_offsets_[j]; - local_index[idx + rid] = index[i]; + const size_t idx = feature_offsets_[j]; + local_index[idx + rid] = index[i]; } }); } else { /* to handle rows in all batches, sum of all batch sizes equal to gmat.row_ptr.size() - 1 */ - size_t rbegin = 0; - for (const auto &batch : gmat.p_fmat->GetBatches()) { - const xgboost::Entry* data_ptr = batch.data.HostVector().data(); - const std::vector& offset_vec = batch.offset.HostVector(); - const size_t batch_size = batch.Size(); - CHECK_LT(batch_size, offset_vec.size()); - for (size_t rid = 0; rid < batch_size; ++rid) { - const size_t size = offset_vec[rid + 1] - offset_vec[rid]; - SparsePage::Inst inst = {data_ptr + offset_vec[rid], size}; - const size_t ibegin = gmat.row_ptr[rbegin + rid]; - const size_t iend = gmat.row_ptr[rbegin + rid + 1]; - CHECK_EQ(ibegin + inst.size(), iend); - size_t j = 0; - size_t fid = 0; - for (size_t i = ibegin; i < iend; ++i, ++j) { - fid = inst[j].index; - const size_t idx = feature_offsets_[fid]; - /* rbegin allows to store indexes from specific SparsePage batch */ - local_index[idx + rbegin + rid] = index[i]; - missing_flags_[idx + rbegin + rid] = false; - } - } - rbegin += batch.Size(); + auto get_bin_idx = [&](auto bin_id, auto rid, bst_feature_t fid) { + // T* begin = &local_index[feature_offsets_[fid]]; + const size_t idx = feature_offsets_[fid]; + /* rbegin allows to store indexes from specific SparsePage batch */ + local_index[idx + rid] = bin_id; + + missing_flags_[idx + rid] = false; + }; + this->SetIndexSparse(page, index, gmat, nfeature, get_bin_idx); + } + } + + // FIXME(jiamingy): In the future we might want to simply use binary search to simplify + // this and remove the dependency on SparsePage. This way we can have quantilized + // matrix for host similar to `DeviceQuantileDMatrix`. + template + void SetIndexSparse(SparsePage const& batch, T* index, const GHistIndexMatrix& gmat, + const size_t nfeature, BinFn&& assign_bin) { + std::vector num_nonzeros(nfeature, 0ul); + const xgboost::Entry* data_ptr = batch.data.HostVector().data(); + const std::vector& offset_vec = batch.offset.HostVector(); + auto rbegin = 0; + const size_t batch_size = gmat.Size(); + CHECK_LT(batch_size, offset_vec.size()); + + for (size_t rid = 0; rid < batch_size; ++rid) { + const size_t ibegin = gmat.row_ptr[rbegin + rid]; + const size_t iend = gmat.row_ptr[rbegin + rid + 1]; + const size_t size = offset_vec[rid + 1] - offset_vec[rid]; + SparsePage::Inst inst = {data_ptr + offset_vec[rid], size}; + + CHECK_EQ(ibegin + inst.size(), iend); + size_t j = 0; + for (size_t i = ibegin; i < iend; ++i, ++j) { + const uint32_t bin_id = index[i]; + auto fid = inst[j].index; + assign_bin(bin_id, rid, fid); } } } - template - inline void SetIndex(uint32_t* index, const GHistIndexMatrix& gmat, + template + inline void SetIndex(SparsePage const& page, uint32_t* index, const GHistIndexMatrix& gmat, const size_t nfeature) { + T* local_index = reinterpret_cast(&index_[0]); std::vector num_nonzeros; num_nonzeros.resize(nfeature); std::fill(num_nonzeros.begin(), num_nonzeros.end(), 0); - T* local_index = reinterpret_cast(&index_[0]); - size_t rbegin = 0; - for (const auto &batch : gmat.p_fmat->GetBatches()) { - const xgboost::Entry* data_ptr = batch.data.HostVector().data(); - const std::vector& offset_vec = batch.offset.HostVector(); - const size_t batch_size = batch.Size(); - CHECK_LT(batch_size, offset_vec.size()); - for (size_t rid = 0; rid < batch_size; ++rid) { - const size_t ibegin = gmat.row_ptr[rbegin + rid]; - const size_t iend = gmat.row_ptr[rbegin + rid + 1]; - size_t fid = 0; - const size_t size = offset_vec[rid + 1] - offset_vec[rid]; - SparsePage::Inst inst = {data_ptr + offset_vec[rid], size}; - - CHECK_EQ(ibegin + inst.size(), iend); - size_t j = 0; - for (size_t i = ibegin; i < iend; ++i, ++j) { - const uint32_t bin_id = index[i]; - - fid = inst[j].index; - if (type_[fid] == kDenseColumn) { - T* begin = &local_index[feature_offsets_[fid]]; - begin[rid + rbegin] = bin_id - index_base_[fid]; - missing_flags_[feature_offsets_[fid] + rid + rbegin] = false; - } else { - T* begin = &local_index[feature_offsets_[fid]]; - begin[num_nonzeros[fid]] = bin_id - index_base_[fid]; - row_ind_[feature_offsets_[fid] + num_nonzeros[fid]] = rid + rbegin; - ++num_nonzeros[fid]; - } - } + auto get_bin_idx = [&](auto bin_id, auto rid, bst_feature_t fid) { + if (type_[fid] == kDenseColumn) { + T* begin = &local_index[feature_offsets_[fid]]; + begin[rid] = bin_id - index_base_[fid]; + missing_flags_[feature_offsets_[fid] + rid] = false; + } else { + T* begin = &local_index[feature_offsets_[fid]]; + begin[num_nonzeros[fid]] = bin_id - index_base_[fid]; + row_ind_[feature_offsets_[fid] + num_nonzeros[fid]] = rid; + ++num_nonzeros[fid]; } - rbegin += batch.Size(); - } - } - BinTypeSize GetTypeSize() const { - return bins_type_size_; + }; + this->SetIndexSparse(page, index, gmat, nfeature, get_bin_idx); } + BinTypeSize GetTypeSize() const { return bins_type_size_; } + // This is just an utility function - bool NoMissingValues(const size_t n_elements, - const size_t n_row, const size_t n_features) { + bool NoMissingValues(const size_t n_elements, const size_t n_row, const size_t n_features) { return n_elements == n_features * n_row; } // And this returns part of state - bool AnyMissing() const { - return any_missing_; + bool AnyMissing() const { return any_missing_; } + + // IO procedures for external memory. + bool Read(dmlc::SeekStream* fi, uint32_t const* index_base) { + fi->Read(&index_); + fi->Read(&feature_counts_); +#if !DMLC_LITTLE_ENDIAN + // s390x + std::vector::type> int_types; + fi->Read(&int_types); + type_.resize(int_types.size()); + std::transform( + int_types.begin(), int_types.end(), type_.begin(), + [](std::underlying_type::type i) { return static_cast(i); }); +#else + fi->Read(&type_); +#endif // !DMLC_LITTLE_ENDIAN + + fi->Read(&row_ind_); + fi->Read(&feature_offsets_); + index_base_ = index_base; +#if !DMLC_LITTLE_ENDIAN + std::underlying_type::type v; + fi->Read(&v); + bins_type_size_ = static_cast(v); +#else + fi->Read(&bins_type_size_); +#endif + + fi->Read(&any_missing_); + return true; + } + + size_t Write(dmlc::Stream* fo) const { + size_t bytes{0}; + + auto write_vec = [&](auto const& vec) { + fo->Write(vec); + bytes += vec.size() * sizeof(typename std::remove_reference_t::value_type) + + sizeof(uint64_t); + }; + write_vec(index_); + write_vec(feature_counts_); +#if !DMLC_LITTLE_ENDIAN + // s390x + std::vector::type> int_types(type_.size()); + std::transform(type_.begin(), type_.end(), int_types.begin(), [](ColumnType t) { + return static_cast::type>(t); + }); + write_vec(int_types); +#else + write_vec(type_); +#endif // !DMLC_LITTLE_ENDIAN + write_vec(row_ind_); + write_vec(feature_offsets_); + +#if !DMLC_LITTLE_ENDIAN + auto v = static_cast::type>(bins_type_size_); + fo->Write(v); +#else + fo->Write(bins_type_size_); +#endif // DMLC_LITTLE_ENDIAN + bytes += sizeof(bins_type_size_); + fo->Write(any_missing_); + bytes += sizeof(any_missing_); + + return bytes; } private: diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index c14da59a7f60..77e83b65886d 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -165,6 +165,7 @@ void BuildHistKernel(const std::vector &gpair, any_missing ? get_row_ptr(rid[i]) : get_rid(rid[i]) * n_features; const size_t icol_end = any_missing ? get_row_ptr(rid[i] + 1) : icol_start + n_features; + CHECK_LE(icol_end, gmat.index.Size()); const size_t row_size = icol_end - icol_start; const size_t idx_gh = two * rid[i]; diff --git a/src/common/partition_builder.h b/src/common/partition_builder.h index 5235ea3b9404..811a56ddfef4 100644 --- a/src/common/partition_builder.h +++ b/src/common/partition_builder.h @@ -1,5 +1,5 @@ /*! - * Copyright 2021 by Contributors + * Copyright 2021-2022 by Contributors * \file row_set.h * \brief Quick Utility to compute subset of rows * \author Philip Cho, Tianqi Chen @@ -48,16 +48,20 @@ class PartitionBuilder { // Analog of std::stable_partition, but in no-inplace manner template inline std::pair PartitionKernel(const ColumnType& column, - common::Span rid_span, const int32_t split_cond, - common::Span left_part, common::Span right_part) { + common::Span rid_span, + const int32_t split_cond, + common::Span left_part, + common::Span right_part, + size_t base_rowid) { size_t* p_left_part = left_part.data(); size_t* p_right_part = right_part.data(); size_t nleft_elems = 0; size_t nright_elems = 0; - auto state = column.GetInitialState(rid_span.front()); + auto state = column.GetInitialState(rid_span.front() - base_rowid); for (auto rid : rid_span) { - const int32_t bin_id = column.GetBinIdx(rid, &state); + CHECK_GE(rid, base_rowid); + const int32_t bin_id = column.GetBinIdx(rid - base_rowid, &state); if (any_missing && bin_id == ColumnType::kMissingId) { if (default_left) { p_left_part[nleft_elems++] = rid; @@ -97,13 +101,11 @@ class PartitionBuilder { template void Partition(const size_t node_in_set, const size_t nid, const common::Range1d range, - const int32_t split_cond, - const ColumnMatrix& column_matrix, const RegTree& tree, const size_t* rid) { + const int32_t split_cond, const ColumnMatrix& column_matrix, const RegTree& tree, + const size_t* rid, size_t base_rowid) { common::Span rid_span(rid + range.begin(), rid + range.end()); - common::Span left = GetLeftBuffer(node_in_set, - range.begin(), range.end()); - common::Span right = GetRightBuffer(node_in_set, - range.begin(), range.end()); + common::Span left = GetLeftBuffer(node_in_set, range.begin(), range.end()); + common::Span right = GetRightBuffer(node_in_set, range.begin(), range.end()); const bst_uint fid = tree[nid].SplitIndex(); const bool default_left = tree[nid].DefaultLeft(); const auto column_ptr = column_matrix.GetColumn(fid); @@ -114,22 +116,22 @@ class PartitionBuilder { const common::DenseColumn& column = static_cast& >(*(column_ptr.get())); if (default_left) { - child_nodes_sizes = PartitionKernel(column, rid_span, - split_cond, left, right); + child_nodes_sizes = PartitionKernel(column, rid_span, split_cond, left, + right, base_rowid); } else { - child_nodes_sizes = PartitionKernel(column, rid_span, - split_cond, left, right); + child_nodes_sizes = PartitionKernel(column, rid_span, split_cond, left, + right, base_rowid); } } else { CHECK_EQ(any_missing, true); const common::SparseColumn& column = static_cast& >(*(column_ptr.get())); if (default_left) { - child_nodes_sizes = PartitionKernel(column, rid_span, - split_cond, left, right); + child_nodes_sizes = PartitionKernel(column, rid_span, split_cond, left, + right, base_rowid); } else { - child_nodes_sizes = PartitionKernel(column, rid_span, - split_cond, left, right); + child_nodes_sizes = PartitionKernel(column, rid_span, split_cond, left, + right, base_rowid); } } diff --git a/src/data/ellpack_page_source.cu b/src/data/ellpack_page_source.cu index 6d79250a0a63..872cb0cc657f 100644 --- a/src/data/ellpack_page_source.cu +++ b/src/data/ellpack_page_source.cu @@ -1,5 +1,5 @@ /*! - * Copyright 2019-2021 XGBoost contributors + * Copyright 2019-2022 XGBoost contributors */ #include #include @@ -12,6 +12,13 @@ namespace data { void EllpackPageSource::Fetch() { dh::safe_cuda(cudaSetDevice(param_.gpu_id)); if (!this->ReadCache()) { + if (count_ != 0 && !sync_) { + // source is initialized to be the 0th page during construction, so when count_ is 0 + // there's no need to increment the source. + ++(*source_); + } + // This is not read from cache so we still need it to be synced with sparse page source. + CHECK_EQ(count_, source_->Iter()); auto const &csr = source_->Page(); this->page_.reset(new EllpackPage{}); auto *impl = this->page_->Impl(); diff --git a/src/data/ellpack_page_source.h b/src/data/ellpack_page_source.h index 9a1551d53749..dc080247287c 100644 --- a/src/data/ellpack_page_source.h +++ b/src/data/ellpack_page_source.h @@ -1,5 +1,5 @@ /*! - * Copyright 2019-2021 by XGBoost Contributors + * Copyright 2019-2022 by XGBoost Contributors */ #ifndef XGBOOST_DATA_ELLPACK_PAGE_SOURCE_H_ @@ -25,15 +25,17 @@ class EllpackPageSource : public PageSourceIncMixIn { std::unique_ptr cuts_; public: - EllpackPageSource( - float missing, int nthreads, bst_feature_t n_features, size_t n_batches, - std::shared_ptr cache, BatchParam param, - std::unique_ptr cuts, bool is_dense, - size_t row_stride, common::Span feature_types, - std::shared_ptr source) - : PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache), - is_dense_{is_dense}, row_stride_{row_stride}, param_{std::move(param)}, - feature_types_{feature_types}, cuts_{std::move(cuts)} { + EllpackPageSource(float missing, int nthreads, bst_feature_t n_features, size_t n_batches, + std::shared_ptr cache, BatchParam param, + std::unique_ptr cuts, bool is_dense, size_t row_stride, + common::Span feature_types, + std::shared_ptr source) + : PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache, false), + is_dense_{is_dense}, + row_stride_{row_stride}, + param_{std::move(param)}, + feature_types_{feature_types}, + cuts_{std::move(cuts)} { this->source_ = source; this->Fetch(); } diff --git a/src/data/gradient_index.cc b/src/data/gradient_index.cc index abd80264d8b8..b6f46d86fe57 100644 --- a/src/data/gradient_index.cc +++ b/src/data/gradient_index.cc @@ -147,7 +147,6 @@ void GHistIndexMatrix::Init(DMatrix *p_fmat, int max_bins, double sparse_thresh, hit_count.resize(nbins, 0); hit_count_tloc_.resize(n_threads * nbins, 0); - this->p_fmat = p_fmat; size_t new_size = 1; for (const auto &batch : p_fmat->GetBatches()) { new_size += batch.Size(); @@ -167,6 +166,16 @@ void GHistIndexMatrix::Init(DMatrix *p_fmat, int max_bins, double sparse_thresh, prev_sum = row_ptr[rbegin + batch.Size()]; rbegin += batch.Size(); } + this->columns_ = std::make_unique(); + + // hessian is empty when hist tree method is used or when dataset is empty + if (hess.empty() && !std::isnan(sparse_thresh)) { + // hist + CHECK(!sorted_sketch); + for (auto const &page : p_fmat->GetBatches()) { + this->columns_->Init(page, *this, sparse_thresh, n_threads); + } + } } void GHistIndexMatrix::Init(SparsePage const &batch, common::Span ft, @@ -190,6 +199,10 @@ void GHistIndexMatrix::Init(SparsePage const &batch, common::SpanPushBatch(batch, ft, rbegin, prev_sum, nbins, n_threads); + this->columns_ = std::make_unique(); + if (!std::isnan(sparse_thresh)) { + this->columns_->Init(batch, *this, sparse_thresh, n_threads); + } } void GHistIndexMatrix::ResizeIndex(const size_t n_index, const bool isDense) { @@ -206,4 +219,17 @@ void GHistIndexMatrix::ResizeIndex(const size_t n_index, const bool isDense) { index.Resize((sizeof(uint32_t)) * n_index); } } + +common::ColumnMatrix const &GHistIndexMatrix::Transpose() const { + CHECK(columns_); + return *columns_; +} + +bool GHistIndexMatrix::ReadColumnPage(dmlc::SeekStream *fi) { + return this->columns_->Read(fi, this->cut.Ptrs().data()); +} + +size_t GHistIndexMatrix::WriteColumnPage(dmlc::Stream *fo) const { + return this->columns_->Write(fo); +} } // namespace xgboost diff --git a/src/data/gradient_index.h b/src/data/gradient_index.h index 83da8c7847cb..a1d27a8f7657 100644 --- a/src/data/gradient_index.h +++ b/src/data/gradient_index.h @@ -33,7 +33,6 @@ class GHistIndexMatrix { std::vector hit_count; /*! \brief The corresponding cuts */ common::HistogramCuts cut; - DMatrix* p_fmat; /*! \brief max_bin for each feature. */ size_t max_num_bins; /*! \brief base row index for current page (used by external memory) */ @@ -108,8 +107,12 @@ class GHistIndexMatrix { return row_ptr.empty() ? 0 : row_ptr.size() - 1; } + bool ReadColumnPage(dmlc::SeekStream* fi); + size_t WriteColumnPage(dmlc::Stream* fo) const; + + common::ColumnMatrix const& Transpose() const; + private: - // unused at the moment: https://github.com/dmlc/xgboost/pull/7531 std::unique_ptr columns_; std::vector hit_count_tloc_; bool isDense_; diff --git a/src/data/gradient_index_format.cc b/src/data/gradient_index_format.cc index 19baeb406414..3032499465eb 100644 --- a/src/data/gradient_index_format.cc +++ b/src/data/gradient_index_format.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2021 XGBoost contributors + * Copyright 2021-2022 XGBoost contributors */ #include "sparse_page_writer.h" #include "gradient_index.h" @@ -7,7 +7,7 @@ namespace xgboost { namespace data { - +// fixme: io for column matrix. class GHistIndexRawFormat : public SparsePageFormat { public: bool Read(GHistIndexMatrix* page, dmlc::SeekStream* fi) override { @@ -55,6 +55,8 @@ class GHistIndexRawFormat : public SparsePageFormat { return false; } page->SetDense(is_dense); + + page->ReadColumnPage(fi); return true; } @@ -93,6 +95,8 @@ class GHistIndexRawFormat : public SparsePageFormat { bytes += sizeof(page.base_rowid); fo->Write(page.IsDense()); bytes += sizeof(page.IsDense()); + + bytes += page.WriteColumnPage(fo); return bytes; } }; diff --git a/src/data/gradient_index_page_source.cc b/src/data/gradient_index_page_source.cc index 9ec69d904c94..0056cff3987b 100644 --- a/src/data/gradient_index_page_source.cc +++ b/src/data/gradient_index_page_source.cc @@ -7,6 +7,13 @@ namespace xgboost { namespace data { void GradientIndexPageSource::Fetch() { if (!this->ReadCache()) { + if (count_ != 0 && !sync_) { + // source is initialized to be the 0th page during construction, so when count_ is 0 + // there's no need to increment the source. + ++(*source_); + } + // This is not read from cache so we still need it to be synced with sparse page source. + CHECK_EQ(count_, source_->Iter()); auto const& csr = source_->Page(); this->page_.reset(new GHistIndexMatrix()); CHECK_NE(cuts_.Values().size(), 0); diff --git a/src/data/gradient_index_page_source.h b/src/data/gradient_index_page_source.h index 30b53a2943d9..db71c1c6d11d 100644 --- a/src/data/gradient_index_page_source.h +++ b/src/data/gradient_index_page_source.h @@ -22,13 +22,14 @@ class GradientIndexPageSource : public PageSourceIncMixIn { public: GradientIndexPageSource(float missing, int nthreads, bst_feature_t n_features, size_t n_batches, std::shared_ptr cache, BatchParam param, - common::HistogramCuts cuts, bool is_dense, int32_t max_bin_per_feat, + common::HistogramCuts cuts, bool is_dense, common::Span feature_types, std::shared_ptr source) - : PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache), + : PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache, + std::isnan(param.sparse_thresh)), cuts_{std::move(cuts)}, is_dense_{is_dense}, - max_bin_per_feat_{max_bin_per_feat}, + max_bin_per_feat_{param.max_bin}, feature_types_{feature_types}, sparse_thresh_{param.sparse_thresh} { this->source_ = source; diff --git a/src/data/sparse_page_dmatrix.cc b/src/data/sparse_page_dmatrix.cc index a9fd9b7c1499..a90150ce8c83 100644 --- a/src/data/sparse_page_dmatrix.cc +++ b/src/data/sparse_page_dmatrix.cc @@ -159,21 +159,6 @@ BatchSet SparsePageDMatrix::GetSortedColumnBatches() { BatchSet SparsePageDMatrix::GetGradientIndex(const BatchParam ¶m) { CHECK_GE(param.max_bin, 2); - if (param.hess.empty() && !param.regen) { - // hist method doesn't support full external memory implementation, so we concatenate - // all index here. - if (!ghist_index_page_ || (param != batch_param_ && param != BatchParam{})) { - this->InitializeSparsePage(); - ghist_index_page_.reset(new GHistIndexMatrix{this, param.max_bin, param.sparse_thresh, - param.regen, ctx_.Threads()}); - this->InitializeSparsePage(); - batch_param_ = param; - } - auto begin_iter = BatchIterator( - new SimpleBatchIteratorImpl(ghist_index_page_)); - return BatchSet(begin_iter); - } - auto id = MakeCache(this, ".gradient_index.page", cache_prefix_, &cache_info_); this->InitializeSparsePage(); if (!cache_info_.at(id)->written || RegenGHist(batch_param_, param)) { @@ -190,10 +175,9 @@ BatchSet SparsePageDMatrix::GetGradientIndex(const BatchParam ghist_index_source_.reset(); CHECK_NE(cuts.Values().size(), 0); auto ft = this->info_.feature_types.ConstHostSpan(); - ghist_index_source_.reset( - new GradientIndexPageSource(this->missing_, this->ctx_.Threads(), this->Info().num_col_, - this->n_batches_, cache_info_.at(id), param, std::move(cuts), - this->IsDense(), param.max_bin, ft, sparse_page_source_)); + ghist_index_source_.reset(new GradientIndexPageSource( + this->missing_, this->ctx_.Threads(), this->Info().num_col_, this->n_batches_, + cache_info_.at(id), param, std::move(cuts), this->IsDense(), ft, sparse_page_source_)); } else { CHECK(ghist_index_source_); ghist_index_source_->Reset(); diff --git a/src/data/sparse_page_dmatrix.cu b/src/data/sparse_page_dmatrix.cu index 82e1f3ce0ae0..b36a0e2a37b9 100644 --- a/src/data/sparse_page_dmatrix.cu +++ b/src/data/sparse_page_dmatrix.cu @@ -11,6 +11,9 @@ namespace data { BatchSet SparsePageDMatrix::GetEllpackBatches(const BatchParam& param) { CHECK_GE(param.gpu_id, 0); CHECK_GE(param.max_bin, 2); + if (!(batch_param_ != BatchParam{})) { + CHECK(param != BatchParam{}) << "Batch parameter is not initialized."; + } auto id = MakeCache(this, ".ellpack.page", cache_prefix_, &cache_info_); size_t row_stride = 0; this->InitializeSparsePage(); diff --git a/src/data/sparse_page_source.h b/src/data/sparse_page_source.h index 4bada04c8e05..0a3e32e75e1f 100644 --- a/src/data/sparse_page_source.h +++ b/src/data/sparse_page_source.h @@ -23,6 +23,7 @@ #include "proxy_dmatrix.h" #include "../common/common.h" +#include "../common/timer.h" namespace xgboost { namespace data { @@ -118,26 +119,30 @@ class SparsePageSourceImpl : public BatchIteratorImpl { size_t n_prefetch_batches = std::min(kPreFetch, n_batches_); CHECK_GT(n_prefetch_batches, 0) << "total batches:" << n_batches_; size_t fetch_it = count_; + for (size_t i = 0; i < n_prefetch_batches; ++i, ++fetch_it) { fetch_it %= n_batches_; // ring - if (ring_->at(fetch_it).valid()) { continue; } + if (ring_->at(fetch_it).valid()) { + continue; + } auto const *self = this; // make sure it's const CHECK_LT(fetch_it, cache_info_->offset.size()); ring_->at(fetch_it) = std::async(std::launch::async, [fetch_it, self]() { + common::Timer timer; + timer.Start(); std::unique_ptr> fmt{CreatePageFormat("raw")}; auto n = self->cache_info_->ShardName(); size_t offset = self->cache_info_->offset.at(fetch_it); - std::unique_ptr fi{ - dmlc::SeekStream::CreateForRead(n.c_str())}; + std::unique_ptr fi{dmlc::SeekStream::CreateForRead(n.c_str())}; fi->Seek(offset); CHECK_EQ(fi->Tell(), offset); auto page = std::make_shared(); CHECK(fmt->Read(page.get(), fi.get())); + LOG(INFO) << "Read a page in " << timer.ElapsedSeconds() << " seconds."; return page; }); } - CHECK_EQ(std::count_if(ring_->cbegin(), ring_->cend(), - [](auto const &f) { return f.valid(); }), + CHECK_EQ(std::count_if(ring_->cbegin(), ring_->cend(), [](auto const& f) { return f.valid(); }), n_prefetch_batches) << "Sparse DMatrix assumes forward iteration."; page_ = (*ring_)[count_].get(); @@ -146,12 +151,18 @@ class SparsePageSourceImpl : public BatchIteratorImpl { void WriteCache() { CHECK(!cache_info_->written); + common::Timer timer; + timer.Start(); std::unique_ptr> fmt{CreatePageFormat("raw")}; if (!fo_) { auto n = cache_info_->ShardName(); fo_.reset(dmlc::Stream::Create(n.c_str(), "w")); } auto bytes = fmt->Write(*page_, fo_.get()); + timer.Stop(); + + LOG(INFO) << static_cast(bytes) / 1024.0 / 1024.0 << " MB written in " + << timer.ElapsedSeconds() << " seconds."; cache_info_->offset.push_back(bytes); } @@ -280,15 +291,24 @@ template class PageSourceIncMixIn : public SparsePageSourceImpl { protected: std::shared_ptr source_; + using Super = SparsePageSourceImpl; + // synchronize the row page, `hist` and `gpu_hist` don't need the original sparse page + // so we avoid fetching it. + bool sync_{true}; public: - using SparsePageSourceImpl::SparsePageSourceImpl; + PageSourceIncMixIn(float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches, + std::shared_ptr cache, bool sync) + : Super::SparsePageSourceImpl{missing, nthreads, n_features, n_batches, cache}, sync_{sync} {} + PageSourceIncMixIn& operator++() final { TryLockGuard guard{this->single_threaded_}; - ++(*source_); + if (sync_) { + ++(*source_); + } ++this->count_; - this->at_end_ = source_->AtEnd(); + this->at_end_ = this->count_ == this->n_batches_; if (this->at_end_) { this->cache_info_->Commit(); @@ -299,7 +319,10 @@ class PageSourceIncMixIn : public SparsePageSourceImpl { } else { this->Fetch(); } - CHECK_EQ(source_->Iter(), this->count_); + + if (sync_) { + CHECK_EQ(source_->Iter(), this->count_); + } return *this; } }; @@ -318,12 +341,9 @@ class CSCPageSource : public PageSourceIncMixIn { } public: - CSCPageSource( - float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches, - std::shared_ptr cache, - std::shared_ptr source) - : PageSourceIncMixIn(missing, nthreads, n_features, - n_batches, cache) { + CSCPageSource(float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches, + std::shared_ptr cache, std::shared_ptr source) + : PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache, true) { this->source_ = source; this->Fetch(); } @@ -349,7 +369,7 @@ class SortedCSCPageSource : public PageSourceIncMixIn { SortedCSCPageSource(float missing, int nthreads, bst_feature_t n_features, uint32_t n_batches, std::shared_ptr cache, std::shared_ptr source) - : PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache) { + : PageSourceIncMixIn(missing, nthreads, n_features, n_batches, cache, true) { this->source_ = source; this->Fetch(); } diff --git a/src/tree/hist/histogram.h b/src/tree/hist/histogram.h index 242825b25bb5..6020de28d529 100644 --- a/src/tree/hist/histogram.h +++ b/src/tree/hist/histogram.h @@ -1,5 +1,5 @@ /*! - * Copyright 2021 by XGBoost Contributors + * Copyright 2021-2022 by XGBoost Contributors */ #ifndef XGBOOST_TREE_HIST_HISTOGRAM_H_ #define XGBOOST_TREE_HIST_HISTOGRAM_H_ @@ -8,10 +8,11 @@ #include #include -#include "rabit/rabit.h" -#include "xgboost/tree_model.h" #include "../../common/hist_util.h" #include "../../data/gradient_index.h" +#include "expand_entry.h" +#include "rabit/rabit.h" +#include "xgboost/tree_model.h" namespace xgboost { namespace tree { @@ -323,6 +324,25 @@ template class HistogramBuilder { (*sync_count) = std::max(1, n_left); } }; + +// Construct a work space for building histogram. Eventually we should move this +// function into histogram builder once hist tree method supports external memory. +template +common::BlockedSpace2d ConstructHistSpace(Partitioner const &partitioners, + std::vector const &nodes_to_build) { + std::vector partition_size(nodes_to_build.size(), 0); + for (auto const &partition : partitioners) { + size_t k = 0; + for (auto node : nodes_to_build) { + auto n_rows_in_node = partition.Partitions()[node.nid].Size(); + partition_size[k] = std::max(partition_size[k], n_rows_in_node); + k++; + } + } + common::BlockedSpace2d space{ + nodes_to_build.size(), [&](size_t nidx_in_set) { return partition_size[nidx_in_set]; }, 256}; + return space; +} } // namespace tree } // namespace xgboost #endif // XGBOOST_TREE_HIST_HISTOGRAM_H_ diff --git a/src/tree/updater_approx.cc b/src/tree/updater_approx.cc index 4def6940ddc5..55ec4bf93502 100644 --- a/src/tree/updater_approx.cc +++ b/src/tree/updater_approx.cc @@ -94,7 +94,7 @@ class GloablApproxBuilder { rabit::Allreduce(reinterpret_cast(&root_sum), 2); std::vector nodes{best}; size_t i = 0; - auto space = this->ConstructHistSpace(nodes); + auto space = ConstructHistSpace(partitioner_, nodes); for (auto const &page : p_fmat->GetBatches(BatchSpec(param_, hess))) { histogram_builder_.BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(), nodes, {}, gpair); @@ -145,25 +145,6 @@ class GloablApproxBuilder { monitor_->Stop(__func__); } - // Construct a work space for building histogram. Eventually we should move this - // function into histogram builder once hist tree method supports external memory. - common::BlockedSpace2d ConstructHistSpace( - std::vector const &nodes_to_build) const { - std::vector partition_size(nodes_to_build.size(), 0); - for (auto const &partition : partitioner_) { - size_t k = 0; - for (auto node : nodes_to_build) { - auto n_rows_in_node = partition.Partitions()[node.nid].Size(); - partition_size[k] = std::max(partition_size[k], n_rows_in_node); - k++; - } - } - common::BlockedSpace2d space{nodes_to_build.size(), - [&](size_t nidx_in_set) { return partition_size[nidx_in_set]; }, - 256}; - return space; - } - void BuildHistogram(DMatrix *p_fmat, RegTree *p_tree, std::vector const &valid_candidates, std::vector const &gpair, common::Span hess) { @@ -186,7 +167,7 @@ class GloablApproxBuilder { } size_t i = 0; - auto space = this->ConstructHistSpace(nodes_to_build); + auto space = ConstructHistSpace(partitioner_, nodes_to_build); for (auto const &page : p_fmat->GetBatches(BatchSpec(param_, hess))) { histogram_builder_.BuildHist(i, space, page, p_tree, partitioner_.at(i).Partitions(), nodes_to_build, nodes_to_sub, gpair); diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 8c52ff382455..0f1cdc04c8e2 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -61,23 +61,17 @@ void QuantileHistMaker::CallBuilderUpdate(const std::unique_ptr &trees) { for (auto tree : trees) { - builder->Update(gmat, column_matrix_, gpair, dmat, tree); + builder->Update(gmat, gpair, dmat, tree); } } void QuantileHistMaker::Update(HostDeviceVector *gpair, DMatrix *dmat, const std::vector &trees) { - auto it = dmat->GetBatches(HistBatch(param_)).begin(); + auto it = dmat->GetBatches( + BatchParam{param_.max_bin, param_.sparse_threshold}) + .begin(); auto p_gmat = it.Page(); - if (dmat != p_last_dmat_ || is_gmat_initialized_ == false) { - updater_monitor_.Start("GmatInitialization"); - column_matrix_.Init(*p_gmat, param_.sparse_threshold, ctx_->Threads()); - updater_monitor_.Stop("GmatInitialization"); - // A proper solution is puting cut matrix in DMatrix, see: - // https://github.com/dmlc/xgboost/issues/5143 - is_gmat_initialized_ = true; - } // rescale learning rate according to size of trees float lr = param_.learning_rate; param_.learning_rate = lr / trees.size(); @@ -114,48 +108,29 @@ bool QuantileHistMaker::UpdatePredictionCache( template -template void QuantileHistMaker::Builder::InitRoot( DMatrix *p_fmat, RegTree *p_tree, const std::vector &gpair_h, int *num_leaves, std::vector *expand) { CPUExpandEntry node(RegTree::kRoot, p_tree->GetDepth(0), 0.0f); - nodes_for_explicit_hist_build_.clear(); - nodes_for_subtraction_trick_.clear(); - nodes_for_explicit_hist_build_.push_back(node); - size_t page_id = 0; - for (auto const& gidx : - p_fmat->GetBatches(HistBatch(param_))) { - this->histogram_builder_->BuildHist( - page_id, gidx, p_tree, row_set_collection_, - nodes_for_explicit_hist_build_, nodes_for_subtraction_trick_, gpair_h); + auto space = ConstructHistSpace(row_partitioner_, {node}); + for (auto const &gidx : + p_fmat->GetBatches({param_.max_bin, param_.sparse_threshold})) { + std::vector nodes_to_build{node}; + std::vector nodes_to_sub; + this->histogram_builder_->BuildHist(page_id, space, gidx, p_tree, + row_partitioner_.at(page_id).Partitions(), nodes_to_build, + nodes_to_sub, gpair_h); ++page_id; } { - auto nid = RegTree::kRoot; - GHistRowT hist = this->histogram_builder_->Histogram()[nid]; GradientPairT grad_stat; - if (data_layout_ == DataLayout::kDenseDataZeroBased || - data_layout_ == DataLayout::kDenseDataOneBased) { - auto const& gmat = *(p_fmat->GetBatches(HistBatch(param_)).begin()); - const std::vector &row_ptr = gmat.cut.Ptrs(); - const uint32_t ibegin = row_ptr[fid_least_bins_]; - const uint32_t iend = row_ptr[fid_least_bins_ + 1]; - auto begin = hist.data(); - for (uint32_t i = ibegin; i < iend; ++i) { - const GradientPairT et = begin[i]; - grad_stat.Add(et.GetGrad(), et.GetHess()); - } - } else { - const RowSetCollection::Elem e = row_set_collection_[nid]; - for (const size_t *it = e.begin; it < e.end; ++it) { - grad_stat.Add(gpair_h[*it].GetGrad(), gpair_h[*it].GetHess()); - } - rabit::Allreduce( - reinterpret_cast(&grad_stat), 2); + for (auto const &gpair : gpair_h) { + grad_stat.Add(gpair.GetGrad(), gpair.GetHess()); } + rabit::Allreduce(reinterpret_cast(&grad_stat), 2); auto weight = evaluator_->InitRoot(GradStats{grad_stat}); p_tree->Stat(RegTree::kRoot).sum_hess = grad_stat.GetHess(); @@ -165,7 +140,8 @@ void QuantileHistMaker::Builder::InitRoot( std::vector entries{node}; builder_monitor_.Start("EvaluateSplits"); auto ft = p_fmat->Info().feature_types.ConstHostSpan(); - for (auto const& gmat : p_fmat->GetBatches(HistBatch(param_))) { + for (auto const &gmat : p_fmat->GetBatches( + BatchParam{param_.max_bin, param_.sparse_threshold})) { evaluator_->EvaluateSplits(histogram_builder_->Histogram(), gmat.cut, ft, *p_tree, &entries); break; @@ -178,124 +154,91 @@ void QuantileHistMaker::Builder::InitRoot( ++(*num_leaves); } -template -void QuantileHistMaker::Builder::AddSplitsToTree( - const std::vector& expand, - RegTree *p_tree, - int *num_leaves, - std::vector* nodes_for_apply_split) { - for (auto const& entry : expand) { - if (entry.IsValid(param_, *num_leaves)) { - nodes_for_apply_split->push_back(entry); - evaluator_->ApplyTreeSplit(entry, p_tree); - (*num_leaves)++; - } - } -} - -// Split nodes to 2 sets depending on amount of rows in each node -// Histograms for small nodes will be built explicitly -// Histograms for big nodes will be built by 'Subtraction Trick' -// Exception: in distributed setting, we always build the histogram for the left child node -// and use 'Subtraction Trick' to built the histogram for the right child node. -// This ensures that the workers operate on the same set of tree nodes. template -void QuantileHistMaker::Builder::SplitSiblings( - const std::vector &nodes_for_apply_split, - std::vector *nodes_to_evaluate, RegTree *p_tree) { - builder_monitor_.Start("SplitSiblings"); - for (auto const& entry : nodes_for_apply_split) { - int nid = entry.nid; - - const int cleft = (*p_tree)[nid].LeftChild(); - const int cright = (*p_tree)[nid].RightChild(); - const CPUExpandEntry left_node = CPUExpandEntry(cleft, p_tree->GetDepth(cleft), 0.0); - const CPUExpandEntry right_node = CPUExpandEntry(cright, p_tree->GetDepth(cright), 0.0); - nodes_to_evaluate->push_back(left_node); - nodes_to_evaluate->push_back(right_node); - if (row_set_collection_[cleft].Size() < row_set_collection_[cright].Size()) { - nodes_for_explicit_hist_build_.push_back(left_node); - nodes_for_subtraction_trick_.push_back(right_node); - } else { - nodes_for_explicit_hist_build_.push_back(right_node); - nodes_for_subtraction_trick_.push_back(left_node); - } - } - CHECK_EQ(nodes_for_subtraction_trick_.size(), nodes_for_explicit_hist_build_.size()); - builder_monitor_.Stop("SplitSiblings"); -} - -template -template void QuantileHistMaker::Builder::ExpandTree( - const GHistIndexMatrix& gmat, - const ColumnMatrix& column_matrix, - DMatrix* p_fmat, - RegTree* p_tree, - const std::vector& gpair_h) { + DMatrix *p_fmat, RegTree *p_tree, const std::vector &gpair_h) { builder_monitor_.Start("ExpandTree"); int num_leaves = 0; Driver driver(static_cast(param_.grow_policy)); - std::vector expand; - InitRoot(p_fmat, p_tree, gpair_h, &num_leaves, &expand); - driver.Push(expand[0]); - - int32_t depth = 0; - while (!driver.IsEmpty()) { - expand = driver.Pop(); - depth = expand[0].depth + 1; - std::vector nodes_for_apply_split; - std::vector nodes_to_evaluate; - nodes_for_explicit_hist_build_.clear(); - nodes_for_subtraction_trick_.clear(); - - AddSplitsToTree(expand, p_tree, &num_leaves, &nodes_for_apply_split); - - if (nodes_for_apply_split.size() != 0) { - ApplySplit(nodes_for_apply_split, gmat, column_matrix, p_tree); - SplitSiblings(nodes_for_apply_split, &nodes_to_evaluate, p_tree); - - if (param_.max_depth == 0 || depth < param_.max_depth) { - size_t i = 0; - for (auto const& gidx : p_fmat->GetBatches(HistBatch(param_))) { - this->histogram_builder_->BuildHist( - i, gidx, p_tree, row_set_collection_, - nodes_for_explicit_hist_build_, nodes_for_subtraction_trick_, - gpair_h); - ++i; - } - } else { - int starting_index = std::numeric_limits::max(); - int sync_count = 0; - this->histogram_builder_->AddHistRows( - &starting_index, &sync_count, nodes_for_explicit_hist_build_, - nodes_for_subtraction_trick_, p_tree); + std::vector expand_set; + InitRoot(p_fmat, p_tree, gpair_h, &num_leaves, &expand_set); + driver.Push(expand_set[0]); + expand_set = driver.Pop(); + + std::vector best_splits; + // candidates that can be further splited. + std::vector valid_candidates; + // candidaates that can be applied. + std::vector applied; + + std::vector nodes_to_build; + // subtraction trick. + std::vector nodes_to_sub; + + while (!expand_set.empty()) { + int32_t depth = expand_set.front().depth + 1; + for (auto const& candidate : expand_set) { + if (!candidate.IsValid(param_, num_leaves)) { + continue; } + evaluator_->ApplyTreeSplit(candidate, p_tree); + applied.push_back(candidate); + num_leaves++; - builder_monitor_.Start("EvaluateSplits"); + if (CPUExpandEntry::ChildIsValid(param_, depth, num_leaves)) { + valid_candidates.emplace_back(candidate); + } + } + size_t page_id{0}; + for (auto const &page : + p_fmat->GetBatches({param_.max_bin, param_.sparse_threshold})) { + auto const &column_matrix = page.Transpose(); + auto &part = this->row_partitioner_.at(page_id); + if (column_matrix.AnyMissing()) { + part.template UpdatePosition(ctx_->Threads(), page, column_matrix, applied, p_tree); + } else { + part.template UpdatePosition(ctx_->Threads(), page, column_matrix, applied, p_tree); + } + ++page_id; + } + applied.clear(); + + auto& tree = *p_tree; + if (!valid_candidates.empty()) { + this->BuildHistogram(p_fmat, p_tree, valid_candidates, &nodes_to_build, &nodes_to_sub, + gpair_h); + for (auto const &candidate : valid_candidates) { + int left_child_nidx = tree[candidate.nid].LeftChild(); + int right_child_nidx = tree[candidate.nid].RightChild(); + CPUExpandEntry l_best{left_child_nidx, depth, 0.0}; + CPUExpandEntry r_best{right_child_nidx, depth, 0.0}; + best_splits.push_back(l_best); + best_splits.push_back(r_best); + } + auto const &histograms = histogram_builder_->Histogram(); auto ft = p_fmat->Info().feature_types.ConstHostSpan(); - evaluator_->EvaluateSplits(this->histogram_builder_->Histogram(), - gmat.cut, ft, *p_tree, &nodes_to_evaluate); - builder_monitor_.Stop("EvaluateSplits"); - - for (size_t i = 0; i < nodes_for_apply_split.size(); ++i) { - CPUExpandEntry left_node = nodes_to_evaluate.at(i * 2 + 0); - CPUExpandEntry right_node = nodes_to_evaluate.at(i * 2 + 1); - driver.Push(left_node); - driver.Push(right_node); + for (auto const &gmat : + p_fmat->GetBatches({param_.max_bin, param_.sparse_threshold})) { + evaluator_->EvaluateSplits(histograms, gmat.cut, ft, *p_tree, &best_splits); + break; } } + valid_candidates.clear(); + + driver.Push(best_splits.begin(), best_splits.end()); + + best_splits.clear(); + expand_set = driver.Pop(); } + builder_monitor_.Stop("ExpandTree"); } template -void QuantileHistMaker::Builder::Update( - const GHistIndexMatrix &gmat, - const ColumnMatrix &column_matrix, - HostDeviceVector *gpair, - DMatrix *p_fmat, RegTree *p_tree) { +void QuantileHistMaker::Builder::Update(const GHistIndexMatrix &gmat, + HostDeviceVector *gpair, + DMatrix *p_fmat, RegTree *p_tree) { builder_monitor_.Start("Update"); std::vector* gpair_ptr = &(gpair->HostVector()); @@ -307,13 +250,10 @@ void QuantileHistMaker::Builder::Update( } p_last_fmat_mutable_ = p_fmat; - this->InitData(gmat, *p_fmat, *p_tree, gpair_ptr); + this->InitData(gmat, p_fmat, *p_tree, gpair_ptr); + + ExpandTree(p_fmat, p_tree, *gpair_ptr); - if (column_matrix.AnyMissing()) { - ExpandTree(gmat, column_matrix, p_fmat, p_tree, *gpair_ptr); - } else { - ExpandTree(gmat, column_matrix, p_fmat, p_tree, *gpair_ptr); - } pruner_->Update(gpair, p_fmat, std::vector{p_tree}); builder_monitor_.Stop("Update"); @@ -333,41 +273,33 @@ bool QuantileHistMaker::Builder::UpdatePredictionCache( CHECK_GT(out_preds.Size(), 0U); - size_t n_nodes = row_set_collection_.end() - row_set_collection_.begin(); - - common::BlockedSpace2d space(n_nodes, [&](size_t node) { - return row_set_collection_[node].Size(); - }, 1024); - CHECK_EQ(out_preds.DeviceIdx(), GenericParameter::kCpuId); - common::ParallelFor2d(space, this->ctx_->Threads(), [&](size_t node, common::Range1d r) { - const RowSetCollection::Elem rowset = row_set_collection_[node]; - if (rowset.begin != nullptr && rowset.end != nullptr) { - int nid = rowset.node_id; - bst_float leaf_value; - // if a node is marked as deleted by the pruner, traverse upward to locate - // a non-deleted leaf. - if ((*p_last_tree_)[nid].IsDeleted()) { - while ((*p_last_tree_)[nid].IsDeleted()) { - nid = (*p_last_tree_)[nid].Parent(); + size_t n_nodes = p_last_tree_->GetNodes().size(); + auto evaluator = evaluator_->Evaluator(); + auto const &tree = *p_last_tree_; + auto const &snode = evaluator_->Stats(); + for (auto &part : row_partitioner_) { + CHECK_EQ(part.Size(), n_nodes); + common::BlockedSpace2d space( + part.Size(), [&](size_t node) { return part[node].Size(); }, 1024); + common::ParallelFor2d(space, ctx_->Threads(), [&](size_t nidx, common::Range1d r) { + if (!tree[nidx].IsDeleted() && tree[nidx].IsLeaf()) { + const auto rowset = part[nidx]; + auto const &stats = snode.at(nidx); + auto leaf_value = + evaluator.CalcWeight(nidx, param_, GradStats{stats.stats}) * param_.learning_rate; + for (const size_t *it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) { + out_preds(*it) += leaf_value; } - CHECK((*p_last_tree_)[nid].IsLeaf()); - } - leaf_value = (*p_last_tree_)[nid].LeafValue(); - - for (const size_t* it = rowset.begin + r.begin(); it < rowset.begin + r.end(); ++it) { - out_preds(*it) += leaf_value; } - } - }); - + }); + } builder_monitor_.Stop("UpdatePredictionCache"); return true; } -template -void QuantileHistMaker::Builder::InitSampling(const DMatrix& fmat, - std::vector* gpair, - std::vector* row_indices) { +template +void QuantileHistMaker::Builder::InitSampling(const DMatrix &fmat, + std::vector *gpair) { const auto& info = fmat.Info(); auto& rnd = common::GlobalRandom(); std::vector& gpair_ref = *gpair; @@ -407,9 +339,9 @@ size_t QuantileHistMaker::Builder::GetNumberOfTrees() { } template -void QuantileHistMaker::Builder::InitData( - const GHistIndexMatrix &gmat, const DMatrix &fmat, const RegTree &tree, - std::vector *gpair) { +void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix &gmat, DMatrix *fmat, + const RegTree &tree, + std::vector *gpair) { CHECK((param_.max_depth > 0 || param_.max_leaves > 0)) << "max_depth or max_leaves cannot be both 0 (unlimited); " << "at least one should be a positive quantity."; @@ -418,32 +350,37 @@ void QuantileHistMaker::Builder::InitData( << "when grow_policy is depthwise."; } builder_monitor_.Start("InitData"); - const auto& info = fmat.Info(); + const auto& info = fmat->Info(); { - // initialize the row set - row_set_collection_.Clear(); // initialize histogram collection - uint32_t nbins = gmat.cut.Ptrs().back(); // initialize histogram builder dmlc::OMPException exc; exc.Rethrow(); - this->histogram_builder_->Reset(nbins, BatchParam{GenericParameter::kCpuId, param_.max_bin}, - this->ctx_->Threads(), 1, rabit::IsDistributed()); - - std::vector& row_indices = *row_set_collection_.Data(); - row_indices.resize(info.num_row_); - size_t* p_row_indices = row_indices.data(); - // mark subsample and build list of member rows + size_t page_id{0}; + int32_t n_total_bins{0}; + row_partitioner_.clear(); + for (auto const &page : + fmat->GetBatches({param_.max_bin, param_.sparse_threshold})) { + if (n_total_bins == 0) { + n_total_bins = page.cut.TotalBins(); + } else { + CHECK_EQ(n_total_bins, page.cut.TotalBins()); + } + row_partitioner_.emplace_back(); + row_partitioner_.back().Init(page.Size(), page.base_rowid, this->ctx_->Threads()); + ++page_id; + } + histogram_builder_->Reset(n_total_bins, BatchParam{param_.max_bin, param_.sparse_threshold}, + ctx_->Threads(), page_id, rabit::IsDistributed()); if (param_.subsample < 1.0f) { CHECK_EQ(param_.sampling_method, TrainParam::kUniform) << "Only uniform sampling is supported, " << "gradient-based sampling is only support by GPU Hist."; builder_monitor_.Start("InitSampling"); - InitSampling(fmat, gpair, &row_indices); + InitSampling(*fmat, gpair); builder_monitor_.Stop("InitSampling"); - CHECK_EQ(row_indices.size(), info.num_row_); // We should check that the partitioning was done correctly // and each row of the dataset fell into exactly one of the categories } @@ -471,41 +408,8 @@ void QuantileHistMaker::Builder::InitData( }); } exc.Rethrow(); - - bool has_neg_hess = false; - for (int32_t tid = 0; tid < n_threads; ++tid) { - if (p_buff[tid]) { - has_neg_hess = true; - } - } - - if (has_neg_hess) { - size_t j = 0; - for (size_t i = 0; i < info.num_row_; ++i) { - if ((*gpair)[i].GetHess() >= 0.0f) { - p_row_indices[j++] = i; - } - } - row_indices.resize(j); - } else { - #pragma omp parallel num_threads(n_threads) - { - exc.Run([&]() { - const size_t tid = omp_get_thread_num(); - const size_t ibegin = tid * block_size; - const size_t iend = std::min(static_cast(ibegin + block_size), - static_cast(info.num_row_)); - for (size_t i = ibegin; i < iend; ++i) { - p_row_indices[i] = i; - } - }); - } - exc.Rethrow(); - } } - row_set_collection_.Init(); - { /* determine layout of data */ const size_t nrow = info.num_row_; @@ -513,7 +417,7 @@ void QuantileHistMaker::Builder::InitData( const size_t nnz = info.num_nonzero_; // number of discrete bins for feature 0 const uint32_t nbins_f0 = gmat.cut.Ptrs()[1] - gmat.cut.Ptrs()[0]; - if (nrow * ncol == nnz) { + if (fmat->IsDense()) { // dense data with zero-based indexing data_layout_ = DataLayout::kDenseDataZeroBased; } else if (nbins_f0 == 0 && nrow * (ncol - 1) == nnz) { @@ -534,140 +438,9 @@ void QuantileHistMaker::Builder::InitData( param_, info, this->ctx_->Threads(), column_sampler_, task_, false}); } - if (data_layout_ == DataLayout::kDenseDataZeroBased - || data_layout_ == DataLayout::kDenseDataOneBased) { - /* specialized code for dense data: - choose the column that has a least positive number of discrete bins. - For dense data (with no missing value), - the sum of gradient histogram is equal to snode[nid] */ - const std::vector& row_ptr = gmat.cut.Ptrs(); - const auto nfeature = static_cast(row_ptr.size() - 1); - uint32_t min_nbins_per_feature = 0; - for (bst_uint i = 0; i < nfeature; ++i) { - const uint32_t nbins = row_ptr[i + 1] - row_ptr[i]; - if (nbins > 0) { - if (min_nbins_per_feature == 0 || min_nbins_per_feature > nbins) { - min_nbins_per_feature = nbins; - fid_least_bins_ = i; - } - } - } - CHECK_GT(min_nbins_per_feature, 0U); - } - builder_monitor_.Stop("InitData"); } -template -void QuantileHistMaker::Builder::FindSplitConditions( - const std::vector& nodes, - const RegTree& tree, - const GHistIndexMatrix& gmat, - std::vector* split_conditions) { - const size_t n_nodes = nodes.size(); - split_conditions->resize(n_nodes); - - for (size_t i = 0; i < nodes.size(); ++i) { - const int32_t nid = nodes[i].nid; - const bst_uint fid = tree[nid].SplitIndex(); - const bst_float split_pt = tree[nid].SplitCond(); - const uint32_t lower_bound = gmat.cut.Ptrs()[fid]; - const uint32_t upper_bound = gmat.cut.Ptrs()[fid + 1]; - int32_t split_cond = -1; - // convert floating-point split_pt into corresponding bin_id - // split_cond = -1 indicates that split_pt is less than all known cut points - CHECK_LT(upper_bound, - static_cast(std::numeric_limits::max())); - for (uint32_t bound = lower_bound; bound < upper_bound; ++bound) { - if (split_pt == gmat.cut.Values()[bound]) { - split_cond = static_cast(bound); - } - } - (*split_conditions)[i] = split_cond; - } -} -template -void QuantileHistMaker::Builder::AddSplitsToRowSet( - const std::vector& nodes, - RegTree* p_tree) { - const size_t n_nodes = nodes.size(); - for (unsigned int i = 0; i < n_nodes; ++i) { - const int32_t nid = nodes[i].nid; - const size_t n_left = partition_builder_.GetNLeftElems(i); - const size_t n_right = partition_builder_.GetNRightElems(i); - CHECK_EQ((*p_tree)[nid].LeftChild() + 1, (*p_tree)[nid].RightChild()); - row_set_collection_.AddSplit(nid, (*p_tree)[nid].LeftChild(), - (*p_tree)[nid].RightChild(), n_left, n_right); - } -} - -template -template -void QuantileHistMaker::Builder::ApplySplit(const std::vector nodes, - const GHistIndexMatrix& gmat, - const ColumnMatrix& column_matrix, - RegTree* p_tree) { - builder_monitor_.Start("ApplySplit"); - // 1. Find split condition for each split - const size_t n_nodes = nodes.size(); - std::vector split_conditions; - FindSplitConditions(nodes, *p_tree, gmat, &split_conditions); - // 2.1 Create a blocked space of size SUM(samples in each node) - common::BlockedSpace2d space(n_nodes, [&](size_t node_in_set) { - int32_t nid = nodes[node_in_set].nid; - return row_set_collection_[nid].Size(); - }, kPartitionBlockSize); - // 2.2 Initialize the partition builder - // allocate buffers for storage intermediate results by each thread - partition_builder_.Init(space.Size(), n_nodes, [&](size_t node_in_set) { - const int32_t nid = nodes[node_in_set].nid; - const size_t size = row_set_collection_[nid].Size(); - const size_t n_tasks = size / kPartitionBlockSize + !!(size % kPartitionBlockSize); - return n_tasks; - }); - // 2.3 Split elements of row_set_collection_ to left and right child-nodes for each node - // Store results in intermediate buffers from partition_builder_ - common::ParallelFor2d(space, this->ctx_->Threads(), [&](size_t node_in_set, common::Range1d r) { - size_t begin = r.begin(); - const int32_t nid = nodes[node_in_set].nid; - const size_t task_id = partition_builder_.GetTaskIdx(node_in_set, begin); - partition_builder_.AllocateForTask(task_id); - switch (column_matrix.GetTypeSize()) { - case common::kUint8BinsTypeSize: - partition_builder_.template Partition(node_in_set, nid, r, - split_conditions[node_in_set], column_matrix, - *p_tree, row_set_collection_[nid].begin); - break; - case common::kUint16BinsTypeSize: - partition_builder_.template Partition(node_in_set, nid, r, - split_conditions[node_in_set], column_matrix, - *p_tree, row_set_collection_[nid].begin); - break; - case common::kUint32BinsTypeSize: - partition_builder_.template Partition(node_in_set, nid, r, - split_conditions[node_in_set], column_matrix, - *p_tree, row_set_collection_[nid].begin); - break; - default: - CHECK(false); // no default behavior - } - }); - // 3. Compute offsets to copy blocks of row-indexes - // from partition_builder_ to row_set_collection_ - partition_builder_.CalculateRowOffsets(); - - // 4. Copy elements from partition_builder_ to row_set_collection_ back - // with updated row-indexes for each tree-node - common::ParallelFor2d(space, this->ctx_->Threads(), [&](size_t node_in_set, common::Range1d r) { - const int32_t nid = nodes[node_in_set].nid; - partition_builder_.MergeToArray(node_in_set, r.begin(), - const_cast(row_set_collection_[nid].begin)); - }); - // 5. Add info about splits into row_set_collection_ - AddSplitsToRowSet(nodes, p_tree); - builder_monitor_.Stop("ApplySplit"); -} - template struct QuantileHistMaker::Builder; template struct QuantileHistMaker::Builder; diff --git a/src/tree/updater_quantile_hist.h b/src/tree/updater_quantile_hist.h index 3f2b07ff972c..32fda2e06865 100644 --- a/src/tree/updater_quantile_hist.h +++ b/src/tree/updater_quantile_hist.h @@ -11,9 +11,9 @@ #include #include -#include +#include +#include #include -#include #include #include #include @@ -83,6 +83,148 @@ struct RandomReplace { namespace tree { +class HistRowPartitioner { + static constexpr size_t kPartitionBlockSize = 2048; + common::PartitionBuilder partition_builder_; + common::RowSetCollection row_set_collection_; + + void FindSplitConditions(const std::vector& nodes, const RegTree& tree, + const GHistIndexMatrix& gmat, std::vector* split_conditions) { + const size_t n_nodes = nodes.size(); + split_conditions->resize(n_nodes); + + for (size_t i = 0; i < nodes.size(); ++i) { + const int32_t nid = nodes[i].nid; + const bst_uint fid = tree[nid].SplitIndex(); + const bst_float split_pt = tree[nid].SplitCond(); + const uint32_t lower_bound = gmat.cut.Ptrs()[fid]; + const uint32_t upper_bound = gmat.cut.Ptrs()[fid + 1]; + int32_t split_cond = -1; + // convert floating-point split_pt into corresponding bin_id + // split_cond = -1 indicates that split_pt is less than all known cut points + CHECK_LT(upper_bound, static_cast(std::numeric_limits::max())); + for (uint32_t bound = lower_bound; bound < upper_bound; ++bound) { + if (split_pt == gmat.cut.Values()[bound]) { + split_cond = static_cast(bound); + } + } + (*split_conditions)[i] = split_cond; + } + } + + void AddSplitsToRowSet(const std::vector& nodes, RegTree const* p_tree) { + const size_t n_nodes = nodes.size(); + for (unsigned int i = 0; i < n_nodes; ++i) { + const int32_t nid = nodes[i].nid; + const size_t n_left = partition_builder_.GetNLeftElems(i); + const size_t n_right = partition_builder_.GetNRightElems(i); + CHECK_EQ((*p_tree)[nid].LeftChild() + 1, (*p_tree)[nid].RightChild()); + row_set_collection_.AddSplit(nid, (*p_tree)[nid].LeftChild(), (*p_tree)[nid].RightChild(), + n_left, n_right); + } + } + + public: + bst_row_t base_rowid = 0; + + public: + void Init(size_t n_samples, size_t base_rowid, int32_t n_threads) { + row_set_collection_.Clear(); + const size_t block_size = n_samples / n_threads + !!(n_samples % n_threads); + dmlc::OMPException exc; + std::vector& row_indices = *row_set_collection_.Data(); + row_indices.resize(n_samples); + size_t* p_row_indices = row_indices.data(); +#pragma omp parallel num_threads(n_threads) + { + exc.Run([&]() { + const size_t tid = omp_get_thread_num(); + const size_t ibegin = tid * block_size; + const size_t iend = std::min(static_cast(ibegin + block_size), n_samples); + for (size_t i = ibegin; i < iend; ++i) { + p_row_indices[i] = i + base_rowid; + } + }); + } + row_set_collection_.Init(); + this->base_rowid = base_rowid; + } + + template + void UpdatePosition(int32_t n_threads, GHistIndexMatrix const& gmat, + common::ColumnMatrix const& column_matrix, + std::vector const& nodes, RegTree const* p_tree) { + // 1. Find split condition for each split + const size_t n_nodes = nodes.size(); + std::vector split_conditions; + FindSplitConditions(nodes, *p_tree, gmat, &split_conditions); + // 2.1 Create a blocked space of size SUM(samples in each node) + common::BlockedSpace2d space( + n_nodes, + [&](size_t node_in_set) { + int32_t nid = nodes[node_in_set].nid; + return row_set_collection_[nid].Size(); + }, + kPartitionBlockSize); + // 2.2 Initialize the partition builder + // allocate buffers for storage intermediate results by each thread + partition_builder_.Init(space.Size(), n_nodes, [&](size_t node_in_set) { + const int32_t nid = nodes[node_in_set].nid; + const size_t size = row_set_collection_[nid].Size(); + const size_t n_tasks = size / kPartitionBlockSize + !!(size % kPartitionBlockSize); + return n_tasks; + }); + // 2.3 Split elements of row_set_collection_ to left and right child-nodes for each node + // Store results in intermediate buffers from partition_builder_ + common::ParallelFor2d(space, n_threads, [&](size_t node_in_set, common::Range1d r) { + size_t begin = r.begin(); + const int32_t nid = nodes[node_in_set].nid; + const size_t task_id = partition_builder_.GetTaskIdx(node_in_set, begin); + partition_builder_.AllocateForTask(task_id); + switch (column_matrix.GetTypeSize()) { + case common::kUint8BinsTypeSize: + partition_builder_.template Partition( + node_in_set, nid, r, split_conditions[node_in_set], column_matrix, *p_tree, + row_set_collection_[nid].begin, base_rowid); + break; + case common::kUint16BinsTypeSize: + partition_builder_.template Partition( + node_in_set, nid, r, split_conditions[node_in_set], column_matrix, *p_tree, + row_set_collection_[nid].begin, base_rowid); + break; + case common::kUint32BinsTypeSize: + partition_builder_.template Partition( + node_in_set, nid, r, split_conditions[node_in_set], column_matrix, *p_tree, + row_set_collection_[nid].begin, base_rowid); + break; + default: + // no default behavior + CHECK(false) << column_matrix.GetTypeSize(); + } + }); + // 3. Compute offsets to copy blocks of row-indexes + // from partition_builder_ to row_set_collection_ + partition_builder_.CalculateRowOffsets(); + + // 4. Copy elements from partition_builder_ to row_set_collection_ back + // with updated row-indexes for each tree-node + common::ParallelFor2d(space, n_threads, [&](size_t node_in_set, common::Range1d r) { + const int32_t nid = nodes[node_in_set].nid; + partition_builder_.MergeToArray(node_in_set, r.begin(), + const_cast(row_set_collection_[nid].begin)); + }); + // 5. Add info about splits into row_set_collection_ + AddSplitsToRowSet(nodes, p_tree); + } + + auto const& Partitions() const { return row_set_collection_; } + size_t Size() const { + return std::distance(row_set_collection_.begin(), row_set_collection_.end()); + } + auto& operator[](bst_node_t nidx) { return row_set_collection_[nidx]; } + auto const& operator[](bst_node_t nidx) const { return row_set_collection_[nidx]; } +}; + using xgboost::GHistIndexMatrix; using xgboost::common::GHistIndexRow; using xgboost::common::HistCollection; @@ -146,10 +288,7 @@ class QuantileHistMaker: public TreeUpdater { CPUHistMakerTrainParam hist_maker_param_; // training parameter TrainParam param_; - // column accessor - ColumnMatrix column_matrix_; DMatrix const* p_last_dmat_ {nullptr}; - bool is_gmat_initialized_ {false}; // actual builder that runs the algorithm template @@ -172,61 +311,62 @@ class QuantileHistMaker: public TreeUpdater { builder_monitor_.Init("Quantile::Builder"); } // update one tree, growing - void Update(const GHistIndexMatrix& gmat, const ColumnMatrix& column_matrix, - HostDeviceVector* gpair, DMatrix* p_fmat, RegTree* p_tree); + void Update(const GHistIndexMatrix& gmat, HostDeviceVector* gpair, + DMatrix* p_fmat, RegTree* p_tree); bool UpdatePredictionCache(const DMatrix* data, linalg::VectorView out_preds); protected: // initialize temp data structure - void InitData(const GHistIndexMatrix& gmat, - const DMatrix& fmat, - const RegTree& tree, + void InitData(const GHistIndexMatrix& gmat, DMatrix* fmat, const RegTree& tree, std::vector* gpair); size_t GetNumberOfTrees(); - void InitSampling(const DMatrix& fmat, - std::vector* gpair, - std::vector* row_indices); - - template - void ApplySplit(std::vector nodes, - const GHistIndexMatrix& gmat, - const ColumnMatrix& column_matrix, - RegTree* p_tree); - - void AddSplitsToRowSet(const std::vector& nodes, RegTree* p_tree); - - - void FindSplitConditions(const std::vector& nodes, const RegTree& tree, - const GHistIndexMatrix& gmat, std::vector* split_conditions); - - template - void InitRoot(DMatrix* p_fmat, - RegTree *p_tree, - const std::vector &gpair_h, - int *num_leaves, std::vector *expand); - - // Split nodes to 2 sets depending on amount of rows in each node - // Histograms for small nodes will be built explicitly - // Histograms for big nodes will be built by 'Subtraction Trick' - void SplitSiblings(const std::vector& nodes, - std::vector* nodes_to_evaluate, - RegTree *p_tree); + void InitSampling(const DMatrix& fmat, std::vector* gpair); + + void InitRoot(DMatrix* p_fmat, RegTree* p_tree, const std::vector& gpair_h, + int* num_leaves, std::vector* expand); + + void BuildHistogram(DMatrix* p_fmat, RegTree* p_tree, + std::vector const& valid_candidates, + std::vector* p_to_build, + std::vector* p_to_sub, + std::vector const& gpair) { + std::vector& nodes_to_build = *p_to_build; + nodes_to_build.resize(valid_candidates.size()); + std::vector& nodes_to_sub = *p_to_sub; + nodes_to_sub.resize(valid_candidates.size()); + + size_t n_idx = 0; + for (auto const& c : valid_candidates) { + auto left_nidx = (*p_tree)[c.nid].LeftChild(); + auto right_nidx = (*p_tree)[c.nid].RightChild(); + auto fewer_right = c.split.right_sum.GetHess() < c.split.left_sum.GetHess(); + + auto build_nidx = left_nidx; + auto subtract_nidx = right_nidx; + if (fewer_right) { + std::swap(build_nidx, subtract_nidx); + } + nodes_to_build[n_idx] = CPUExpandEntry{build_nidx, p_tree->GetDepth(build_nidx), {}}; + nodes_to_sub[n_idx] = CPUExpandEntry{subtract_nidx, p_tree->GetDepth(subtract_nidx), {}}; + n_idx++; + } - void AddSplitsToTree(const std::vector& expand, - RegTree *p_tree, - int *num_leaves, - std::vector* nodes_for_apply_split); + size_t page_id {0}; + auto space = ConstructHistSpace(row_partitioner_, nodes_to_build); + for (auto const& gidx : + p_fmat->GetBatches({param_.max_bin, param_.sparse_threshold})) { + histogram_builder_->BuildHist(page_id, space, gidx, p_tree, + row_partitioner_.at(page_id).Partitions(), nodes_to_build, + nodes_to_sub, gpair); + ++page_id; + } + } - template - void ExpandTree(const GHistIndexMatrix& gmat, - const ColumnMatrix& column_matrix, - DMatrix* p_fmat, - RegTree* p_tree, - const std::vector& gpair_h); + void ExpandTree(DMatrix* p_fmat, RegTree* p_tree, const std::vector& gpair_h); // --data fields-- const size_t n_trees_; @@ -235,31 +375,17 @@ class QuantileHistMaker: public TreeUpdater { std::make_shared()}; std::vector unused_rows_; - // the internal row sets - RowSetCollection row_set_collection_; + std::vector row_partitioner_; std::vector gpair_local_; - /*! \brief feature with least # of bins. to be used for dense specialization - of InitNewNode() */ - uint32_t fid_least_bins_; - std::unique_ptr pruner_; std::unique_ptr> evaluator_; - static constexpr size_t kPartitionBlockSize = 2048; - common::PartitionBuilder partition_builder_; - // back pointers to tree and data matrix const RegTree* p_last_tree_; DMatrix const* const p_last_fmat_; DMatrix* p_last_fmat_mutable_; - // key is the node id which should be calculated by Subtraction Trick, value is the node which - // provides the evidence for subtraction - std::vector nodes_for_subtraction_trick_; - // list of nodes whose histograms would be built explicitly. - std::vector nodes_for_explicit_hist_build_; - enum class DataLayout { kDenseDataZeroBased, kDenseDataOneBased, kSparseData }; DataLayout data_layout_; std::unique_ptr> histogram_builder_; diff --git a/tests/cpp/common/test_column_matrix.cc b/tests/cpp/common/test_column_matrix.cc index 46d89fe97ed7..2626b6fb3c1d 100644 --- a/tests/cpp/common/test_column_matrix.cc +++ b/tests/cpp/common/test_column_matrix.cc @@ -21,7 +21,9 @@ TEST(DenseColumn, Test) { GHistIndexMatrix gmat{dmat.get(), max_num_bin, sparse_thresh, false, common::OmpGetNumThreads(0)}; ColumnMatrix column_matrix; - column_matrix.Init(gmat, 0.2, common::OmpGetNumThreads(0)); + for (auto const& page : dmat->GetBatches()) { + column_matrix.Init(page, gmat, sparse_thresh, common::OmpGetNumThreads(0)); + } for (auto i = 0ull; i < dmat->Info().num_row_; i++) { for (auto j = 0ull; j < dmat->Info().num_col_; j++) { @@ -68,7 +70,9 @@ TEST(SparseColumn, Test) { auto dmat = RandomDataGenerator(100, 1, 0.85).GenerateDMatrix(); GHistIndexMatrix gmat{dmat.get(), max_num_bin, 0.5f, false, common::OmpGetNumThreads(0)}; ColumnMatrix column_matrix; - column_matrix.Init(gmat, 0.5, common::OmpGetNumThreads(0)); + for (auto const& page : dmat->GetBatches()) { + column_matrix.Init(page, gmat, 1.0, common::OmpGetNumThreads(0)); + } switch (column_matrix.GetTypeSize()) { case kUint8BinsTypeSize: { auto col = column_matrix.GetColumn(0); @@ -106,9 +110,11 @@ TEST(DenseColumnWithMissing, Test) { static_cast(std::numeric_limits::max()) + 2}; for (int32_t max_num_bin : max_num_bins) { auto dmat = RandomDataGenerator(100, 1, 0.5).GenerateDMatrix(); - GHistIndexMatrix gmat{dmat.get(), max_num_bin, 0.2, false, common::OmpGetNumThreads(0)}; + GHistIndexMatrix gmat(dmat.get(), max_num_bin, 0.2, false, common::OmpGetNumThreads(0)); ColumnMatrix column_matrix; - column_matrix.Init(gmat, 0.2, common::OmpGetNumThreads(0)); + for (auto const& page : dmat->GetBatches()) { + column_matrix.Init(page, gmat, 0.2, common::OmpGetNumThreads(0)); + } switch (column_matrix.GetTypeSize()) { case kUint8BinsTypeSize: { auto col = column_matrix.GetColumn(0); diff --git a/tests/cpp/data/test_gradient_index_page_raw_format.cc b/tests/cpp/data/test_gradient_index_page_raw_format.cc index b24ee8770b8d..fa1a10faa829 100644 --- a/tests/cpp/data/test_gradient_index_page_raw_format.cc +++ b/tests/cpp/data/test_gradient_index_page_raw_format.cc @@ -3,6 +3,7 @@ */ #include +#include "../../../src/common/column_matrix.h" #include "../../../src/data/gradient_index.h" #include "../../../src/data/sparse_page_source.h" #include "../helpers.h" @@ -15,33 +16,31 @@ TEST(GHistIndexPageRawFormat, IO) { auto m = RandomDataGenerator{100, 14, 0.5}.GenerateDMatrix(); dmlc::TemporaryDirectory tmpdir; std::string path = tmpdir.path + "/ghistindex.page"; + auto batch = BatchParam{256, 0.5}; { std::unique_ptr fo{dmlc::Stream::Create(path.c_str(), "w")}; - for (auto const &index : - m->GetBatches({GenericParameter::kCpuId, 256})) { + for (auto const &index : m->GetBatches(batch)) { format->Write(index, fo.get()); } } GHistIndexMatrix page; - std::unique_ptr fi{ - dmlc::SeekStream::CreateForRead(path.c_str())}; + std::unique_ptr fi{dmlc::SeekStream::CreateForRead(path.c_str())}; format->Read(&page, fi.get()); - for (auto const &gidx : - m->GetBatches({GenericParameter::kCpuId, 256})) { + for (auto const &gidx : m->GetBatches(batch)) { auto const &loaded = gidx; ASSERT_EQ(loaded.cut.Ptrs(), page.cut.Ptrs()); ASSERT_EQ(loaded.cut.MinValues(), page.cut.MinValues()); ASSERT_EQ(loaded.cut.Values(), page.cut.Values()); ASSERT_EQ(loaded.base_rowid, page.base_rowid); ASSERT_EQ(loaded.IsDense(), page.IsDense()); - ASSERT_TRUE(std::equal(loaded.index.begin(), loaded.index.end(), - page.index.begin())); - ASSERT_TRUE(std::equal(loaded.index.Offset(), - loaded.index.Offset() + loaded.index.OffsetSize(), + ASSERT_TRUE(std::equal(loaded.index.begin(), loaded.index.end(), page.index.begin())); + ASSERT_TRUE(std::equal(loaded.index.Offset(), loaded.index.Offset() + loaded.index.OffsetSize(), page.index.Offset())); + + ASSERT_EQ(loaded.Transpose().GetTypeSize(), loaded.Transpose().GetTypeSize()); } } } // namespace data diff --git a/tests/cpp/tree/hist/test_histogram.cc b/tests/cpp/tree/hist/test_histogram.cc index 553550e3301f..2dc68bab4144 100644 --- a/tests/cpp/tree/hist/test_histogram.cc +++ b/tests/cpp/tree/hist/test_histogram.cc @@ -445,6 +445,12 @@ void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx) { TEST(CPUHistogram, ExternalMemory) { int32_t constexpr kBins = 256; TestHistogramExternalMemory(BatchParam{kBins, common::Span{}, false}, true); + + float sparse_thresh{0.5}; + TestHistogramExternalMemory({kBins, sparse_thresh}, false); + sparse_thresh = std::numeric_limits::quiet_NaN(); + TestHistogramExternalMemory({kBins, sparse_thresh}, false); + } } // namespace tree } // namespace xgboost diff --git a/tests/cpp/tree/test_quantile_hist.cc b/tests/cpp/tree/test_quantile_hist.cc index fc7c43ad7658..dcc7a60ddb43 100644 --- a/tests/cpp/tree/test_quantile_hist.cc +++ b/tests/cpp/tree/test_quantile_hist.cc @@ -35,7 +35,7 @@ class QuantileHistMock : public QuantileHistMaker { std::vector* gpair, DMatrix* p_fmat, const RegTree& tree) { - RealImpl::InitData(gmat, *p_fmat, tree, gpair); + RealImpl::InitData(gmat, p_fmat, tree, gpair); ASSERT_EQ(this->data_layout_, RealImpl::DataLayout::kSparseData); /* The creation of HistCutMatrix and GHistIndexMatrix are not technically @@ -95,130 +95,6 @@ class QuantileHistMock : public QuantileHistMaker { } } } - - void TestInitDataSampling(const GHistIndexMatrix& gmat, - std::vector* gpair, - DMatrix* p_fmat, - const RegTree& tree) { - // check SimpleSkip - size_t initial_seed = 777; - std::linear_congruential_engine(1) << 63 > eng_first(initial_seed); - for (size_t i = 0; i < 100; ++i) { - eng_first(); - } - uint64_t initial_seed_th = RandomReplace::SimpleSkip(100, initial_seed, 16807, RandomReplace::kMod); - std::linear_congruential_engine eng_second(initial_seed_th); - ASSERT_EQ(eng_first(), eng_second()); - - const size_t nthreads = omp_get_num_threads(); - // save state of global rng engine - auto initial_rnd = common::GlobalRandom(); - std::vector unused_rows_cpy = this->unused_rows_; - RealImpl::InitData(gmat, *p_fmat, tree, gpair); - std::vector row_indices_initial = *(this->row_set_collection_.Data()); - std::vector unused_row_indices_initial = this->unused_rows_; - ASSERT_EQ(row_indices_initial.size(), p_fmat->Info().num_row_); - auto check_each_row_occurs_in_one_of_arrays = [](const std::vector& first, - const std::vector& second, - size_t nrows) { - ASSERT_EQ(first.size(), nrows); - ASSERT_EQ(second.size(), 0); - }; - check_each_row_occurs_in_one_of_arrays(row_indices_initial, unused_row_indices_initial, - p_fmat->Info().num_row_); - - for (size_t i_nthreads = 1; i_nthreads < 4; ++i_nthreads) { - omp_set_num_threads(i_nthreads); - // return initial state of global rng engine - common::GlobalRandom() = initial_rnd; - this->unused_rows_ = unused_rows_cpy; - RealImpl::InitData(gmat, *p_fmat, tree, gpair); - std::vector& row_indices = *(this->row_set_collection_.Data()); - ASSERT_EQ(row_indices_initial.size(), row_indices.size()); - for (size_t i = 0; i < row_indices_initial.size(); ++i) { - ASSERT_EQ(row_indices_initial[i], row_indices[i]); - } - std::vector& unused_row_indices = this->unused_rows_; - ASSERT_EQ(unused_row_indices_initial.size(), unused_row_indices.size()); - for (size_t i = 0; i < unused_row_indices_initial.size(); ++i) { - ASSERT_EQ(unused_row_indices_initial[i], unused_row_indices[i]); - } - check_each_row_occurs_in_one_of_arrays(row_indices, unused_row_indices, - p_fmat->Info().num_row_); - } - omp_set_num_threads(nthreads); - } - - void TestApplySplit(const RegTree& tree) { - std::vector row_gpairs = - { {1.23f, 0.24f}, {0.24f, 0.25f}, {0.26f, 0.27f}, {2.27f, 0.28f}, - {0.27f, 0.29f}, {0.37f, 0.39f}, {-0.47f, 0.49f}, {0.57f, 0.59f} }; - int32_t constexpr kMaxBins = 4; - - // try out different sparsity to get different number of missing values - for (double sparsity : {0.0, 0.1, 0.2}) { - // kNRows samples with kNCols features - auto dmat = RandomDataGenerator(kNRows, kNCols, sparsity).Seed(3).GenerateDMatrix(); - - float sparse_th = 0.0; - GHistIndexMatrix gmat{dmat.get(), kMaxBins, sparse_th, false, common::OmpGetNumThreads(0)}; - ColumnMatrix cm; - - // treat everything as dense, as this is what we intend to test here - cm.Init(gmat, sparse_th, common::OmpGetNumThreads(0)); - RealImpl::InitData(gmat, *dmat, tree, &row_gpairs); - const size_t num_row = dmat->Info().num_row_; - // split by feature 0 - const size_t bin_id_min = gmat.cut.Ptrs()[0]; - const size_t bin_id_max = gmat.cut.Ptrs()[1]; - - // attempt to split at different bins - for (size_t split = 0; split < 4; split++) { - size_t left_cnt = 0, right_cnt = 0; - - // manually compute how many samples go left or right - for (size_t rid = 0; rid < num_row; ++rid) { - for (size_t offset = gmat.row_ptr[rid]; offset < gmat.row_ptr[rid + 1]; ++offset) { - const size_t bin_id = gmat.index[offset]; - if (bin_id >= bin_id_min && bin_id < bin_id_max) { - if (bin_id <= split) { - left_cnt++; - } else { - right_cnt++; - } - } - } - } - - // if any were missing due to sparsity, we add them to the left or to the right - size_t missing = kNRows - left_cnt - right_cnt; - if (tree[0].DefaultLeft()) { - left_cnt += missing; - } else { - right_cnt += missing; - } - - // have one node with kNRows (=8 at the moment) rows, just one task - RealImpl::partition_builder_.Init(1, 1, [&](size_t node_in_set) { - return 1; - }); - const size_t task_id = RealImpl::partition_builder_.GetTaskIdx(0, 0); - RealImpl::partition_builder_.AllocateForTask(task_id); - if (cm.AnyMissing()) { - RealImpl::partition_builder_.template Partition(0, 0, common::Range1d(0, kNRows), - split, cm, tree, this->row_set_collection_[0].begin); - } else { - RealImpl::partition_builder_.template Partition(0, 0, common::Range1d(0, kNRows), - split, cm, tree, this->row_set_collection_[0].begin); - } - RealImpl::partition_builder_.CalculateRowOffsets(); - ASSERT_EQ(RealImpl::partition_builder_.GetNLeftElems(0), left_cnt); - ASSERT_EQ(RealImpl::partition_builder_.GetNRightElems(0), right_cnt); - } - } - } }; int static constexpr kNRows = 8, kNCols = 16; @@ -249,7 +125,8 @@ class QuantileHistMock : public QuantileHistMaker { void TestInitData() { int32_t constexpr kMaxBins = 4; - GHistIndexMatrix gmat{dmat_.get(), kMaxBins, 0.0f, false, common::OmpGetNumThreads(0)}; + GHistIndexMatrix gmat{dmat_.get(), kMaxBins, param_.sparse_threshold, false, + common::OmpGetNumThreads(0)}; RegTree tree = RegTree(); tree.param.UpdateAllowUnknown(cfg_); @@ -263,33 +140,6 @@ class QuantileHistMock : public QuantileHistMaker { float_builder_->TestInitData(gmat, &gpair, dmat_.get(), tree); } } - - void TestInitDataSampling() { - int32_t constexpr kMaxBins = 4; - GHistIndexMatrix gmat{dmat_.get(), kMaxBins, 0.0f, false, common::OmpGetNumThreads(0)}; - - RegTree tree = RegTree(); - tree.param.UpdateAllowUnknown(cfg_); - - std::vector gpair = - { {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, {0.23f, 0.24f}, - {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f}, {0.27f, 0.29f} }; - if (double_builder_) { - double_builder_->TestInitDataSampling(gmat, &gpair, dmat_.get(), tree); - } else { - float_builder_->TestInitDataSampling(gmat, &gpair, dmat_.get(), tree); - } - } - - void TestApplySplit() { - RegTree tree = RegTree(); - tree.param.UpdateAllowUnknown(cfg_); - if (double_builder_) { - double_builder_->TestApplySplit(tree); - } else { - float_builder_->TestApplySplit(tree); - } - } }; TEST(QuantileHist, InitData) { @@ -301,31 +151,5 @@ TEST(QuantileHist, InitData) { QuantileHistMock maker_float(cfg, single_precision_histogram); maker_float.TestInitData(); } - -TEST(QuantileHist, InitDataSampling) { - const float subsample = 0.5; - std::vector> cfg - {{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())}, - {"subsample", std::to_string(subsample)}}; - QuantileHistMock maker(cfg); - maker.TestInitDataSampling(); - const bool single_precision_histogram = true; - QuantileHistMock maker_float(cfg, single_precision_histogram); - maker_float.TestInitDataSampling(); -} - -TEST(QuantileHist, ApplySplit) { - std::vector> cfg - {{"num_feature", std::to_string(QuantileHistMock::GetNumColumns())}, - {"split_evaluator", "elastic_net"}, - {"reg_lambda", "0"}, {"reg_alpha", "0"}, {"max_delta_step", "0"}, - {"min_child_weight", "0"}}; - QuantileHistMock maker(cfg); - maker.TestApplySplit(); - const bool single_precision_histogram = true; - QuantileHistMock maker_float(cfg, single_precision_histogram); - maker_float.TestApplySplit(); -} - } // namespace tree } // namespace xgboost diff --git a/tests/python/test_data_iterator.py b/tests/python/test_data_iterator.py index 946127d1311f..b5fdeb34cc78 100644 --- a/tests/python/test_data_iterator.py +++ b/tests/python/test_data_iterator.py @@ -1,7 +1,7 @@ import xgboost as xgb from xgboost.data import SingleBatchInternalIter as SingleBatch import numpy as np -from testing import IteratorForTest +from testing import IteratorForTest, non_increasing from typing import Tuple, List import pytest from hypothesis import given, strategies, settings @@ -108,7 +108,7 @@ def run_data_iterator( evals_result=results_from_it, verbose_eval=False, ) - it_predt = from_it.predict(Xy) + assert non_increasing(results_from_it["Train"]["rmse"]) X, y = it.as_arrays() Xy = xgb.DMatrix(X, y) @@ -125,13 +125,14 @@ def run_data_iterator( verbose_eval=False, ) arr_predt = from_arrays.predict(Xy) + assert non_increasing(results_from_arrays["Train"]["rmse"]) - if tree_method != "gpu_hist": - rtol = 1e-1 # flaky - else: - # Model can be sensitive to quantiles, use 1e-2 to relax the test. - np.testing.assert_allclose(it_predt, arr_predt, rtol=1e-2) - rtol = 1e-6 + rtol = 1e-2 + # CPU sketching is more memory efficient but less consistent due to small chunks + if tree_method == "gpu_hist": + it_predt = from_it.predict(Xy) + arr_predt = from_arrays.predict(Xy) + np.testing.assert_allclose(it_predt, arr_predt, rtol=rtol) np.testing.assert_allclose( results_from_it["Train"]["rmse"], diff --git a/tests/python/test_with_dask.py b/tests/python/test_with_dask.py index 55fd22e027ad..1d0fed6e5916 100644 --- a/tests/python/test_with_dask.py +++ b/tests/python/test_with_dask.py @@ -1300,9 +1300,12 @@ def run_quantile(self, name: str) -> None: pytest.skip("Skipping dask tests on Windows") exe: Optional[str] = None - for possible_path in {'./testxgboost', './build/testxgboost', - '../build/cpubuild/testxgboost', - '../cpu-build/testxgboost'}: + for possible_path in { + './testxgboost', + './build/testxgboost', + '../build/testxgboost', + "../build/cpubuild/testxgboost", + }: if os.path.exists(possible_path): exe = possible_path if exe is None: