Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dispatcher for template parameters of BuildHist Kernels #8259

Merged
merged 18 commits into from Oct 6, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
178 changes: 111 additions & 67 deletions src/common/hist_util.cc
Expand Up @@ -139,10 +139,72 @@ struct Prefetch {

constexpr size_t Prefetch::kNoPrefetchSize;

template <bool do_prefetch, typename BinIdxType, bool first_page, bool any_missing = true>
struct RuntimeFlags {
const bool first_page;
const bool read_by_column;
const BinTypeSize bin_type_size;
};

template <bool _any_missing,
bool _first_page = false,
bool _read_by_column = false,
typename _BinIdxType = uint8_t>
class GHistBuildingManager {
public:
constexpr static bool kAnyMissing = _any_missing;
constexpr static bool kFirstPage = _first_page;
constexpr static bool kReadByColumn = _read_by_column;
using BinIdxType = _BinIdxType;

private:
template<bool new_first_page>
struct set_first_page {
using type = GHistBuildingManager<kAnyMissing, new_first_page, kReadByColumn, BinIdxType>;
};

template<bool new_read_by_column>
struct set_read_by_column {
using type = GHistBuildingManager<kAnyMissing, kFirstPage, new_read_by_column, BinIdxType>;
};

template<typename new_bin_idx_type>
struct set_bin_idx_type {
using type = GHistBuildingManager<kAnyMissing, kFirstPage, kReadByColumn, new_bin_idx_type>;
};

using type = GHistBuildingManager<kAnyMissing, kFirstPage, kReadByColumn, BinIdxType>;

public:
/* Entry point to dispatcher
* This function check matching run time flags to compile time flags.
* In case of difference, it creates a Manager with different template parameters
* and forward the call there.
*/
template <typename Fn>
static void DispatchAndExecute(const RuntimeFlags& flags, Fn&& fn) {
if (flags.first_page != kFirstPage) {
set_first_page<true>::type::DispatchAndExecute(flags, std::forward<Fn>(fn));
} else if (flags.read_by_column != kReadByColumn) {
set_read_by_column<true>::type::DispatchAndExecute(flags, std::forward<Fn>(fn));
} else if (flags.bin_type_size != sizeof(BinIdxType)) {
DispatchBinType(flags.bin_type_size, [&](auto t) {
using NewBinIdxType = decltype(t);
set_bin_idx_type<NewBinIdxType>::type::DispatchAndExecute(flags, std::forward<Fn>(fn));
});
} else {
fn(type());
}
}
};

template <bool do_prefetch, class BuildingManager>
void RowsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
GHistRow hist) {
constexpr bool kAnyMissing = BuildingManager::kAnyMissing;
constexpr bool kFirstPage = BuildingManager::kFirstPage;
using BinIdxType = typename BuildingManager::BinIdxType;

const size_t size = row_indices.Size();
const size_t *rid = row_indices.begin;
auto const *pgh = reinterpret_cast<const float *>(gpair.data());
Expand All @@ -152,10 +214,10 @@ void RowsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
auto base_rowid = gmat.base_rowid;
const uint32_t *offsets = gmat.index.Offset();
auto get_row_ptr = [&](size_t ridx) {
return first_page ? row_ptr[ridx] : row_ptr[ridx - base_rowid];
return kFirstPage ? row_ptr[ridx] : row_ptr[ridx - base_rowid];
};
auto get_rid = [&](size_t ridx) {
return first_page ? ridx : (ridx - base_rowid);
return kFirstPage ? ridx : (ridx - base_rowid);
};

const size_t n_features =
Expand All @@ -168,20 +230,20 @@ void RowsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,

for (size_t i = 0; i < size; ++i) {
const size_t icol_start =
any_missing ? get_row_ptr(rid[i]) : get_rid(rid[i]) * n_features;
kAnyMissing ? get_row_ptr(rid[i]) : get_rid(rid[i]) * n_features;
const size_t icol_end =
any_missing ? get_row_ptr(rid[i] + 1) : icol_start + n_features;
kAnyMissing ? get_row_ptr(rid[i] + 1) : icol_start + n_features;

const size_t row_size = icol_end - icol_start;
const size_t idx_gh = two * rid[i];

if (do_prefetch) {
const size_t icol_start_prefetch =
any_missing
kAnyMissing
? get_row_ptr(rid[i + Prefetch::kPrefetchOffset])
: get_rid(rid[i + Prefetch::kPrefetchOffset]) * n_features;
const size_t icol_end_prefetch =
any_missing ? get_row_ptr(rid[i + Prefetch::kPrefetchOffset] + 1)
kAnyMissing ? get_row_ptr(rid[i + Prefetch::kPrefetchOffset] + 1)
: icol_start_prefetch + n_features;

PREFETCH_READ_T0(pgh + two * rid[i + Prefetch::kPrefetchOffset]);
Expand All @@ -196,18 +258,21 @@ void RowsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
const float pgh_t[] = {pgh[idx_gh], pgh[idx_gh + 1]};
for (size_t j = 0; j < row_size; ++j) {
const uint32_t idx_bin = two * (static_cast<uint32_t>(gr_index_local[j]) +
(any_missing ? 0 : offsets[j]));
(kAnyMissing ? 0 : offsets[j]));
auto hist_local = hist_data + idx_bin;
*(hist_local) += pgh_t[0];
*(hist_local + 1) += pgh_t[1];
}
}
}

