From c24e9d712cd509ee9b5bc61584acdac85b564968 Mon Sep 17 00:00:00 2001 From: Dmitry Razdoburdin Date: Thu, 6 Oct 2022 13:02:29 +0200 Subject: [PATCH] Dispatcher for template parameters of BuildHist Kernels (#8259) * Intoducing Column Wise Hist Building * linting * more linting * bug fixing * Removing column samping optimization for a while to simplify the review process. * linting * Removing unnecessary changes * Use DispatchBinType in hist_util.cc * Adding force_read_by column flag to buildhist. Adding tests for column wise buiilhist. * Introducing new dispatcher for compile time flags in hist building * fixing bug with using of DispatchBinType * Fixing building * Merging with master branch Co-authored-by: dmitry.razdoburdin Co-authored-by: Hyunsu Cho --- src/common/hist_util.cc | 178 +++++++++++++++++++++++++--------------- 1 file changed, 111 insertions(+), 67 deletions(-) diff --git a/src/common/hist_util.cc b/src/common/hist_util.cc index ab67bc92fee7..8df650d0219d 100644 --- a/src/common/hist_util.cc +++ b/src/common/hist_util.cc @@ -134,10 +134,72 @@ struct Prefetch { constexpr size_t Prefetch::kNoPrefetchSize; -template +struct RuntimeFlags { + const bool first_page; + const bool read_by_column; + const BinTypeSize bin_type_size; +}; + +template +class GHistBuildingManager { + public: + constexpr static bool kAnyMissing = _any_missing; + constexpr static bool kFirstPage = _first_page; + constexpr static bool kReadByColumn = _read_by_column; + using BinIdxType = _BinIdxType; + + private: + template + struct set_first_page { + using type = GHistBuildingManager; + }; + + template + struct set_read_by_column { + using type = GHistBuildingManager; + }; + + template + struct set_bin_idx_type { + using type = GHistBuildingManager; + }; + + using type = GHistBuildingManager; + + public: + /* Entry point to dispatcher + * This function check matching run time flags to compile time flags. + * In case of difference, it creates a Manager with different template parameters + * and forward the call there. + */ + template + static void DispatchAndExecute(const RuntimeFlags& flags, Fn&& fn) { + if (flags.first_page != kFirstPage) { + set_first_page::type::DispatchAndExecute(flags, std::forward(fn)); + } else if (flags.read_by_column != kReadByColumn) { + set_read_by_column::type::DispatchAndExecute(flags, std::forward(fn)); + } else if (flags.bin_type_size != sizeof(BinIdxType)) { + DispatchBinType(flags.bin_type_size, [&](auto t) { + using NewBinIdxType = decltype(t); + set_bin_idx_type::type::DispatchAndExecute(flags, std::forward(fn)); + }); + } else { + fn(type()); + } + } +}; + +template void RowsWiseBuildHistKernel(const std::vector &gpair, const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat, GHistRow hist) { + constexpr bool kAnyMissing = BuildingManager::kAnyMissing; + constexpr bool kFirstPage = BuildingManager::kFirstPage; + using BinIdxType = typename BuildingManager::BinIdxType; + const size_t size = row_indices.Size(); const size_t *rid = row_indices.begin; auto const *pgh = reinterpret_cast(gpair.data()); @@ -147,10 +209,10 @@ void RowsWiseBuildHistKernel(const std::vector &gpair, auto base_rowid = gmat.base_rowid; const uint32_t *offsets = gmat.index.Offset(); auto get_row_ptr = [&](size_t ridx) { - return first_page ? row_ptr[ridx] : row_ptr[ridx - base_rowid]; + return kFirstPage ? row_ptr[ridx] : row_ptr[ridx - base_rowid]; }; auto get_rid = [&](size_t ridx) { - return first_page ? ridx : (ridx - base_rowid); + return kFirstPage ? ridx : (ridx - base_rowid); }; const size_t n_features = @@ -163,20 +225,20 @@ void RowsWiseBuildHistKernel(const std::vector &gpair, for (size_t i = 0; i < size; ++i) { const size_t icol_start = - any_missing ? get_row_ptr(rid[i]) : get_rid(rid[i]) * n_features; + kAnyMissing ? get_row_ptr(rid[i]) : get_rid(rid[i]) * n_features; const size_t icol_end = - any_missing ? get_row_ptr(rid[i] + 1) : icol_start + n_features; + kAnyMissing ? get_row_ptr(rid[i] + 1) : icol_start + n_features; const size_t row_size = icol_end - icol_start; const size_t idx_gh = two * rid[i]; if (do_prefetch) { const size_t icol_start_prefetch = - any_missing + kAnyMissing ? get_row_ptr(rid[i + Prefetch::kPrefetchOffset]) : get_rid(rid[i + Prefetch::kPrefetchOffset]) * n_features; const size_t icol_end_prefetch = - any_missing ? get_row_ptr(rid[i + Prefetch::kPrefetchOffset] + 1) + kAnyMissing ? get_row_ptr(rid[i + Prefetch::kPrefetchOffset] + 1) : icol_start_prefetch + n_features; PREFETCH_READ_T0(pgh + two * rid[i + Prefetch::kPrefetchOffset]); @@ -191,7 +253,7 @@ void RowsWiseBuildHistKernel(const std::vector &gpair, const float pgh_t[] = {pgh[idx_gh], pgh[idx_gh + 1]}; for (size_t j = 0; j < row_size; ++j) { const uint32_t idx_bin = two * (static_cast(gr_index_local[j]) + - (any_missing ? 0 : offsets[j])); + (kAnyMissing ? 0 : offsets[j])); auto hist_local = hist_data + idx_bin; *(hist_local) += pgh_t[0]; *(hist_local + 1) += pgh_t[1]; @@ -199,10 +261,13 @@ void RowsWiseBuildHistKernel(const std::vector &gpair, } } -template +template void ColsWiseBuildHistKernel(const std::vector &gpair, const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat, GHistRow hist) { + constexpr bool kAnyMissing = BuildingManager::kAnyMissing; + constexpr bool kFirstPage = BuildingManager::kFirstPage; + using BinIdxType = typename BuildingManager::BinIdxType; const size_t size = row_indices.Size(); const size_t *rid = row_indices.begin; auto const *pgh = reinterpret_cast(gpair.data()); @@ -212,10 +277,10 @@ void ColsWiseBuildHistKernel(const std::vector &gpair, auto base_rowid = gmat.base_rowid; const uint32_t *offsets = gmat.index.Offset(); auto get_row_ptr = [&](size_t ridx) { - return first_page ? row_ptr[ridx] : row_ptr[ridx - base_rowid]; + return kFirstPage ? row_ptr[ridx] : row_ptr[ridx - base_rowid]; }; auto get_rid = [&](size_t ridx) { - return first_page ? ridx : (ridx - base_rowid); + return kFirstPage ? ridx : (ridx - base_rowid); }; const size_t n_features = gmat.cut.Ptrs().size() - 1; @@ -226,13 +291,13 @@ void ColsWiseBuildHistKernel(const std::vector &gpair, // So we need to multiply each row-index/bin-index by 2 // to work with gradient pairs as a singe row FP array for (size_t cid = 0; cid < n_columns; ++cid) { - const uint32_t offset = any_missing ? 0 : offsets[cid]; + const uint32_t offset = kAnyMissing ? 0 : offsets[cid]; for (size_t i = 0; i < size; ++i) { const size_t row_id = rid[i]; const size_t icol_start = - any_missing ? get_row_ptr(row_id) : get_rid(row_id) * n_features; + kAnyMissing ? get_row_ptr(row_id) : get_rid(row_id) * n_features; const size_t icol_end = - any_missing ? get_row_ptr(rid[i] + 1) : icol_start + n_features; + kAnyMissing ? get_row_ptr(rid[i] + 1) : icol_start + n_features; if (cid < icol_end - icol_start) { const BinIdxType *gr_index_local = gradient_index + icol_start; @@ -249,59 +314,32 @@ void ColsWiseBuildHistKernel(const std::vector &gpair, } } -template -void BuildHistKernel(const std::vector &gpair, - const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat, - GHistRow hist, bool read_by_column) { - if (read_by_column) { - ColsWiseBuildHistKernel - (gpair, row_indices, gmat, hist); - } else { - RowsWiseBuildHistKernel - (gpair, row_indices, gmat, hist); - } -} - -template +template void BuildHistDispatch(const std::vector &gpair, const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat, - GHistRow hist, bool read_by_column) { - auto first_page = gmat.base_rowid == 0; - DispatchBinType(gmat.index.GetBinTypeSize(), [&](auto t) { - using BinIdxType = decltype(t); - if (first_page) { - BuildHistKernel - (gpair, row_indices, gmat, hist, read_by_column); + GHistRow hist) { + if (BuildingManager::kReadByColumn) { + ColsWiseBuildHistKernel(gpair, row_indices, gmat, hist); + } else { + const size_t nrows = row_indices.Size(); + const size_t no_prefetch_size = Prefetch::NoPrefetchSize(nrows); + // if need to work with all rows from bin-matrix (e.g. root node) + const bool contiguousBlock = + (row_indices.begin[nrows - 1] - row_indices.begin[0]) == (nrows - 1); + + if (contiguousBlock) { + // contiguous memory access, built-in HW prefetching is enough + RowsWiseBuildHistKernel(gpair, row_indices, gmat, hist); } else { - BuildHistKernel - (gpair, row_indices, gmat, hist, read_by_column); + const RowSetCollection::Elem span1(row_indices.begin, + row_indices.end - no_prefetch_size); + const RowSetCollection::Elem span2(row_indices.end - no_prefetch_size, + row_indices.end); + + RowsWiseBuildHistKernel(gpair, span1, gmat, hist); + // no prefetching to avoid loading extra memory + RowsWiseBuildHistKernel(gpair, span2, gmat, hist); } - }); -} - -template -void BuildHistDispatch(const std::vector &gpair, - const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat, - GHistRow hist, bool read_by_column) { - const size_t nrows = row_indices.Size(); - const size_t no_prefetch_size = Prefetch::NoPrefetchSize(nrows); - // if need to work with all rows from bin-matrix (e.g. root node) - const bool contiguousBlock = - (row_indices.begin[nrows - 1] - row_indices.begin[0]) == (nrows - 1); - - if (contiguousBlock) { - // contiguous memory access, built-in HW prefetching is enough - BuildHistDispatch(gpair, row_indices, gmat, hist, read_by_column); - } else { - const RowSetCollection::Elem span1(row_indices.begin, - row_indices.end - no_prefetch_size); - const RowSetCollection::Elem span2(row_indices.end - no_prefetch_size, - row_indices.end); - - BuildHistDispatch(gpair, span1, gmat, hist, read_by_column); - // no prefetching to avoid loading extra memory - BuildHistDispatch(gpair, span2, gmat, hist, read_by_column); } } @@ -315,10 +353,16 @@ void GHistBuilder::BuildHist(const std::vector &gpair, */ constexpr double kAdhocL2Size = 1024 * 1024 * 0.8; const bool hist_fit_to_l2 = kAdhocL2Size > 2*sizeof(float)*gmat.cut.Ptrs().back(); - const bool read_by_column = !hist_fit_to_l2 && !any_missing; - - BuildHistDispatch(gpair, row_indices, gmat, hist, read_by_column || - force_read_by_column); + bool first_page = gmat.base_rowid == 0; + bool read_by_column = !hist_fit_to_l2 && !any_missing; + auto bin_type_size = gmat.index.GetBinTypeSize(); + + GHistBuildingManager::DispatchAndExecute( + {first_page, read_by_column || force_read_by_column, bin_type_size}, + [&](auto t) { + using BuildingManager = decltype(t); + BuildHistDispatch(gpair, row_indices, gmat, hist); + }); } template void GHistBuilder::BuildHist(const std::vector &gpair,