diff --git a/amalgamation/xgboost-all0.cc b/amalgamation/xgboost-all0.cc index c61135353ff2..16650d215af9 100644 --- a/amalgamation/xgboost-all0.cc +++ b/amalgamation/xgboost-all0.cc @@ -38,6 +38,7 @@ #include "../src/data/sparse_page_raw_format.cc" #include "../src/data/ellpack_page.cc" #include "../src/data/ellpack_page_source.cc" +#include "../src/data/gradient_index.cc" // prediction #include "../src/predictor/predictor.cc" diff --git a/include/xgboost/data.h b/include/xgboost/data.h index 27aa81577ed5..ada292e42f27 100644 --- a/include/xgboost/data.h +++ b/include/xgboost/data.h @@ -385,6 +385,8 @@ class EllpackPage { std::unique_ptr impl_; }; +class GHistIndexMatrix; + template class BatchIteratorImpl { public: @@ -553,6 +555,7 @@ class DMatrix { virtual BatchSet GetColumnBatches() = 0; virtual BatchSet GetSortedColumnBatches() = 0; virtual BatchSet GetEllpackBatches(const BatchParam& param) = 0; + virtual BatchSet GetGradientIndex(const BatchParam& param) = 0; virtual bool EllpackExists() const = 0; virtual bool SparsePageExists() const = 0; @@ -587,6 +590,11 @@ template<> inline BatchSet DMatrix::GetBatches(const BatchParam& param) { return GetEllpackBatches(param); } + +template<> +inline BatchSet DMatrix::GetBatches(const BatchParam& param) { + return GetGradientIndex(param); +} } // namespace xgboost namespace dmlc { diff --git a/src/common/column_matrix.h b/src/common/column_matrix.h index 56cc89c05998..804d7a568116 100644 --- a/src/common/column_matrix.h +++ b/src/common/column_matrix.h @@ -12,6 +12,7 @@ #include #include #include "hist_util.h" +#include "../data/gradient_index.h" namespace xgboost { namespace common { @@ -262,9 +263,10 @@ class ColumnMatrix { return res; } - template - inline void SetIndexAllDense(T* index, const GHistIndexMatrix& gmat, const size_t nrow, - const size_t nfeature, const bool noMissingValues) { + template + inline void SetIndexAllDense(T *index, const GHistIndexMatrix &gmat, + const size_t nrow, const size_t nfeature, + const bool noMissingValues) { T* local_index = reinterpret_cast(&index_[0]); /* missing values make sense only for column with type kDenseColumn, diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index edb22b613c26..78754a6b7c95 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -16,6 +16,7 @@ #include "column_matrix.h" #include "quantile.h" #include "./../tree/updater_quantile_hist.h" +#include "../data/gradient_index.h" #if defined(XGBOOST_MM_PREFETCH_PRESENT) #include @@ -29,164 +30,10 @@ namespace xgboost { namespace common { -void GHistIndexMatrix::ResizeIndex(const size_t n_index, - const bool isDense) { - if ((max_num_bins - 1 <= static_cast(std::numeric_limits::max())) && isDense) { - index.SetBinTypeSize(kUint8BinsTypeSize); - index.Resize((sizeof(uint8_t)) * n_index); - } else if ((max_num_bins - 1 > static_cast(std::numeric_limits::max()) && - max_num_bins - 1 <= static_cast(std::numeric_limits::max())) && isDense) { - index.SetBinTypeSize(kUint16BinsTypeSize); - index.Resize((sizeof(uint16_t)) * n_index); - } else { - index.SetBinTypeSize(kUint32BinsTypeSize); - index.Resize((sizeof(uint32_t)) * n_index); - } -} - HistogramCuts::HistogramCuts() { cut_ptrs_.HostVector().emplace_back(0); } -void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_bins) { - cut = SketchOnDMatrix(p_fmat, max_bins); - - max_num_bins = max_bins; - const int32_t nthread = omp_get_max_threads(); - const uint32_t nbins = cut.Ptrs().back(); - hit_count.resize(nbins, 0); - hit_count_tloc_.resize(nthread * nbins, 0); - - this->p_fmat = p_fmat; - size_t new_size = 1; - for (const auto &batch : p_fmat->GetBatches()) { - new_size += batch.Size(); - } - - row_ptr.resize(new_size); - row_ptr[0] = 0; - - size_t rbegin = 0; - size_t prev_sum = 0; - const bool isDense = p_fmat->IsDense(); - this->isDense_ = isDense; - - for (const auto &batch : p_fmat->GetBatches()) { - // The number of threads is pegged to the batch size. If the OMP - // block is parallelized on anything other than the batch/block size, - // it should be reassigned - const size_t batch_threads = std::max( - size_t(1), - std::min(batch.Size(), static_cast(omp_get_max_threads()))); - auto page = batch.GetView(); - MemStackAllocator partial_sums(batch_threads); - size_t* p_part = partial_sums.Get(); - - size_t block_size = batch.Size() / batch_threads; - - dmlc::OMPException exc; - #pragma omp parallel num_threads(batch_threads) - { - #pragma omp for - for (omp_ulong tid = 0; tid < batch_threads; ++tid) { - exc.Run([&]() { - size_t ibegin = block_size * tid; - size_t iend = (tid == (batch_threads-1) ? batch.Size() : (block_size * (tid+1))); - - size_t sum = 0; - for (size_t i = ibegin; i < iend; ++i) { - sum += page[i].size(); - row_ptr[rbegin + 1 + i] = sum; - } - }); - } - - #pragma omp single - { - exc.Run([&]() { - p_part[0] = prev_sum; - for (size_t i = 1; i < batch_threads; ++i) { - p_part[i] = p_part[i - 1] + row_ptr[rbegin + i*block_size]; - } - }); - } - - #pragma omp for - for (omp_ulong tid = 0; tid < batch_threads; ++tid) { - exc.Run([&]() { - size_t ibegin = block_size * tid; - size_t iend = (tid == (batch_threads-1) ? batch.Size() : (block_size * (tid+1))); - - for (size_t i = ibegin; i < iend; ++i) { - row_ptr[rbegin + 1 + i] += p_part[tid]; - } - }); - } - } - exc.Rethrow(); - - const size_t n_offsets = cut.Ptrs().size() - 1; - const size_t n_index = row_ptr[rbegin + batch.Size()]; - ResizeIndex(n_index, isDense); - - CHECK_GT(cut.Values().size(), 0U); - - uint32_t* offsets = nullptr; - if (isDense) { - index.ResizeOffset(n_offsets); - offsets = index.Offset(); - for (size_t i = 0; i < n_offsets; ++i) { - offsets[i] = cut.Ptrs()[i]; - } - } - - if (isDense) { - BinTypeSize curent_bin_size = index.GetBinTypeSize(); - if (curent_bin_size == kUint8BinsTypeSize) { - common::Span index_data_span = {index.data(), - n_index}; - SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins, - [offsets](auto idx, auto j) { - return static_cast(idx - offsets[j]); - }); - - } else if (curent_bin_size == kUint16BinsTypeSize) { - common::Span index_data_span = {index.data(), - n_index}; - SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins, - [offsets](auto idx, auto j) { - return static_cast(idx - offsets[j]); - }); - } else { - CHECK_EQ(curent_bin_size, kUint32BinsTypeSize); - common::Span index_data_span = {index.data(), - n_index}; - SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins, - [offsets](auto idx, auto j) { - return static_cast(idx - offsets[j]); - }); - } - - /* For sparse DMatrix we have to store index of feature for each bin - in index field to chose right offset. So offset is nullptr and index is not reduced */ - } else { - common::Span index_data_span = {index.data(), n_index}; - SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins, - [](auto idx, auto) { return idx; }); - } - - ParallelFor(bst_omp_uint(nbins), nthread, [&](bst_omp_uint idx) { - for (int32_t tid = 0; tid < nthread; ++tid) { - hit_count[idx] += hit_count_tloc_[tid * nbins + idx]; - hit_count_tloc_[tid * nbins + idx] = 0; // reset for next batch - } - }); - - prev_sum = row_ptr[rbegin + batch.Size()]; - rbegin += batch.Size(); - } -} - /*! * \brief fill a histogram by zeros in range [begin, end) */ @@ -289,9 +136,9 @@ constexpr size_t Prefetch::kNoPrefetchSize; template void BuildHistKernel(const std::vector& gpair, - const RowSetCollection::Elem row_indices, - const GHistIndexMatrix& gmat, - GHistRow hist) { + const RowSetCollection::Elem row_indices, + const GHistIndexMatrix& gmat, + GHistRow hist) { const size_t size = row_indices.Size(); const size_t* rid = row_indices.begin; const float* pgh = reinterpret_cast(gpair.data()); @@ -337,8 +184,8 @@ void BuildHistKernel(const std::vector& gpair, template void BuildHistDispatch(const std::vector& gpair, - const RowSetCollection::Elem row_indices, - const GHistIndexMatrix& gmat, GHistRow hist) { + const RowSetCollection::Elem row_indices, + const GHistIndexMatrix& gmat, GHistRow hist) { switch (gmat.index.GetBinTypeSize()) { case kUint8BinsTypeSize: BuildHistKernel(gpair, row_indices, @@ -382,26 +229,26 @@ void GHistBuilder::BuildHist( BuildHistDispatch(gpair, span2, gmat, hist); } } -template -void GHistBuilder::BuildHist(const std::vector& gpair, - const RowSetCollection::Elem row_indices, - const GHistIndexMatrix& gmat, - GHistRow hist); -template -void GHistBuilder::BuildHist(const std::vector& gpair, - const RowSetCollection::Elem row_indices, - const GHistIndexMatrix& gmat, - GHistRow hist); -template -void GHistBuilder::BuildHist(const std::vector& gpair, - const RowSetCollection::Elem row_indices, - const GHistIndexMatrix& gmat, - GHistRow hist); -template -void GHistBuilder::BuildHist(const std::vector& gpair, - const RowSetCollection::Elem row_indices, - const GHistIndexMatrix& gmat, - GHistRow hist); +template void +GHistBuilder::BuildHist(const std::vector &gpair, + const RowSetCollection::Elem row_indices, + const GHistIndexMatrix &gmat, + GHistRow hist); +template void +GHistBuilder::BuildHist(const std::vector &gpair, + const RowSetCollection::Elem row_indices, + const GHistIndexMatrix &gmat, + GHistRow hist); +template void +GHistBuilder::BuildHist(const std::vector &gpair, + const RowSetCollection::Elem row_indices, + const GHistIndexMatrix &gmat, + GHistRow hist); +template void +GHistBuilder::BuildHist(const std::vector &gpair, + const RowSetCollection::Elem row_indices, + const GHistIndexMatrix &gmat, + GHistRow hist); template void GHistBuilder::SubtractionTrick(GHistRowT self, diff --git a/src/common/hist_util.h b/src/common/hist_util.h index b4af15ae147e..041faf2a1b53 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -25,6 +25,8 @@ #include "../include/rabit/rabit.h" namespace xgboost { +class GHistIndexMatrix; + namespace common { /*! * \brief A single row in global histogram index. @@ -226,74 +228,6 @@ struct Index { Func func_; }; - -/*! - * \brief preprocessed global index matrix, in CSR format - * - * Transform floating values to integer index in histogram This is a global histogram - * index for CPU histogram. On GPU ellpack page is used. - */ -struct GHistIndexMatrix { - /*! \brief row pointer to rows by element position */ - std::vector row_ptr; - /*! \brief The index data */ - Index index; - /*! \brief hit count of each index */ - std::vector hit_count; - /*! \brief The corresponding cuts */ - HistogramCuts cut; - DMatrix* p_fmat; - size_t max_num_bins; - // Create a global histogram matrix, given cut - void Init(DMatrix* p_fmat, int max_num_bins); - - // specific method for sparse data as no possibility to reduce allocated memory - template - void SetIndexData(common::Span index_data_span, - size_t batch_threads, const SparsePage &batch, - size_t rbegin, size_t nbins, GetOffset get_offset) { - const xgboost::Entry *data_ptr = batch.data.HostVector().data(); - const std::vector &offset_vec = batch.offset.HostVector(); - const size_t batch_size = batch.Size(); - CHECK_LT(batch_size, offset_vec.size()); - BinIdxType* index_data = index_data_span.data(); - ParallelFor(omp_ulong(batch_size), batch_threads, [&](omp_ulong i) { - const int tid = omp_get_thread_num(); - size_t ibegin = row_ptr[rbegin + i]; - size_t iend = row_ptr[rbegin + i + 1]; - const size_t size = offset_vec[i + 1] - offset_vec[i]; - SparsePage::Inst inst = {data_ptr + offset_vec[i], size}; - CHECK_EQ(ibegin + inst.size(), iend); - for (bst_uint j = 0; j < inst.size(); ++j) { - uint32_t idx = cut.SearchBin(inst[j]); - index_data[ibegin + j] = get_offset(idx, j); - ++hit_count_tloc_[tid * nbins + idx]; - } - }); - } - - void ResizeIndex(const size_t n_index, - const bool isDense); - - inline void GetFeatureCounts(size_t* counts) const { - auto nfeature = cut.Ptrs().size() - 1; - for (unsigned fid = 0; fid < nfeature; ++fid) { - auto ibegin = cut.Ptrs()[fid]; - auto iend = cut.Ptrs()[fid + 1]; - for (auto i = ibegin; i < iend; ++i) { - counts[fid] += hit_count[i]; - } - } - } - inline bool IsDense() const { - return isDense_; - } - - private: - std::vector hit_count_tloc_; - bool isDense_; -}; - template int32_t XGBOOST_HOST_DEV_INLINE BinarySearchBin(size_t begin, size_t end, GradientIndex const &data, @@ -647,6 +581,42 @@ class GHistBuilder { /*! \brief number of all bins over all features */ uint32_t nbins_ { 0 }; }; + +/*! + * \brief A C-style array with in-stack allocation. As long as the array is smaller than + * MaxStackSize, it will be allocated inside the stack. Otherwise, it will be + * heap-allocated. + */ +template +class MemStackAllocator { + public: + explicit MemStackAllocator(size_t required_size): required_size_(required_size) { + } + + T* Get() { + if (!ptr_) { + if (MaxStackSize >= required_size_) { + ptr_ = stack_mem_; + } else { + ptr_ = reinterpret_cast(malloc(required_size_ * sizeof(T))); + do_free_ = true; + } + } + + return ptr_; + } + + ~MemStackAllocator() { + if (do_free_) free(ptr_); + } + + + private: + T* ptr_ = nullptr; + bool do_free_ = false; + size_t required_size_; + T stack_mem_[MaxStackSize]; +}; } // namespace common } // namespace xgboost #endif // XGBOOST_COMMON_HIST_UTIL_H_ diff --git a/src/data/gradient_index.cc b/src/data/gradient_index.cc new file mode 100644 index 000000000000..7836d2a19469 --- /dev/null +++ b/src/data/gradient_index.cc @@ -0,0 +1,165 @@ +/*! + * Copyright 2017-2021 by Contributors + * \brief Data type for fast histogram aggregation. + */ +#include +#include +#include "gradient_index.h" +#include "../common/hist_util.h" + +namespace xgboost { +void GHistIndexMatrix::Init(DMatrix* p_fmat, int max_bins) { + cut = common::SketchOnDMatrix(p_fmat, max_bins); + + max_num_bins = max_bins; + const int32_t nthread = omp_get_max_threads(); + const uint32_t nbins = cut.Ptrs().back(); + hit_count.resize(nbins, 0); + hit_count_tloc_.resize(nthread * nbins, 0); + + this->p_fmat = p_fmat; + size_t new_size = 1; + for (const auto &batch : p_fmat->GetBatches()) { + new_size += batch.Size(); + } + + row_ptr.resize(new_size); + row_ptr[0] = 0; + + size_t rbegin = 0; + size_t prev_sum = 0; + const bool isDense = p_fmat->IsDense(); + this->isDense_ = isDense; + + for (const auto &batch : p_fmat->GetBatches()) { + // The number of threads is pegged to the batch size. If the OMP + // block is parallelized on anything other than the batch/block size, + // it should be reassigned + const size_t batch_threads = std::max( + size_t(1), + std::min(batch.Size(), static_cast(omp_get_max_threads()))); + auto page = batch.GetView(); + common::MemStackAllocator partial_sums(batch_threads); + size_t* p_part = partial_sums.Get(); + + size_t block_size = batch.Size() / batch_threads; + + dmlc::OMPException exc; + #pragma omp parallel num_threads(batch_threads) + { + #pragma omp for + for (omp_ulong tid = 0; tid < batch_threads; ++tid) { + exc.Run([&]() { + size_t ibegin = block_size * tid; + size_t iend = (tid == (batch_threads-1) ? batch.Size() : (block_size * (tid+1))); + + size_t sum = 0; + for (size_t i = ibegin; i < iend; ++i) { + sum += page[i].size(); + row_ptr[rbegin + 1 + i] = sum; + } + }); + } + + #pragma omp single + { + exc.Run([&]() { + p_part[0] = prev_sum; + for (size_t i = 1; i < batch_threads; ++i) { + p_part[i] = p_part[i - 1] + row_ptr[rbegin + i*block_size]; + } + }); + } + + #pragma omp for + for (omp_ulong tid = 0; tid < batch_threads; ++tid) { + exc.Run([&]() { + size_t ibegin = block_size * tid; + size_t iend = (tid == (batch_threads-1) ? batch.Size() : (block_size * (tid+1))); + + for (size_t i = ibegin; i < iend; ++i) { + row_ptr[rbegin + 1 + i] += p_part[tid]; + } + }); + } + } + exc.Rethrow(); + + const size_t n_offsets = cut.Ptrs().size() - 1; + const size_t n_index = row_ptr[rbegin + batch.Size()]; + ResizeIndex(n_index, isDense); + + CHECK_GT(cut.Values().size(), 0U); + + uint32_t* offsets = nullptr; + if (isDense) { + index.ResizeOffset(n_offsets); + offsets = index.Offset(); + for (size_t i = 0; i < n_offsets; ++i) { + offsets[i] = cut.Ptrs()[i]; + } + } + + if (isDense) { + common::BinTypeSize curent_bin_size = index.GetBinTypeSize(); + if (curent_bin_size == common::kUint8BinsTypeSize) { + common::Span index_data_span = {index.data(), + n_index}; + SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins, + [offsets](auto idx, auto j) { + return static_cast(idx - offsets[j]); + }); + + } else if (curent_bin_size == common::kUint16BinsTypeSize) { + common::Span index_data_span = {index.data(), + n_index}; + SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins, + [offsets](auto idx, auto j) { + return static_cast(idx - offsets[j]); + }); + } else { + CHECK_EQ(curent_bin_size, common::kUint32BinsTypeSize); + common::Span index_data_span = {index.data(), + n_index}; + SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins, + [offsets](auto idx, auto j) { + return static_cast(idx - offsets[j]); + }); + } + + /* For sparse DMatrix we have to store index of feature for each bin + in index field to chose right offset. So offset is nullptr and index is not reduced */ + } else { + common::Span index_data_span = {index.data(), n_index}; + SetIndexData(index_data_span, batch_threads, batch, rbegin, nbins, + [](auto idx, auto) { return idx; }); + } + + common::ParallelFor(bst_omp_uint(nbins), nthread, [&](bst_omp_uint idx) { + for (int32_t tid = 0; tid < nthread; ++tid) { + hit_count[idx] += hit_count_tloc_[tid * nbins + idx]; + hit_count_tloc_[tid * nbins + idx] = 0; // reset for next batch + } + }); + + prev_sum = row_ptr[rbegin + batch.Size()]; + rbegin += batch.Size(); + } +} + + +void GHistIndexMatrix::ResizeIndex(const size_t n_index, + const bool isDense) { + if ((max_num_bins - 1 <= static_cast(std::numeric_limits::max())) && isDense) { + index.SetBinTypeSize(common::kUint8BinsTypeSize); + index.Resize((sizeof(uint8_t)) * n_index); + } else if ((max_num_bins - 1 > static_cast(std::numeric_limits::max()) && + max_num_bins - 1 <= static_cast(std::numeric_limits::max())) && isDense) { + index.SetBinTypeSize(common::kUint16BinsTypeSize); + index.Resize((sizeof(uint16_t)) * n_index); + } else { + index.SetBinTypeSize(common::kUint32BinsTypeSize); + index.Resize((sizeof(uint32_t)) * n_index); + } +} +} // namespace xgboost diff --git a/src/data/gradient_index.h b/src/data/gradient_index.h new file mode 100644 index 000000000000..d42f596bc896 --- /dev/null +++ b/src/data/gradient_index.h @@ -0,0 +1,86 @@ +/*! + * Copyright 2017-2021 by Contributors + * \brief Data type for fast histogram aggregation. + */ +#ifndef XGBOOST_DATA_GRADIENT_INDEX_H_ +#define XGBOOST_DATA_GRADIENT_INDEX_H_ +#include +#include "xgboost/base.h" +#include "xgboost/data.h" +#include "../common/hist_util.h" +#include "../common/threading_utils.h" + +namespace xgboost { +/*! + * \brief preprocessed global index matrix, in CSR format + * + * Transform floating values to integer index in histogram This is a global histogram + * index for CPU histogram. On GPU ellpack page is used. + */ +class GHistIndexMatrix { + public: + /*! \brief row pointer to rows by element position */ + std::vector row_ptr; + /*! \brief The index data */ + common::Index index; + /*! \brief hit count of each index */ + std::vector hit_count; + /*! \brief The corresponding cuts */ + common::HistogramCuts cut; + DMatrix* p_fmat; + size_t max_num_bins; + + GHistIndexMatrix(DMatrix* x, int32_t max_bin) { + this->Init(x, max_bin); + } + // Create a global histogram matrix, given cut + void Init(DMatrix* p_fmat, int max_num_bins); + + // specific method for sparse data as no possibility to reduce allocated memory + template + void SetIndexData(common::Span index_data_span, + size_t batch_threads, const SparsePage &batch, + size_t rbegin, size_t nbins, GetOffset get_offset) { + const xgboost::Entry *data_ptr = batch.data.HostVector().data(); + const std::vector &offset_vec = batch.offset.HostVector(); + const size_t batch_size = batch.Size(); + CHECK_LT(batch_size, offset_vec.size()); + BinIdxType* index_data = index_data_span.data(); + common::ParallelFor(omp_ulong(batch_size), batch_threads, [&](omp_ulong i) { + const int tid = omp_get_thread_num(); + size_t ibegin = row_ptr[rbegin + i]; + size_t iend = row_ptr[rbegin + i + 1]; + const size_t size = offset_vec[i + 1] - offset_vec[i]; + SparsePage::Inst inst = {data_ptr + offset_vec[i], size}; + CHECK_EQ(ibegin + inst.size(), iend); + for (bst_uint j = 0; j < inst.size(); ++j) { + uint32_t idx = cut.SearchBin(inst[j]); + index_data[ibegin + j] = get_offset(idx, j); + ++hit_count_tloc_[tid * nbins + idx]; + } + }); + } + + void ResizeIndex(const size_t n_index, + const bool isDense); + + inline void GetFeatureCounts(size_t* counts) const { + auto nfeature = cut.Ptrs().size() - 1; + for (unsigned fid = 0; fid < nfeature; ++fid) { + auto ibegin = cut.Ptrs()[fid]; + auto iend = cut.Ptrs()[fid + 1]; + for (auto i = ibegin; i < iend; ++i) { + counts[fid] += hit_count[i]; + } + } + } + inline bool IsDense() const { + return isDense_; + } + + private: + std::vector hit_count_tloc_; + bool isDense_; +}; +} // namespace xgboost +#endif // XGBOOST_DATA_GRADIENT_INDEX_H_ diff --git a/src/data/iterative_device_dmatrix.h b/src/data/iterative_device_dmatrix.h index bc73d7c85f4d..a9923552cc21 100644 --- a/src/data/iterative_device_dmatrix.h +++ b/src/data/iterative_device_dmatrix.h @@ -58,6 +58,10 @@ class IterativeDeviceDMatrix : public DMatrix { LOG(FATAL) << "Not implemented."; return BatchSet(BatchIterator(nullptr)); } + BatchSet GetGradientIndex(const BatchParam&) override { + LOG(FATAL) << "Not implemented."; + return BatchSet(BatchIterator(nullptr)); + } BatchSet GetEllpackBatches(const BatchParam& param) override; diff --git a/src/data/proxy_dmatrix.h b/src/data/proxy_dmatrix.h index 8d4ae8777a19..cb5cc6e02219 100644 --- a/src/data/proxy_dmatrix.h +++ b/src/data/proxy_dmatrix.h @@ -97,6 +97,10 @@ class DMatrixProxy : public DMatrix { LOG(FATAL) << "Not implemented."; return BatchSet(BatchIterator(nullptr)); } + BatchSet GetGradientIndex(const BatchParam&) override { + LOG(FATAL) << "Not implemented."; + return BatchSet(BatchIterator(nullptr)); + } dmlc::any Adapter() const { return batch_; diff --git a/src/data/simple_dmatrix.cc b/src/data/simple_dmatrix.cc index 85e38a52d70f..bcab52e48898 100644 --- a/src/data/simple_dmatrix.cc +++ b/src/data/simple_dmatrix.cc @@ -17,6 +17,7 @@ #include "../common/random.h" #include "../common/threading_utils.h" #include "adapter.h" +#include "gradient_index.h" namespace xgboost { namespace data { @@ -89,6 +90,20 @@ BatchSet SimpleDMatrix::GetEllpackBatches(const BatchParam& param) return BatchSet(begin_iter); } +BatchSet SimpleDMatrix::GetGradientIndex(const BatchParam& param) { + if (!(batch_param_ != BatchParam{})) { + CHECK(param != BatchParam{}) << "Batch parameter is not initialized."; + } + if (!gradient_index_ || (batch_param_ != param && param != BatchParam{})) { + CHECK_GE(param.max_bin, 2); + gradient_index_.reset(new GHistIndexMatrix(this, param.max_bin)); + batch_param_ = param; + } + auto begin_iter = BatchIterator( + new SimpleBatchIteratorImpl(gradient_index_.get())); + return BatchSet(begin_iter); +} + template SimpleDMatrix::SimpleDMatrix(AdapterT* adapter, float missing, int nthread) { std::vector qids; diff --git a/src/data/simple_dmatrix.h b/src/data/simple_dmatrix.h index 9d2130b4195e..aa555b212af2 100644 --- a/src/data/simple_dmatrix.h +++ b/src/data/simple_dmatrix.h @@ -13,6 +13,7 @@ #include #include +#include "gradient_index.h" namespace xgboost { namespace data { @@ -43,12 +44,14 @@ class SimpleDMatrix : public DMatrix { BatchSet GetColumnBatches() override; BatchSet GetSortedColumnBatches() override; BatchSet GetEllpackBatches(const BatchParam& param) override; + BatchSet GetGradientIndex(const BatchParam& param) override; MetaInfo info_; SparsePage sparse_page_; // Primary storage type std::unique_ptr column_page_; std::unique_ptr sorted_column_page_; std::unique_ptr ellpack_page_; + std::unique_ptr gradient_index_; BatchParam batch_param_; bool EllpackExists() const override { diff --git a/src/data/sparse_page_dmatrix.h b/src/data/sparse_page_dmatrix.h index 00cbdbfbe82a..a4eaeed712ff 100644 --- a/src/data/sparse_page_dmatrix.h +++ b/src/data/sparse_page_dmatrix.h @@ -47,6 +47,10 @@ class SparsePageDMatrix : public DMatrix { BatchSet GetColumnBatches() override; BatchSet GetSortedColumnBatches() override; BatchSet GetEllpackBatches(const BatchParam& param) override; + BatchSet GetGradientIndex(const BatchParam&) override { + LOG(FATAL) << "Not implemented."; + return BatchSet(BatchIterator(nullptr)); + } // source data pointers. std::unique_ptr row_source_; diff --git a/src/tree/updater_quantile_hist.cc b/src/tree/updater_quantile_hist.cc index 58c0e7bbfe71..e430d95f5329 100644 --- a/src/tree/updater_quantile_hist.cc +++ b/src/tree/updater_quantile_hist.cc @@ -69,18 +69,22 @@ template void QuantileHistMaker::CallBuilderUpdate(const std::unique_ptr>& builder, HostDeviceVector *gpair, DMatrix *dmat, + GHistIndexMatrix const& gmat, const std::vector &trees) { for (auto tree : trees) { - builder->Update(gmat_, column_matrix_, gpair, dmat, tree); + builder->Update(gmat, column_matrix_, gpair, dmat, tree); } } void QuantileHistMaker::Update(HostDeviceVector *gpair, DMatrix *dmat, const std::vector &trees) { + auto const &gmat = + *(dmat->GetBatches( + BatchParam{GenericParameter::kCpuId, param_.max_bin}) + .begin()); if (dmat != p_last_dmat_ || is_gmat_initialized_ == false) { updater_monitor_.Start("GmatInitialization"); - gmat_.Init(dmat, static_cast(param_.max_bin)); - column_matrix_.Init(gmat_, param_.sparse_threshold); + column_matrix_.Init(gmat, param_.sparse_threshold); updater_monitor_.Stop("GmatInitialization"); // A proper solution is puting cut matrix in DMatrix, see: // https://github.com/dmlc/xgboost/issues/5143 @@ -96,12 +100,12 @@ void QuantileHistMaker::Update(HostDeviceVector *gpair, if (!float_builder_) { SetBuilder(n_trees, &float_builder_, dmat); } - CallBuilderUpdate(float_builder_, gpair, dmat, trees); + CallBuilderUpdate(float_builder_, gpair, dmat, gmat, trees); } else { if (!double_builder_) { SetBuilder(n_trees, &double_builder_, dmat); } - CallBuilderUpdate(double_builder_, gpair, dmat, trees); + CallBuilderUpdate(double_builder_, gpair, dmat, gmat, trees); } param_.learning_rate = lr; @@ -678,7 +682,7 @@ void QuantileHistMaker::Builder::InitData(const GHistIndexMatrix& // We should check that the partitioning was done correctly // and each row of the dataset fell into exactly one of the categories } - MemStackAllocator buff(this->nthread_); + common::MemStackAllocator buff(this->nthread_); bool* p_buff = buff.Get(); std::fill(p_buff, p_buff + this->nthread_, false); diff --git a/src/tree/updater_quantile_hist.h b/src/tree/updater_quantile_hist.h index 0ce106222881..6c55a5bb1af2 100644 --- a/src/tree/updater_quantile_hist.h +++ b/src/tree/updater_quantile_hist.h @@ -75,43 +75,9 @@ struct RandomReplace { } }; -/*! - * \brief A C-style array with in-stack allocation. As long as the array is smaller than MaxStackSize, it will be allocated inside the stack. Otherwise, it will be heap-allocated. - */ -template -class MemStackAllocator { - public: - explicit MemStackAllocator(size_t required_size): required_size_(required_size) { - } - - T* Get() { - if (!ptr_) { - if (MaxStackSize >= required_size_) { - ptr_ = stack_mem_; - } else { - ptr_ = reinterpret_cast(malloc(required_size_ * sizeof(T))); - do_free_ = true; - } - } - - return ptr_; - } - - ~MemStackAllocator() { - if (do_free_) free(ptr_); - } - - - private: - T* ptr_ = nullptr; - bool do_free_ = false; - size_t required_size_; - T stack_mem_[MaxStackSize]; -}; - namespace tree { -using xgboost::common::GHistIndexMatrix; +using xgboost::GHistIndexMatrix; using xgboost::common::GHistIndexRow; using xgboost::common::HistCollection; using xgboost::common::RowSetCollection; @@ -243,8 +209,6 @@ class QuantileHistMaker: public TreeUpdater { CPUHistMakerTrainParam hist_maker_param_; // training parameter TrainParam param_; - // quantized data matrix - GHistIndexMatrix gmat_; // column accessor ColumnMatrix column_matrix_; DMatrix const* p_last_dmat_ {nullptr}; @@ -466,6 +430,7 @@ class QuantileHistMaker: public TreeUpdater { void CallBuilderUpdate(const std::unique_ptr>& builder, HostDeviceVector *gpair, DMatrix *dmat, + GHistIndexMatrix const& gmat, const std::vector &trees); protected: diff --git a/tests/cpp/common/test_column_matrix.cc b/tests/cpp/common/test_column_matrix.cc index 2cb88855420c..c3eaf100865b 100644 --- a/tests/cpp/common/test_column_matrix.cc +++ b/tests/cpp/common/test_column_matrix.cc @@ -14,8 +14,7 @@ TEST(DenseColumn, Test) { static_cast(std::numeric_limits::max()) + 2}; for (size_t max_num_bin : max_num_bins) { auto dmat = RandomDataGenerator(100, 10, 0.0).GenerateDMatrix(); - GHistIndexMatrix gmat; - gmat.Init(dmat.get(), max_num_bin); + GHistIndexMatrix gmat(dmat.get(), max_num_bin); ColumnMatrix column_matrix; column_matrix.Init(gmat, 0.2); @@ -62,8 +61,7 @@ TEST(SparseColumn, Test) { static_cast(std::numeric_limits::max()) + 2}; for (size_t max_num_bin : max_num_bins) { auto dmat = RandomDataGenerator(100, 1, 0.85).GenerateDMatrix(); - GHistIndexMatrix gmat; - gmat.Init(dmat.get(), max_num_bin); + GHistIndexMatrix gmat(dmat.get(), max_num_bin); ColumnMatrix column_matrix; column_matrix.Init(gmat, 0.5); switch (column_matrix.GetTypeSize()) { @@ -103,8 +101,7 @@ TEST(DenseColumnWithMissing, Test) { static_cast(std::numeric_limits::max()) + 2 }; for (size_t max_num_bin : max_num_bins) { auto dmat = RandomDataGenerator(100, 1, 0.5).GenerateDMatrix(); - GHistIndexMatrix gmat; - gmat.Init(dmat.get(), max_num_bin); + GHistIndexMatrix gmat(dmat.get(), max_num_bin); ColumnMatrix column_matrix; column_matrix.Init(gmat, 0.2); switch (column_matrix.GetTypeSize()) { @@ -135,8 +132,7 @@ void TestGHistIndexMatrixCreation(size_t nthreads) { /* This should create multiple sparse pages */ std::unique_ptr dmat{ CreateSparsePageDMatrix(kEntries, kPageSize, filename) }; omp_set_num_threads(nthreads); - GHistIndexMatrix gmat; - gmat.Init(dmat.get(), 256); + GHistIndexMatrix gmat(dmat.get(), 256); } TEST(HistIndexCreationWithExternalMemory, Test) { diff --git a/tests/cpp/common/test_hist_util.cc b/tests/cpp/common/test_hist_util.cc index 31bbb38f9c52..5a467fc316a4 100644 --- a/tests/cpp/common/test_hist_util.cc +++ b/tests/cpp/common/test_hist_util.cc @@ -4,6 +4,7 @@ #include #include "../../../src/common/hist_util.h" +#include "../../../src/data/gradient_index.h" #include "../helpers.h" #include "test_hist_util.h" @@ -255,8 +256,7 @@ TEST(HistUtil, IndexBinBound) { for (auto max_bin : bin_sizes) { auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); - common::GHistIndexMatrix hmat; - hmat.Init(p_fmat.get(), max_bin); + GHistIndexMatrix hmat(p_fmat.get(), max_bin); EXPECT_EQ(hmat.index.Size(), kRows*kCols); EXPECT_EQ(expected_bin_type_sizes[bin_id++], hmat.index.GetBinTypeSize()); } @@ -264,7 +264,7 @@ TEST(HistUtil, IndexBinBound) { template void CheckIndexData(T* data_ptr, uint32_t* offsets, - const common::GHistIndexMatrix& hmat, size_t n_cols) { + const GHistIndexMatrix& hmat, size_t n_cols) { for (size_t i = 0; i < hmat.index.Size(); ++i) { EXPECT_EQ(data_ptr[i] + offsets[i % n_cols], hmat.index[i]); } @@ -279,8 +279,7 @@ TEST(HistUtil, IndexBinData) { for (auto max_bin : kBinSizes) { auto p_fmat = RandomDataGenerator(kRows, kCols, 0).GenerateDMatrix(); - common::GHistIndexMatrix hmat; - hmat.Init(p_fmat.get(), max_bin); + GHistIndexMatrix hmat(p_fmat.get(), max_bin); uint32_t* offsets = hmat.index.Offset(); EXPECT_EQ(hmat.index.Size(), kRows*kCols); switch (max_bin) { diff --git a/tests/cpp/tree/test_quantile_hist.cc b/tests/cpp/tree/test_quantile_hist.cc index 79ebfc0c1e6a..33772025735f 100644 --- a/tests/cpp/tree/test_quantile_hist.cc +++ b/tests/cpp/tree/test_quantile_hist.cc @@ -344,8 +344,7 @@ class QuantileHistMock : public QuantileHistMaker { auto dmat = RandomDataGenerator(kNRows, kNCols, 0).Seed(3).GenerateDMatrix(); // dense, no missing values - common::GHistIndexMatrix gmat; - gmat.Init(dmat.get(), kMaxBins); + GHistIndexMatrix gmat(dmat.get(), kMaxBins); RealImpl::InitData(gmat, *dmat, tree, &row_gpairs); this->hist_.AddHistRow(0); @@ -434,8 +433,7 @@ class QuantileHistMock : public QuantileHistMaker { // kNRows samples with kNCols features auto dmat = RandomDataGenerator(kNRows, kNCols, sparsity).Seed(3).GenerateDMatrix(); - common::GHistIndexMatrix gmat; - gmat.Init(dmat.get(), kMaxBins); + GHistIndexMatrix gmat(dmat.get(), kMaxBins); ColumnMatrix cm; // treat everything as dense, as this is what we intend to test here @@ -546,8 +544,7 @@ class QuantileHistMock : public QuantileHistMaker { void TestInitData() { size_t constexpr kMaxBins = 4; - common::GHistIndexMatrix gmat; - gmat.Init(dmat_.get(), kMaxBins); + GHistIndexMatrix gmat(dmat_.get(), kMaxBins); RegTree tree = RegTree(); tree.param.UpdateAllowUnknown(cfg_); @@ -564,8 +561,7 @@ class QuantileHistMock : public QuantileHistMaker { void TestInitDataSampling() { size_t constexpr kMaxBins = 4; - common::GHistIndexMatrix gmat; - gmat.Init(dmat_.get(), kMaxBins); + GHistIndexMatrix gmat(dmat_.get(), kMaxBins); RegTree tree = RegTree(); tree.param.UpdateAllowUnknown(cfg_); @@ -582,8 +578,7 @@ class QuantileHistMock : public QuantileHistMaker { void TestAddHistRows() { size_t constexpr kMaxBins = 4; - common::GHistIndexMatrix gmat; - gmat.Init(dmat_.get(), kMaxBins); + GHistIndexMatrix gmat(dmat_.get(), kMaxBins); RegTree tree = RegTree(); tree.param.UpdateAllowUnknown(cfg_); @@ -599,8 +594,7 @@ class QuantileHistMock : public QuantileHistMaker { void TestSyncHistograms() { size_t constexpr kMaxBins = 4; - common::GHistIndexMatrix gmat; - gmat.Init(dmat_.get(), kMaxBins); + GHistIndexMatrix gmat(dmat_.get(), kMaxBins); RegTree tree = RegTree(); tree.param.UpdateAllowUnknown(cfg_); @@ -620,8 +614,7 @@ class QuantileHistMock : public QuantileHistMaker { tree.param.UpdateAllowUnknown(cfg_); size_t constexpr kMaxBins = 4; - common::GHistIndexMatrix gmat; - gmat.Init(dmat_.get(), kMaxBins); + GHistIndexMatrix gmat(dmat_.get(), kMaxBins); if (double_builder_) { double_builder_->TestBuildHist(0, gmat, *dmat_, tree); } else {