template <typename BinIdxType, bool first_page, bool any_missing>
template <class BuildingManager>
void ColsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
GHistRow hist) {
constexpr bool kAnyMissing = BuildingManager::kAnyMissing;
constexpr bool kFirstPage = BuildingManager::kFirstPage;
using BinIdxType = typename BuildingManager::BinIdxType;
const size_t size = row_indices.Size();
const size_t *rid = row_indices.begin;
auto const *pgh = reinterpret_cast<const float *>(gpair.data());
Expand All @@ -217,10 +282,10 @@ void ColsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
auto base_rowid = gmat.base_rowid;
const uint32_t *offsets = gmat.index.Offset();
auto get_row_ptr = [&](size_t ridx) {
return first_page ? row_ptr[ridx] : row_ptr[ridx - base_rowid];
return kFirstPage ? row_ptr[ridx] : row_ptr[ridx - base_rowid];
};
auto get_rid = [&](size_t ridx) {
return first_page ? ridx : (ridx - base_rowid);
return kFirstPage ? ridx : (ridx - base_rowid);
};

const size_t n_features = gmat.cut.Ptrs().size() - 1;
Expand All @@ -231,13 +296,13 @@ void ColsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
// So we need to multiply each row-index/bin-index by 2
// to work with gradient pairs as a singe row FP array
for (size_t cid = 0; cid < n_columns; ++cid) {
const uint32_t offset = any_missing ? 0 : offsets[cid];
const uint32_t offset = kAnyMissing ? 0 : offsets[cid];
for (size_t i = 0; i < size; ++i) {
const size_t row_id = rid[i];
const size_t icol_start =
any_missing ? get_row_ptr(row_id) : get_rid(row_id) * n_features;
kAnyMissing ? get_row_ptr(row_id) : get_rid(row_id) * n_features;
const size_t icol_end =
any_missing ? get_row_ptr(rid[i] + 1) : icol_start + n_features;
kAnyMissing ? get_row_ptr(rid[i] + 1) : icol_start + n_features;

if (cid < icol_end - icol_start) {
const BinIdxType *gr_index_local = gradient_index + icol_start;
Expand All @@ -254,59 +319,32 @@ void ColsWiseBuildHistKernel(const std::vector<GradientPair> &gpair,
}
}

template <bool do_prefetch, typename BinIdxType, bool first_page,
bool any_missing>
void BuildHistKernel(const std::vector<GradientPair> &gpair,
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
GHistRow hist, bool read_by_column) {
if (read_by_column) {
ColsWiseBuildHistKernel<BinIdxType, first_page, any_missing>
(gpair, row_indices, gmat, hist);
} else {
RowsWiseBuildHistKernel<do_prefetch, BinIdxType, first_page, any_missing>
(gpair, row_indices, gmat, hist);
}
}

template <bool do_prefetch, bool any_missing>
template <class BuildingManager>
void BuildHistDispatch(const std::vector<GradientPair> &gpair,
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
GHistRow hist, bool read_by_column) {
auto first_page = gmat.base_rowid == 0;
DispatchBinType(gmat.index.GetBinTypeSize(), [&](auto t) {
using BinIdxType = decltype(t);
if (first_page) {
BuildHistKernel<do_prefetch, BinIdxType, true, any_missing>
(gpair, row_indices, gmat, hist, read_by_column);
GHistRow hist) {
if (BuildingManager::kReadByColumn) {
ColsWiseBuildHistKernel<BuildingManager>(gpair, row_indices, gmat, hist);
} else {
const size_t nrows = row_indices.Size();
const size_t no_prefetch_size = Prefetch::NoPrefetchSize(nrows);
// if need to work with all rows from bin-matrix (e.g. root node)
const bool contiguousBlock =
(row_indices.begin[nrows - 1] - row_indices.begin[0]) == (nrows - 1);

if (contiguousBlock) {
// contiguous memory access, built-in HW prefetching is enough
RowsWiseBuildHistKernel<false, BuildingManager>(gpair, row_indices, gmat, hist);
} else {
BuildHistKernel<do_prefetch, BinIdxType, false, any_missing>
(gpair, row_indices, gmat, hist, read_by_column);
const RowSetCollection::Elem span1(row_indices.begin,
row_indices.end - no_prefetch_size);
const RowSetCollection::Elem span2(row_indices.end - no_prefetch_size,
row_indices.end);

RowsWiseBuildHistKernel<true, BuildingManager>(gpair, span1, gmat, hist);
// no prefetching to avoid loading extra memory
RowsWiseBuildHistKernel<false, BuildingManager>(gpair, span2, gmat, hist);
}
});
}

template <bool any_missing>
void BuildHistDispatch(const std::vector<GradientPair> &gpair,
const RowSetCollection::Elem row_indices, const GHistIndexMatrix &gmat,
GHistRow hist, bool read_by_column) {
const size_t nrows = row_indices.Size();
const size_t no_prefetch_size = Prefetch::NoPrefetchSize(nrows);
// if need to work with all rows from bin-matrix (e.g. root node)
const bool contiguousBlock =
(row_indices.begin[nrows - 1] - row_indices.begin[0]) == (nrows - 1);

if (contiguousBlock) {
// contiguous memory access, built-in HW prefetching is enough
BuildHistDispatch<false, any_missing>(gpair, row_indices, gmat, hist, read_by_column);
} else {
const RowSetCollection::Elem span1(row_indices.begin,
row_indices.end - no_prefetch_size);
const RowSetCollection::Elem span2(row_indices.end - no_prefetch_size,
row_indices.end);

BuildHistDispatch<true, any_missing>(gpair, span1, gmat, hist, read_by_column);
// no prefetching to avoid loading extra memory
BuildHistDispatch<false, any_missing>(gpair, span2, gmat, hist, read_by_column);
}
}

Expand All @@ -320,10 +358,16 @@ void GHistBuilder::BuildHist(const std::vector<GradientPair> &gpair,
*/
constexpr double kAdhocL2Size = 1024 * 1024 * 0.8;
const bool hist_fit_to_l2 = kAdhocL2Size > 2*sizeof(float)*gmat.cut.Ptrs().back();
const bool read_by_column = !hist_fit_to_l2 && !any_missing;

BuildHistDispatch<any_missing>(gpair, row_indices, gmat, hist, read_by_column ||
force_read_by_column);
bool first_page = gmat.base_rowid == 0;
bool read_by_column = !hist_fit_to_l2 && !any_missing;
auto bin_type_size = gmat.index.GetBinTypeSize();

GHistBuildingManager<any_missing>::DispatchAndExecute(
{first_page, read_by_column || force_read_by_column, bin_type_size},
[&](auto t) {
using BuildingManager = decltype(t);
BuildHistDispatch<BuildingManager>(gpair, row_indices, gmat, hist);
});
}

template void GHistBuilder::BuildHist<true>(const std::vector<GradientPair> &gpair,
Expand Down