Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement iterative DMatrix for CPU. #8116

Merged
merged 1 commit into from Jul 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions amalgamation/xgboost-all0.cc
Expand Up @@ -43,6 +43,7 @@
#include "../src/data/gradient_index_format.cc"
#include "../src/data/sparse_page_dmatrix.cc"
#include "../src/data/proxy_dmatrix.cc"
#include "../src/data/iterative_dmatrix.cc"

// prediction
#include "../src/predictor/predictor.cc"
Expand Down
21 changes: 13 additions & 8 deletions include/xgboost/data.h
Expand Up @@ -559,6 +559,7 @@ class DMatrix {
*
* \param iter External data iterator
* \param proxy A hanlde to ProxyDMatrix
* \param ref Reference Quantile DMatrix.
* \param reset Callback for reset
* \param next Callback for next
* \param missing Value that should be treated as missing.
Expand All @@ -567,13 +568,11 @@ class DMatrix {
*
* \return A created quantile based DMatrix.
*/
template <typename DataIterHandle, typename DMatrixHandle,
typename DataIterResetCallback, typename XGDMatrixCallbackNext>
static DMatrix *Create(DataIterHandle iter, DMatrixHandle proxy,
DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, float missing,
int nthread,
int max_bin);
template <typename DataIterHandle, typename DMatrixHandle, typename DataIterResetCallback,
typename XGDMatrixCallbackNext>
static DMatrix* Create(DataIterHandle iter, DMatrixHandle proxy, std::shared_ptr<DMatrix> ref,
DataIterResetCallback* reset, XGDMatrixCallbackNext* next, float missing,
int nthread, bst_bin_t max_bin);

/**
* \brief Create an external memory DMatrix with callbacks.
Expand Down Expand Up @@ -613,6 +612,7 @@ class DMatrix {
virtual BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam& param) = 0;

virtual bool EllpackExists() const = 0;
virtual bool GHistIndexExists() const = 0;
virtual bool SparsePageExists() const = 0;
};

Expand All @@ -621,11 +621,16 @@ inline BatchSet<SparsePage> DMatrix::GetBatches() {
return GetRowBatches();
}

template<>
template <>
inline bool DMatrix::PageExists<EllpackPage>() const {
return this->EllpackExists();
}

template <>
inline bool DMatrix::PageExists<GHistIndexMatrix>() const {
return this->GHistIndexExists();
}

template<>
inline bool DMatrix::PageExists<SparsePage>() const {
return this->SparsePageExists();
Expand Down
11 changes: 6 additions & 5 deletions src/c_api/c_api.cc
Expand Up @@ -275,13 +275,14 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy
API_END();
}

XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallback(
DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, float missing, int nthread,
int max_bin, DMatrixHandle *out) {
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, float missing,
int nthread, int max_bin,
DMatrixHandle *out) {
API_BEGIN();
*out = new std::shared_ptr<xgboost::DMatrix>{
xgboost::DMatrix::Create(iter, proxy, reset, next, missing, nthread, max_bin)};
xgboost::DMatrix::Create(iter, proxy, nullptr, reset, next, missing, nthread, max_bin)};
API_END();
}

Expand Down
27 changes: 13 additions & 14 deletions src/data/data.cc
Expand Up @@ -931,15 +931,13 @@ DMatrix* DMatrix::Load(const std::string& uri,
}
return dmat;
}
template <typename DataIterHandle, typename DMatrixHandle,
typename DataIterResetCallback, typename XGDMatrixCallbackNext>
DMatrix *DMatrix::Create(DataIterHandle iter, DMatrixHandle proxy,
DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, float missing,
int nthread,
int max_bin) {
return new data::IterativeDMatrix(iter, proxy, reset, next, missing,
nthread, max_bin);

template <typename DataIterHandle, typename DMatrixHandle, typename DataIterResetCallback,
typename XGDMatrixCallbackNext>
DMatrix* DMatrix::Create(DataIterHandle iter, DMatrixHandle proxy, std::shared_ptr<DMatrix> ref,
DataIterResetCallback* reset, XGDMatrixCallbackNext* next, float missing,
int nthread, bst_bin_t max_bin) {
return new data::IterativeDMatrix(iter, proxy, ref, reset, next, missing, nthread, max_bin);
}

template <typename DataIterHandle, typename DMatrixHandle,
Expand All @@ -953,11 +951,12 @@ DMatrix *DMatrix::Create(DataIterHandle iter, DMatrixHandle proxy,
cache);
}

template DMatrix *DMatrix::Create<DataIterHandle, DMatrixHandle,
DataIterResetCallback, XGDMatrixCallbackNext>(
DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, float missing, int nthread,
int max_bin);
template DMatrix* DMatrix::Create<DataIterHandle, DMatrixHandle, DataIterResetCallback,
XGDMatrixCallbackNext>(DataIterHandle iter, DMatrixHandle proxy,
std::shared_ptr<DMatrix> ref,
DataIterResetCallback* reset,
XGDMatrixCallbackNext* next, float missing,
int nthread, int max_bin);

template DMatrix *DMatrix::Create<DataIterHandle, DMatrixHandle,
DataIterResetCallback, XGDMatrixCallbackNext>(
Expand Down
214 changes: 214 additions & 0 deletions src/data/iterative_dmatrix.cc
@@ -0,0 +1,214 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#include "iterative_dmatrix.h"

#include <rabit/rabit.h>

#include "../common/column_matrix.h"
#include "../common/hist_util.h"
#include "gradient_index.h"
#include "proxy_dmatrix.h"
#include "simple_batch_iterator.h"

namespace xgboost {
namespace data {

void GetCutsFromRef(std::shared_ptr<DMatrix> ref_, bst_feature_t n_features, BatchParam p,
common::HistogramCuts* p_cuts) {
CHECK(ref_);
CHECK(p_cuts);
auto csr = [&]() {
for (auto const& page : ref_->GetBatches<GHistIndexMatrix>(p)) {
*p_cuts = page.cut;
break;
}
};
auto ellpack = [&]() {
for (auto const& page : ref_->GetBatches<EllpackPage>(p)) {
GetCutsFromEllpack(page, p_cuts);
break;
}
};

if (ref_->PageExists<GHistIndexMatrix>()) {
csr();
} else if (ref_->PageExists<EllpackPage>()) {
ellpack();
} else {
if (p.gpu_id == Context::kCpuId) {
csr();
} else {
ellpack();
}
}
CHECK_EQ(ref_->Info().num_col_, n_features)
<< "Invalid ref DMatrix, different number of features.";
}

void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
std::shared_ptr<DMatrix> ref) {
DMatrixProxy* proxy = MakeProxy(proxy_);
CHECK(proxy);

// The external iterator
auto iter =
DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>{iter_handle, reset_, next_};
common::HistogramCuts cuts;

auto num_rows = [&]() {
return HostAdapterDispatch(proxy, [](auto const& value) { return value.Size(); });
};
auto num_cols = [&]() {
return HostAdapterDispatch(proxy, [](auto const& value) { return value.NumCols(); });
};

std::vector<size_t> column_sizes;
auto const is_valid = data::IsValidFunctor{missing};
auto nnz_cnt = [&]() {
return HostAdapterDispatch(proxy, [&](auto const& value) {
size_t n_threads = ctx_.Threads();
size_t n_features = column_sizes.size();
linalg::Tensor<size_t, 2> column_sizes_tloc({n_threads, n_features}, Context::kCpuId);
auto view = column_sizes_tloc.HostView();
common::ParallelFor(value.Size(), n_threads, common::Sched::Static(256), [&](auto i) {
auto const& line = value.GetLine(i);
for (size_t j = 0; j < line.Size(); ++j) {
data::COOTuple const& elem = line.GetElement(j);
if (is_valid(elem)) {
view(omp_get_thread_num(), elem.column_idx)++;
}
}
});
auto ptr = column_sizes_tloc.Data()->HostPointer();
auto result = std::accumulate(ptr, ptr + column_sizes_tloc.Size(), static_cast<size_t>(0));
for (size_t tidx = 0; tidx < n_threads; ++tidx) {
for (size_t fidx = 0; fidx < n_features; ++fidx) {
column_sizes[fidx] += view(tidx, fidx);
}
}
return result;
});
};

size_t n_features = 0;
size_t n_batches = 0;
size_t accumulated_rows{0};
size_t nnz{0};

/**
* CPU impl needs an additional loop for accumulating the column size.
*/
std::unique_ptr<common::HostSketchContainer> p_sketch;
std::vector<size_t> batch_nnz;
do {
// We use do while here as the first batch is fetched in ctor
if (n_features == 0) {
n_features = num_cols();
rabit::Allreduce<rabit::op::Max>(&n_features, 1);
column_sizes.resize(n_features);
info_.num_col_ = n_features;
} else {
CHECK_EQ(n_features, num_cols()) << "Inconsistent number of columns.";
}

size_t batch_size = num_rows();
batch_nnz.push_back(nnz_cnt());
nnz += batch_nnz.back();
accumulated_rows += batch_size;
n_batches++;
} while (iter.Next());
iter.Reset();

// From here on Info() has the correct data shape
Info().num_row_ = accumulated_rows;
Info().num_nonzero_ = nnz;
rabit::Allreduce<rabit::op::Max>(&info_.num_col_, 1);
CHECK(std::none_of(column_sizes.cbegin(), column_sizes.cend(), [&](auto f) {
return f > accumulated_rows;
})) << "Something went wrong during iteration.";

/**
* Generate quantiles
*/
accumulated_rows = 0;
if (ref) {
GetCutsFromRef(ref, Info().num_col_, batch_param_, &cuts);
} else {
size_t i = 0;
while (iter.Next()) {
if (!p_sketch) {
p_sketch.reset(new common::HostSketchContainer{batch_param_.max_bin,
proxy->Info().feature_types.ConstHostSpan(),
column_sizes, false, ctx_.Threads()});
}
HostAdapterDispatch(proxy, [&](auto const& batch) {
proxy->Info().num_nonzero_ = batch_nnz[i];
// We don't need base row idx here as Info is from proxy and the number of rows in
// it is consistent with data batch.
p_sketch->PushAdapterBatch(batch, 0, proxy->Info(), missing);
});
accumulated_rows += num_rows();
++i;
}
iter.Reset();
CHECK_EQ(accumulated_rows, Info().num_row_);

CHECK(p_sketch);
p_sketch->MakeCuts(&cuts);
}

/**
* Generate gradient index.
*/
this->ghist_ = std::make_unique<GHistIndexMatrix>(Info(), std::move(cuts), batch_param_.max_bin);
size_t rbegin = 0;
size_t prev_sum = 0;
size_t i = 0;
while (iter.Next()) {
HostAdapterDispatch(proxy, [&](auto const& batch) {
proxy->Info().num_nonzero_ = batch_nnz[i];
this->ghist_->PushAdapterBatch(&ctx_, rbegin, prev_sum, batch, missing,
proxy->Info().feature_types.ConstHostSpan(),
batch_param_.sparse_thresh, Info().num_row_);
});
if (n_batches != 1) {
this->info_.Extend(std::move(proxy->Info()), false, true);
}
size_t batch_size = num_rows();
prev_sum = this->ghist_->row_ptr[rbegin + batch_size];
rbegin += batch_size;
++i;
}
iter.Reset();
CHECK_EQ(rbegin, Info().num_row_);

/**
* Generate column matrix
*/
accumulated_rows = 0;
while (iter.Next()) {
HostAdapterDispatch(proxy, [&](auto const& batch) {
this->ghist_->PushAdapterBatchColumns(&ctx_, batch, missing, accumulated_rows);
});
accumulated_rows += num_rows();
}
iter.Reset();
CHECK_EQ(accumulated_rows, Info().num_row_);

if (n_batches == 1) {
this->info_ = std::move(proxy->Info());
this->info_.num_nonzero_ = nnz;
CHECK_EQ(proxy->Info().labels.Size(), 0);
}
}

BatchSet<GHistIndexMatrix> IterativeDMatrix::GetGradientIndex(BatchParam const& param) {
CheckParam(param);
CHECK(ghist_) << "Not initialized with CPU data";
auto begin_iter =
BatchIterator<GHistIndexMatrix>(new SimpleBatchIteratorImpl<GHistIndexMatrix>(ghist_));
return BatchSet<GHistIndexMatrix>(begin_iter);
}
} // namespace data
} // namespace xgboost