diff --git a/src/data/gradient_index.cc b/src/data/gradient_index.cc index 6f8c5ee9f44e..08f03f1a14d3 100644 --- a/src/data/gradient_index.cc +++ b/src/data/gradient_index.cc @@ -7,6 +7,7 @@ #include #include #include +#include // std::forward #include "../common/column_matrix.h" #include "../common/hist_util.h" @@ -43,7 +44,7 @@ GHistIndexMatrix::GHistIndexMatrix(DMatrix *p_fmat, bst_bin_t max_bins_per_feat, auto ft = p_fmat->Info().feature_types.ConstHostSpan(); for (const auto &batch : p_fmat->GetBatches()) { - this->PushBatch(batch, ft, nbins, n_threads); + this->PushBatch(batch, ft, n_threads); } this->columns_ = std::make_unique(); @@ -57,62 +58,29 @@ GHistIndexMatrix::GHistIndexMatrix(DMatrix *p_fmat, bst_bin_t max_bins_per_feat, } } +GHistIndexMatrix::GHistIndexMatrix(MetaInfo const &info, common::HistogramCuts &&cuts, + bst_bin_t max_bin_per_feat) + : row_ptr(info.num_row_ + 1, 0), + hit_count(cuts.TotalBins(), 0), + cut{std::forward(cuts)}, + max_num_bins(max_bin_per_feat), + isDense_{info.num_col_ * info.num_row_ == info.num_nonzero_} {} + GHistIndexMatrix::~GHistIndexMatrix() = default; void GHistIndexMatrix::PushBatch(SparsePage const &batch, common::Span ft, - bst_bin_t n_total_bins, int32_t n_threads) { + int32_t n_threads) { auto page = batch.GetView(); auto it = common::MakeIndexTransformIter([&](size_t ridx) { return page[ridx].size(); }); common::PartialSum(n_threads, it, it + page.Size(), static_cast(0), row_ptr.begin()); - // The number of threads is pegged to the batch size. If the OMP block is parallelized - // on anything other than the batch/block size, it should be reassigned - const size_t batch_threads = - std::max(static_cast(1), std::min(batch.Size(), static_cast(n_threads))); - - const size_t n_index = row_ptr[batch.Size()]; // number of entries in this page - ResizeIndex(n_index, isDense_); - - CHECK_GT(cut.Values().size(), 0U); - - if (isDense_) { - index.SetBinOffset(cut.Ptrs()); - } - uint32_t const *offsets = index.Offset(); - - auto n_bins_total = cut.TotalBins(); - auto is_valid = [](auto) { return true; }; // SparsePage always contains valid entries data::SparsePageAdapterBatch adapter_batch{page}; - if (isDense_) { - // Inside the lambda functions, bin_idx is the index for cut value across all - // features. By subtracting it with starting pointer of each feature, we can reduce - // it to smaller value and compress it to smaller types. - common::DispatchBinType(index.GetBinTypeSize(), [&](auto dtype) { - using T = decltype(dtype); - common::Span index_data_span = {index.data(), index.Size()}; - SetIndexData( - index_data_span, ft, batch_threads, adapter_batch, is_valid, n_bins_total, - [offsets](auto bin_idx, auto fidx) { return static_cast(bin_idx - offsets[fidx]); }); - }); - } else { - /* For sparse DMatrix we have to store index of feature for each bin - in index field to chose right offset. So offset is nullptr and index is - not reduced */ - common::Span index_data_span = {index.data(), n_index}; - SetIndexData(index_data_span, ft, batch_threads, adapter_batch, is_valid, n_bins_total, - [](auto idx, auto) { return idx; }); - } - - common::ParallelFor(n_total_bins, n_threads, [&](bst_omp_uint idx) { - for (int32_t tid = 0; tid < n_threads; ++tid) { - hit_count[idx] += hit_count_tloc_[tid * n_total_bins + idx]; - hit_count_tloc_[tid * n_total_bins + idx] = 0; // reset for next batch - } - }); + auto is_valid = [](auto) { return true; }; // SparsePage always contains valid entries + PushBatchImpl(n_threads, adapter_batch, 0, is_valid, ft); } -void GHistIndexMatrix::Init(SparsePage const &batch, common::Span ft, - common::HistogramCuts const &cuts, int32_t max_bins_per_feat, - bool isDense, double sparse_thresh, int32_t n_threads) { +GHistIndexMatrix::GHistIndexMatrix(SparsePage const &batch, common::Span ft, + common::HistogramCuts const &cuts, int32_t max_bins_per_feat, + bool isDense, double sparse_thresh, int32_t n_threads) { CHECK_GE(n_threads, 1); base_rowid = batch.base_rowid; isDense_ = isDense; @@ -127,13 +95,30 @@ void GHistIndexMatrix::Init(SparsePage const &batch, common::SpanPushBatch(batch, ft, nbins, n_threads); + this->PushBatch(batch, ft, n_threads); this->columns_ = std::make_unique(); if (!std::isnan(sparse_thresh)) { this->columns_->Init(batch, *this, sparse_thresh, n_threads); } } +template +void GHistIndexMatrix::PushAdapterBatchColumns(Context const *ctx, Batch const &batch, + float missing, size_t rbegin) { + CHECK(columns_); + this->columns_->PushBatch(ctx->Threads(), batch, missing, *this, rbegin); +} + +#define INSTANTIATION_PUSH(BatchT) \ + template void GHistIndexMatrix::PushAdapterBatchColumns( \ + Context const *ctx, BatchT const &batch, float missing, size_t rbegin); + +INSTANTIATION_PUSH(data::CSRArrayAdapterBatch) +INSTANTIATION_PUSH(data::ArrayAdapterBatch) +INSTANTIATION_PUSH(data::SparsePageAdapterBatch) + +#undef INSTANTIATION_PUSH + void GHistIndexMatrix::ResizeIndex(const size_t n_index, const bool isDense) { if ((max_num_bins - 1 <= static_cast(std::numeric_limits::max())) && isDense) { // compress dense index to uint8 @@ -156,6 +141,57 @@ common::ColumnMatrix const &GHistIndexMatrix::Transpose() const { return *columns_; } +float GHistIndexMatrix::GetFvalue(size_t ridx, size_t fidx, bool is_cat) const { + auto const &values = cut.Values(); + auto const &mins = cut.MinValues(); + auto const &ptrs = cut.Ptrs(); + if (is_cat) { + auto f_begin = ptrs[fidx]; + auto f_end = ptrs[fidx + 1]; + auto begin = RowIdx(ridx); + auto end = RowIdx(ridx + 1); + auto gidx = BinarySearchBin(begin, end, index, f_begin, f_end); + if (gidx == -1) { + return std::numeric_limits::quiet_NaN(); + } + return values[gidx]; + } + + auto lower = static_cast(cut.Ptrs()[fidx]); + auto get_bin_idx = [&](auto &column) { + auto bin_idx = column[ridx]; + if (bin_idx == common::DenseColumnIter::kMissingId) { + return std::numeric_limits::quiet_NaN(); + } + if (bin_idx == lower) { + return mins[fidx]; + } + return values[bin_idx - 1]; + }; + + if (columns_->GetColumnType(fidx) == common::kDenseColumn) { + if (columns_->AnyMissing()) { + return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) { + auto column = columns_->DenseColumn(fidx); + return get_bin_idx(column); + }); + } else { + return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) { + auto column = columns_->DenseColumn(fidx); + return get_bin_idx(column); + }); + } + } else { + return common::DispatchBinType(columns_->GetTypeSize(), [&](auto dtype) { + auto column = columns_->SparseColumn(fidx, 0); + return get_bin_idx(column); + }); + } + + SPAN_CHECK(false); + return std::numeric_limits::quiet_NaN(); +} + bool GHistIndexMatrix::ReadColumnPage(dmlc::SeekStream *fi) { return this->columns_->Read(fi, this->cut.Ptrs().data()); } diff --git a/src/data/gradient_index.h b/src/data/gradient_index.h index d7091757df52..71c199f81be6 100644 --- a/src/data/gradient_index.h +++ b/src/data/gradient_index.h @@ -4,13 +4,17 @@ */ #ifndef XGBOOST_DATA_GRADIENT_INDEX_H_ #define XGBOOST_DATA_GRADIENT_INDEX_H_ + +#include // std::min #include #include #include "../common/categorical.h" #include "../common/hist_util.h" +#include "../common/numeric.h" #include "../common/threading_utils.h" #include "adapter.h" +#include "proxy_dmatrix.h" #include "xgboost/base.h" #include "xgboost/data.h" @@ -18,7 +22,6 @@ namespace xgboost { namespace common { class ColumnMatrix; } // namespace common - /*! * \brief preprocessed global index matrix, in CSR format * @@ -26,24 +29,39 @@ class ColumnMatrix; * index for CPU histogram. On GPU ellpack page is used. */ class GHistIndexMatrix { + // Get the size of each row + template + auto GetRowCounts(AdapterBatchT const& batch, float missing, int32_t n_threads) { + std::vector valid_counts(batch.Size(), 0); + common::ParallelFor(batch.Size(), n_threads, [&](size_t i) { + auto line = batch.GetLine(i); + for (size_t j = 0; j < line.Size(); ++j) { + data::COOTuple elem = line.GetElement(j); + if (data::IsValidFunctor {missing}(elem)) { + valid_counts[i]++; + } + } + }); + return valid_counts; + } + /** * \brief Push a page into index matrix, the function is only necessary because hist has * partial support for external memory. */ - void PushBatch(SparsePage const& batch, common::Span ft, - bst_bin_t n_total_bins, int32_t n_threads); + void PushBatch(SparsePage const& batch, common::Span ft, int32_t n_threads); template - void SetIndexData(common::Span index_data_span, common::Span ft, - size_t batch_threads, Batch const& batch, IsValid&& is_valid, size_t nbins, - GetOffset&& get_offset) { + void SetIndexData(common::Span index_data_span, size_t rbegin, + common::Span ft, size_t batch_threads, Batch const& batch, + IsValid&& is_valid, size_t nbins, GetOffset&& get_offset) { auto batch_size = batch.Size(); BinIdxType* index_data = index_data_span.data(); auto const& ptrs = cut.Ptrs(); auto const& values = cut.Values(); common::ParallelFor(batch_size, batch_threads, [&](size_t i) { auto line = batch.GetLine(i); - size_t ibegin = row_ptr[i]; // index of first entry for current block + size_t ibegin = row_ptr[rbegin + i]; // index of first entry for current block size_t k = 0; auto tid = omp_get_thread_num(); for (size_t j = 0; j < line.Size(); ++j) { @@ -63,6 +81,49 @@ class GHistIndexMatrix { }); } + template + void PushBatchImpl(int32_t n_threads, Batch const& batch, size_t rbegin, IsValid&& is_valid, + common::Span ft) { + // The number of threads is pegged to the batch size. If the OMP block is parallelized + // on anything other than the batch/block size, it should be reassigned + size_t batch_threads = + std::max(static_cast(1), std::min(batch.Size(), static_cast(n_threads))); + + auto n_bins_total = cut.TotalBins(); + const size_t n_index = row_ptr[rbegin + batch.Size()]; // number of entries in this page + ResizeIndex(n_index, isDense_); + if (isDense_) { + index.SetBinOffset(cut.Ptrs()); + } + uint32_t const* offsets = index.Offset(); + if (isDense_) { + // Inside the lambda functions, bin_idx is the index for cut value across all + // features. By subtracting it with starting pointer of each feature, we can reduce + // it to smaller value and compress it to smaller types. + common::DispatchBinType(index.GetBinTypeSize(), [&](auto dtype) { + using T = decltype(dtype); + common::Span index_data_span = {index.data(), index.Size()}; + SetIndexData( + index_data_span, rbegin, ft, batch_threads, batch, is_valid, n_bins_total, + [offsets](auto bin_idx, auto fidx) { return static_cast(bin_idx - offsets[fidx]); }); + }); + } else { + /* For sparse DMatrix we have to store index of feature for each bin + in index field to chose right offset. So offset is nullptr and index is + not reduced */ + common::Span index_data_span = {index.data(), n_index}; + SetIndexData(index_data_span, rbegin, ft, batch_threads, batch, is_valid, n_bins_total, + [](auto idx, auto) { return idx; }); + } + + common::ParallelFor(n_bins_total, n_threads, [&](bst_omp_uint idx) { + for (int32_t tid = 0; tid < n_threads; ++tid) { + hit_count[idx] += hit_count_tloc_[tid * n_bins_total + idx]; + hit_count_tloc_[tid * n_bins_total + idx] = 0; // reset for next batch + } + }); + } + public: /*! \brief row pointer to rows by element position */ std::vector row_ptr; @@ -77,15 +138,53 @@ class GHistIndexMatrix { /*! \brief base row index for current page (used by external memory) */ size_t base_rowid{0}; - GHistIndexMatrix(); + ~GHistIndexMatrix(); + /** + * \brief Constrcutor for SimpleDMatrix. + */ GHistIndexMatrix(DMatrix* x, bst_bin_t max_bins_per_feat, double sparse_thresh, bool sorted_sketch, int32_t n_threads, common::Span hess = {}); - ~GHistIndexMatrix(); + /** + * \brief Constructor for Iterative DMatrix. Initialize basic information and prepare + * for push batch. + */ + GHistIndexMatrix(MetaInfo const& info, common::HistogramCuts&& cuts, bst_bin_t max_bin_per_feat); + /** + * \brief Constructor for external memory. + */ + GHistIndexMatrix(SparsePage const& page, common::Span ft, + common::HistogramCuts const& cuts, int32_t max_bins_per_feat, bool is_dense, + double sparse_thresh, int32_t n_threads); + GHistIndexMatrix(); // also for ext mem, empty ctor so that we can read the cache back. + + template + void PushAdapterBatch(Context const* ctx, size_t rbegin, size_t prev_sum, Batch const& batch, + float missing, common::Span ft, double sparse_thresh, + size_t n_samples_total) { + auto n_bins_total = cut.TotalBins(); + hit_count_tloc_.clear(); + hit_count_tloc_.resize(ctx->Threads() * n_bins_total, 0); + + auto n_threads = ctx->Threads(); + auto valid_counts = GetRowCounts(batch, missing, n_threads); + + auto it = common::MakeIndexTransformIter([&](size_t ridx) { return valid_counts[ridx]; }); + common::PartialSum(n_threads, it, it + batch.Size(), prev_sum, row_ptr.begin() + rbegin); + auto is_valid = data::IsValidFunctor{missing}; + + PushBatchImpl(ctx->Threads(), batch, rbegin, is_valid, ft); + + if (rbegin + batch.Size() == n_samples_total) { + // finished + CHECK(!std::isnan(sparse_thresh)); + this->columns_ = std::make_unique(*this, sparse_thresh); + } + } - // Create a global histogram matrix, given cut. Used by external memory - void Init(SparsePage const& page, common::Span ft, - common::HistogramCuts const& cuts, int32_t max_bins_per_feat, bool is_dense, - double sparse_thresh, int32_t n_threads); + // Call ColumnMatrix::PushBatch + template + void PushAdapterBatchColumns(Context const* ctx, Batch const& batch, float missing, + size_t rbegin); void ResizeIndex(const size_t n_index, const bool isDense); @@ -117,6 +216,8 @@ class GHistIndexMatrix { common::ColumnMatrix const& Transpose() const; + float GetFvalue(size_t ridx, size_t fidx, bool is_cat) const; + private: std::unique_ptr columns_; std::vector hit_count_tloc_; diff --git a/src/data/gradient_index_page_source.cc b/src/data/gradient_index_page_source.cc index 09d8ada8070b..6fa2f07e0ddd 100644 --- a/src/data/gradient_index_page_source.cc +++ b/src/data/gradient_index_page_source.cc @@ -15,10 +15,9 @@ void GradientIndexPageSource::Fetch() { // This is not read from cache so we still need it to be synced with sparse page source. CHECK_EQ(count_, source_->Iter()); auto const& csr = source_->Page(); - this->page_.reset(new GHistIndexMatrix()); CHECK_NE(cuts_.Values().size(), 0); - this->page_->Init(*csr, feature_types_, cuts_, max_bin_per_feat_, is_dense_, sparse_thresh_, - nthreads_); + this->page_.reset(new GHistIndexMatrix(*csr, feature_types_, cuts_, max_bin_per_feat_, + is_dense_, sparse_thresh_, nthreads_)); this->WriteCache(); } } diff --git a/tests/cpp/data/test_gradient_index.cc b/tests/cpp/data/test_gradient_index.cc index 8f5d7d3d759d..6e5d1312d98d 100644 --- a/tests/cpp/data/test_gradient_index.cc +++ b/tests/cpp/data/test_gradient_index.cc @@ -4,6 +4,7 @@ #include #include +#include "../../../src/common/column_matrix.h" #include "../../../src/data/gradient_index.h" #include "../helpers.h" @@ -65,5 +66,46 @@ TEST(GradientIndex, FromCategoricalBasic) { ASSERT_EQ(common::AsCat(x[i]), common::AsCat(bin_value)); } } + +TEST(GradientIndex, PushBatch) { + size_t constexpr kRows = 64, kCols = 4; + bst_bin_t max_bins = 64; + float st = 0.5; + + auto test = [&](float sparisty) { + auto m = RandomDataGenerator{kRows, kCols, sparisty}.GenerateDMatrix(true); + auto cuts = common::SketchOnDMatrix(m.get(), max_bins, common::OmpGetNumThreads(0), false, {}); + common::HistogramCuts copy_cuts = cuts; + + ASSERT_EQ(m->Info().num_row_, kRows); + ASSERT_EQ(m->Info().num_col_, kCols); + GHistIndexMatrix gmat{m->Info(), std::move(copy_cuts), max_bins}; + + for (auto const &page : m->GetBatches()) { + SparsePageAdapterBatch batch{page.GetView()}; + gmat.PushAdapterBatch(m->Ctx(), 0, 0, batch, std::numeric_limits::quiet_NaN(), {}, st, + m->Info().num_row_); + gmat.PushAdapterBatchColumns(m->Ctx(), batch, std::numeric_limits::quiet_NaN(), 0); + } + for (auto const &page : m->GetBatches(BatchParam{max_bins, st})) { + for (size_t i = 0; i < kRows; ++i) { + for (size_t j = 0; j < kCols; ++j) { + auto v0 = gmat.GetFvalue(i, j, false); + auto v1 = page.GetFvalue(i, j, false); + if (sparisty == 0.0) { + ASSERT_FALSE(std::isnan(v0)); + } + if (!std::isnan(v0)) { + ASSERT_EQ(v0, v1); + } + } + } + } + }; + + test(0.0f); + test(0.5f); + test(0.9f); +} } // namespace data } // namespace xgboost diff --git a/tests/cpp/predictor/test_predictor.cc b/tests/cpp/predictor/test_predictor.cc index 71eee8bb3bcd..34c4d48e6dc1 100644 --- a/tests/cpp/predictor/test_predictor.cc +++ b/tests/cpp/predictor/test_predictor.cc @@ -66,6 +66,14 @@ void TestTrainingPrediction(size_t rows, size_t bins, learner->UpdateOneIter(i, p_hist); } + Json model{Object{}}; + learner->SaveModel(&model); + + learner.reset(Learner::Create({})); + learner->LoadModel(model); + learner->SetParam("predictor", predictor); + learner->Configure(); + HostDeviceVector from_full; learner->Predict(p_full, false, &from_full, 0, 0); diff --git a/tests/cpp/tree/hist/test_histogram.cc b/tests/cpp/tree/hist/test_histogram.cc index 3ce6c9989db1..3c0728c38a37 100644 --- a/tests/cpp/tree/hist/test_histogram.cc +++ b/tests/cpp/tree/hist/test_histogram.cc @@ -419,9 +419,8 @@ void TestHistogramExternalMemory(BatchParam batch_param, bool is_approx) { auto cut = common::SketchOnDMatrix(m.get(), batch_param.max_bin, common::OmpGetNumThreads(0), false, hess); - GHistIndexMatrix gmat; - gmat.Init(concat, {}, cut, batch_param.max_bin, false, std::numeric_limits::quiet_NaN(), - common::OmpGetNumThreads(0)); + GHistIndexMatrix gmat(concat, {}, cut, batch_param.max_bin, false, + std::numeric_limits::quiet_NaN(), common::OmpGetNumThreads(0)); single_build.BuildHist(0, gmat, &tree, row_set_collection, nodes, {}, h_gpair); single_page = single_build.Histogram()[0]; } diff --git a/tests/cpp/tree/test_quantile_hist.cc b/tests/cpp/tree/test_quantile_hist.cc index e03933411cfd..c34f63b46b1d 100644 --- a/tests/cpp/tree/test_quantile_hist.cc +++ b/tests/cpp/tree/test_quantile_hist.cc @@ -34,8 +34,7 @@ TEST(QuantileHist, Partitioner) { auto cuts = common::SketchOnDMatrix(Xy.get(), 64, ctx.Threads()); for (auto const& page : Xy->GetBatches()) { - GHistIndexMatrix gmat; - gmat.Init(page, {}, cuts, 64, true, 0.5, ctx.Threads()); + GHistIndexMatrix gmat(page, {}, cuts, 64, true, 0.5, ctx.Threads()); bst_feature_t const split_ind = 0; common::ColumnMatrix column_indices; column_indices.Init(page, gmat, 0.5, ctx.Threads());