Skip to content

Commit

Permalink
Implement iterative DMatrix for CPU. (#8116)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jul 26, 2022
1 parent 546de5e commit 2c70751
Show file tree
Hide file tree
Showing 20 changed files with 634 additions and 188 deletions.
1 change: 1 addition & 0 deletions amalgamation/xgboost-all0.cc
Expand Up @@ -43,6 +43,7 @@
#include "../src/data/gradient_index_format.cc"
#include "../src/data/sparse_page_dmatrix.cc"
#include "../src/data/proxy_dmatrix.cc"
#include "../src/data/iterative_dmatrix.cc"

// prediction
#include "../src/predictor/predictor.cc"
Expand Down
21 changes: 13 additions & 8 deletions include/xgboost/data.h
Expand Up @@ -559,6 +559,7 @@ class DMatrix {
*
* \param iter External data iterator
* \param proxy A hanlde to ProxyDMatrix
* \param ref Reference Quantile DMatrix.
* \param reset Callback for reset
* \param next Callback for next
* \param missing Value that should be treated as missing.
Expand All @@ -567,13 +568,11 @@ class DMatrix {
*
* \return A created quantile based DMatrix.
*/
template <typename DataIterHandle, typename DMatrixHandle,
typename DataIterResetCallback, typename XGDMatrixCallbackNext>
static DMatrix *Create(DataIterHandle iter, DMatrixHandle proxy,
DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, float missing,
int nthread,
int max_bin);
template <typename DataIterHandle, typename DMatrixHandle, typename DataIterResetCallback,
typename XGDMatrixCallbackNext>
static DMatrix* Create(DataIterHandle iter, DMatrixHandle proxy, std::shared_ptr<DMatrix> ref,
DataIterResetCallback* reset, XGDMatrixCallbackNext* next, float missing,
int nthread, bst_bin_t max_bin);

/**
* \brief Create an external memory DMatrix with callbacks.
Expand Down Expand Up @@ -613,6 +612,7 @@ class DMatrix {
virtual BatchSet<GHistIndexMatrix> GetGradientIndex(const BatchParam& param) = 0;

virtual bool EllpackExists() const = 0;
virtual bool GHistIndexExists() const = 0;
virtual bool SparsePageExists() const = 0;
};

Expand All @@ -621,11 +621,16 @@ inline BatchSet<SparsePage> DMatrix::GetBatches() {
return GetRowBatches();
}

template<>
template <>
inline bool DMatrix::PageExists<EllpackPage>() const {
return this->EllpackExists();
}

template <>
inline bool DMatrix::PageExists<GHistIndexMatrix>() const {
return this->GHistIndexExists();
}

template<>
inline bool DMatrix::PageExists<SparsePage>() const {
return this->SparsePageExists();
Expand Down
11 changes: 6 additions & 5 deletions src/c_api/c_api.cc
Expand Up @@ -275,13 +275,14 @@ XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy
API_END();
}

XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallback(
DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, float missing, int nthread,
int max_bin, DMatrixHandle *out) {
XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallback(DataIterHandle iter, DMatrixHandle proxy,
DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, float missing,
int nthread, int max_bin,
DMatrixHandle *out) {
API_BEGIN();
*out = new std::shared_ptr<xgboost::DMatrix>{
xgboost::DMatrix::Create(iter, proxy, reset, next, missing, nthread, max_bin)};
xgboost::DMatrix::Create(iter, proxy, nullptr, reset, next, missing, nthread, max_bin)};
API_END();
}

Expand Down
27 changes: 13 additions & 14 deletions src/data/data.cc
Expand Up @@ -931,15 +931,13 @@ DMatrix* DMatrix::Load(const std::string& uri,
}
return dmat;
}
template <typename DataIterHandle, typename DMatrixHandle,
typename DataIterResetCallback, typename XGDMatrixCallbackNext>
DMatrix *DMatrix::Create(DataIterHandle iter, DMatrixHandle proxy,
DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, float missing,
int nthread,
int max_bin) {
return new data::IterativeDMatrix(iter, proxy, reset, next, missing,
nthread, max_bin);

template <typename DataIterHandle, typename DMatrixHandle, typename DataIterResetCallback,
typename XGDMatrixCallbackNext>
DMatrix* DMatrix::Create(DataIterHandle iter, DMatrixHandle proxy, std::shared_ptr<DMatrix> ref,
DataIterResetCallback* reset, XGDMatrixCallbackNext* next, float missing,
int nthread, bst_bin_t max_bin) {
return new data::IterativeDMatrix(iter, proxy, ref, reset, next, missing, nthread, max_bin);
}

template <typename DataIterHandle, typename DMatrixHandle,
Expand All @@ -953,11 +951,12 @@ DMatrix *DMatrix::Create(DataIterHandle iter, DMatrixHandle proxy,
cache);
}

template DMatrix *DMatrix::Create<DataIterHandle, DMatrixHandle,
DataIterResetCallback, XGDMatrixCallbackNext>(
DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, float missing, int nthread,
int max_bin);
template DMatrix* DMatrix::Create<DataIterHandle, DMatrixHandle, DataIterResetCallback,
XGDMatrixCallbackNext>(DataIterHandle iter, DMatrixHandle proxy,
std::shared_ptr<DMatrix> ref,
DataIterResetCallback* reset,
XGDMatrixCallbackNext* next, float missing,
int nthread, int max_bin);

template DMatrix *DMatrix::Create<DataIterHandle, DMatrixHandle,
DataIterResetCallback, XGDMatrixCallbackNext>(
Expand Down
214 changes: 214 additions & 0 deletions src/data/iterative_dmatrix.cc
@@ -0,0 +1,214 @@
/*!
* Copyright 2022 XGBoost contributors
*/
#include "iterative_dmatrix.h"

#include <rabit/rabit.h>

#include "../common/column_matrix.h"
#include "../common/hist_util.h"
#include "gradient_index.h"
#include "proxy_dmatrix.h"
#include "simple_batch_iterator.h"

namespace xgboost {
namespace data {

void GetCutsFromRef(std::shared_ptr<DMatrix> ref_, bst_feature_t n_features, BatchParam p,
common::HistogramCuts* p_cuts) {
CHECK(ref_);
CHECK(p_cuts);
auto csr = [&]() {
for (auto const& page : ref_->GetBatches<GHistIndexMatrix>(p)) {
*p_cuts = page.cut;
break;
}
};
auto ellpack = [&]() {
for (auto const& page : ref_->GetBatches<EllpackPage>(p)) {
GetCutsFromEllpack(page, p_cuts);
break;
}
};

if (ref_->PageExists<GHistIndexMatrix>()) {
csr();
} else if (ref_->PageExists<EllpackPage>()) {
ellpack();
} else {
if (p.gpu_id == Context::kCpuId) {
csr();
} else {
ellpack();
}
}
CHECK_EQ(ref_->Info().num_col_, n_features)
<< "Invalid ref DMatrix, different number of features.";
}

void IterativeDMatrix::InitFromCPU(DataIterHandle iter_handle, float missing,
std::shared_ptr<DMatrix> ref) {
DMatrixProxy* proxy = MakeProxy(proxy_);
CHECK(proxy);

// The external iterator
auto iter =
DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext>{iter_handle, reset_, next_};
common::HistogramCuts cuts;

auto num_rows = [&]() {
return HostAdapterDispatch(proxy, [](auto const& value) { return value.Size(); });
};
auto num_cols = [&]() {
return HostAdapterDispatch(proxy, [](auto const& value) { return value.NumCols(); });
};

std::vector<size_t> column_sizes;
auto const is_valid = data::IsValidFunctor{missing};
auto nnz_cnt = [&]() {
return HostAdapterDispatch(proxy, [&](auto const& value) {
size_t n_threads = ctx_.Threads();
size_t n_features = column_sizes.size();
linalg::Tensor<size_t, 2> column_sizes_tloc({n_threads, n_features}, Context::kCpuId);
auto view = column_sizes_tloc.HostView();
common::ParallelFor(value.Size(), n_threads, common::Sched::Static(256), [&](auto i) {
auto const& line = value.GetLine(i);
for (size_t j = 0; j < line.Size(); ++j) {
data::COOTuple const& elem = line.GetElement(j);
if (is_valid(elem)) {
view(omp_get_thread_num(), elem.column_idx)++;
}
}
});
auto ptr = column_sizes_tloc.Data()->HostPointer();
auto result = std::accumulate(ptr, ptr + column_sizes_tloc.Size(), static_cast<size_t>(0));
for (size_t tidx = 0; tidx < n_threads; ++tidx) {
for (size_t fidx = 0; fidx < n_features; ++fidx) {
column_sizes[fidx] += view(tidx, fidx);
}
}
return result;
});
};

size_t n_features = 0;
size_t n_batches = 0;
size_t accumulated_rows{0};
size_t nnz{0};

/**
* CPU impl needs an additional loop for accumulating the column size.
*/
std::unique_ptr<common::HostSketchContainer> p_sketch;
std::vector<size_t> batch_nnz;
do {
// We use do while here as the first batch is fetched in ctor
if (n_features == 0) {
n_features = num_cols();
rabit::Allreduce<rabit::op::Max>(&n_features, 1);
column_sizes.resize(n_features);
info_.num_col_ = n_features;
} else {
CHECK_EQ(n_features, num_cols()) << "Inconsistent number of columns.";
}

size_t batch_size = num_rows();
batch_nnz.push_back(nnz_cnt());
nnz += batch_nnz.back();
accumulated_rows += batch_size;
n_batches++;
} while (iter.Next());
iter.Reset();

// From here on Info() has the correct data shape
Info().num_row_ = accumulated_rows;
Info().num_nonzero_ = nnz;
rabit::Allreduce<rabit::op::Max>(&info_.num_col_, 1);
CHECK(std::none_of(column_sizes.cbegin(), column_sizes.cend(), [&](auto f) {
return f > accumulated_rows;
})) << "Something went wrong during iteration.";

/**
* Generate quantiles
*/
accumulated_rows = 0;
if (ref) {
GetCutsFromRef(ref, Info().num_col_, batch_param_, &cuts);
} else {
size_t i = 0;
while (iter.Next()) {
if (!p_sketch) {
p_sketch.reset(new common::HostSketchContainer{batch_param_.max_bin,
proxy->Info().feature_types.ConstHostSpan(),
column_sizes, false, ctx_.Threads()});
}
HostAdapterDispatch(proxy, [&](auto const& batch) {
proxy->Info().num_nonzero_ = batch_nnz[i];
// We don't need base row idx here as Info is from proxy and the number of rows in
// it is consistent with data batch.
p_sketch->PushAdapterBatch(batch, 0, proxy->Info(), missing);
});
accumulated_rows += num_rows();
++i;
}
iter.Reset();
CHECK_EQ(accumulated_rows, Info().num_row_);

CHECK(p_sketch);
p_sketch->MakeCuts(&cuts);
}

/**
* Generate gradient index.
*/
this->ghist_ = std::make_unique<GHistIndexMatrix>(Info(), std::move(cuts), batch_param_.max_bin);
size_t rbegin = 0;
size_t prev_sum = 0;
size_t i = 0;
while (iter.Next()) {
HostAdapterDispatch(proxy, [&](auto const& batch) {
proxy->Info().num_nonzero_ = batch_nnz[i];
this->ghist_->PushAdapterBatch(&ctx_, rbegin, prev_sum, batch, missing,
proxy->Info().feature_types.ConstHostSpan(),
batch_param_.sparse_thresh, Info().num_row_);
});
if (n_batches != 1) {
this->info_.Extend(std::move(proxy->Info()), false, true);
}
size_t batch_size = num_rows();
prev_sum = this->ghist_->row_ptr[rbegin + batch_size];
rbegin += batch_size;
++i;
}
iter.Reset();
CHECK_EQ(rbegin, Info().num_row_);

/**
* Generate column matrix
*/
accumulated_rows = 0;
while (iter.Next()) {
HostAdapterDispatch(proxy, [&](auto const& batch) {
this->ghist_->PushAdapterBatchColumns(&ctx_, batch, missing, accumulated_rows);
});
accumulated_rows += num_rows();
}
iter.Reset();
CHECK_EQ(accumulated_rows, Info().num_row_);

if (n_batches == 1) {
this->info_ = std::move(proxy->Info());
this->info_.num_nonzero_ = nnz;
CHECK_EQ(proxy->Info().labels.Size(), 0);
}
}

BatchSet<GHistIndexMatrix> IterativeDMatrix::GetGradientIndex(BatchParam const& param) {
CheckParam(param);
CHECK(ghist_) << "Not initialized with CPU data";
auto begin_iter =
BatchIterator<GHistIndexMatrix>(new SimpleBatchIteratorImpl<GHistIndexMatrix>(ghist_));
return BatchSet<GHistIndexMatrix>(begin_iter);
}
} // namespace data
} // namespace xgboost

0 comments on commit 2c70751

Please sign in to comment.