Skip to content

Commit

Permalink
Define nthreads.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jul 4, 2021
1 parent 3208467 commit 72e5b35
Show file tree
Hide file tree
Showing 5 changed files with 61 additions and 38 deletions.
25 changes: 22 additions & 3 deletions include/xgboost/data.h
Expand Up @@ -541,11 +541,30 @@ class DMatrix {
int nthread,
int max_bin);

/**
* \brief Create an external memory DMatrix with callbacks.
*
* \tparam DataIterHandle External iterator type, defined in C API.
* \tparam DMatrixHandle DMatrix handle, defined in C API.
* \tparam DataIterResetCallback Callback for reset, prototype defined in C API.
* \tparam XGDMatrixCallbackNext Callback for next, prototype defined in C API.
*
* \param iter External data iterator
* \param proxy A hanlde to ProxyDMatrix
* \param reset Callback for reset
* \param next Callback for next
* \param missing Value that should be treated as missing.
* \param nthread number of threads used for initialization.
* \param cache Prefix of cache file path.
*
* \return A created external memory DMatrix.
*/
template <typename DataIterHandle, typename DMatrixHandle,
typename DataIterResetCallback, typename XGDMatrixCallbackNext>
static DMatrix *
Create(DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, float missing, std::string cache);
static DMatrix *Create(DataIterHandle iter, DMatrixHandle proxy,
DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, float missing,
int32_t nthread, std::string cache);

virtual DMatrix *Slice(common::Span<int32_t const> ridxs) = 0;
/*! \brief Number of rows per page in external memory. Approximately 100MB per page for
Expand Down
54 changes: 29 additions & 25 deletions src/c_api/c_api.cc
Expand Up @@ -190,6 +190,35 @@ XGB_DLL int XGDMatrixCreateFromArrayInterface(char const* c_json_strs,
#endif

// Create from data iterator
XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter,
DMatrixHandle proxy,
DataIterResetCallback *reset,
XGDMatrixCallbackNext *next,
char const* c_json_config,
DMatrixHandle *out) {
API_BEGIN();
auto config = Json::Load(StringView{c_json_config});
float missing = get<Number const>(config["missing"]);
std::string cache = get<String const>(config["cache_prefix"]);
int32_t n_threads = omp_get_max_threads();
if (!IsA<Null>(config["nthread"])) {
n_threads = get<Integer const>(config["nthread"]);
}
*out = new std::shared_ptr<xgboost::DMatrix>{xgboost::DMatrix::Create(
iter, proxy, reset, next, missing, n_threads, cache)};
API_END();
}

XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallback(
DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, float missing, int nthread,
int max_bin, DMatrixHandle *out) {
API_BEGIN();
*out = new std::shared_ptr<xgboost::DMatrix>{
xgboost::DMatrix::Create(iter, proxy, reset, next, missing, nthread, max_bin)};
API_END();
}

XGB_DLL int XGProxyDMatrixCreate(DMatrixHandle* out) {
API_BEGIN();
*out = new std::shared_ptr<xgboost::DMatrix>(new xgboost::data::DMatrixProxy);;
Expand Down Expand Up @@ -221,31 +250,6 @@ XGB_DLL int XGProxyDMatrixSetDataCudaColumnar(DMatrixHandle handle,
API_END();
}

XGB_DLL int XGDMatrixCreateFromCallback(DataIterHandle iter,
DMatrixHandle proxy,
DataIterResetCallback *reset,
XGDMatrixCallbackNext *next,
char const* c_json_config,
DMatrixHandle *out) {
API_BEGIN();
auto config = Json::Load(StringView{c_json_config});
float missing = get<Number const>(config["missing"]);
std::string cache = get<String const>(config["cache_prefix"]);
*out = new std::shared_ptr<xgboost::DMatrix>{
xgboost::DMatrix::Create(iter, proxy, reset, next, missing, cache)};
API_END();
}

XGB_DLL int XGDeviceQuantileDMatrixCreateFromCallback(
DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, float missing, int nthread,
int max_bin, DMatrixHandle *out) {
API_BEGIN();
*out = new std::shared_ptr<xgboost::DMatrix>{
xgboost::DMatrix::Create(iter, proxy, reset, next, missing, nthread, max_bin)};
API_END();
}

XGB_DLL int XGProxyDMatrixSetDataDense(DMatrixHandle handle,
char const *c_interface_str) {
API_BEGIN();
Expand Down
5 changes: 3 additions & 2 deletions src/data/data.cc
Expand Up @@ -800,8 +800,9 @@ template <typename DataIterHandle, typename DMatrixHandle,
DMatrix *DMatrix::Create(DataIterHandle iter, DMatrixHandle proxy,
DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, float missing,
int32_t n_threads,
std::string cache) {
return new data::SparsePageDMatrix(iter, proxy, reset, next, missing, 1,
return new data::SparsePageDMatrix(iter, proxy, reset, next, missing, n_threads,
cache);
}

Expand All @@ -814,7 +815,7 @@ template DMatrix *DMatrix::Create<DataIterHandle, DMatrixHandle,
template DMatrix *DMatrix::Create<DataIterHandle, DMatrixHandle,
DataIterResetCallback, XGDMatrixCallbackNext>(
DataIterHandle iter, DMatrixHandle proxy, DataIterResetCallback *reset,
XGDMatrixCallbackNext *next, float missing, std::string);
XGDMatrixCallbackNext *next, float missing, int32_t n_threads, std::string);

template <typename AdapterT>
DMatrix* DMatrix::Create(AdapterT* adapter, float missing, int nthread,
Expand Down
9 changes: 4 additions & 5 deletions src/data/sparse_page_source.h
Expand Up @@ -198,9 +198,8 @@ class SparsePageSourceImpl : public BatchIteratorImpl<S> {
virtual void Fetch() = 0;

public:
SparsePageSourceImpl(float missing, int nthreads,
bst_feature_t n_features, size_t n_batches,
std::shared_ptr<Cache> cache)
SparsePageSourceImpl(float missing, int nthreads, bst_feature_t n_features,
size_t n_batches, std::shared_ptr<Cache> cache)
: missing_{missing}, nthreads_{nthreads}, n_features_{n_features},
n_batches_{n_batches}, cache_info_{std::move(cache)} {}

Expand Down Expand Up @@ -276,8 +275,8 @@ class SparsePageSource : public SparsePageSourceImpl<SparsePage> {
DataIterProxy<DataIterResetCallback, XGDMatrixCallbackNext> iter,
DMatrixProxy *proxy, float missing, int nthreads,
bst_feature_t n_features, size_t n_batches, std::shared_ptr<Cache> cache)
: SparsePageSourceImpl(missing, nthreads, n_features,
n_batches, cache), iter_{iter}, proxy_{proxy} {
: SparsePageSourceImpl(missing, nthreads, n_features, n_batches, cache),
iter_{iter}, proxy_{proxy} {
if (!cache_info_->written) {
iter_.Reset();
iter_.Next();
Expand Down
6 changes: 3 additions & 3 deletions tests/cpp/helpers.cc
Expand Up @@ -355,9 +355,9 @@ std::unique_ptr<DMatrix> CreateSparsePageDMatrix(
size_t n_rows = n_entries / n_columns;
ArrayIterForTest iter(0, n_rows, n_columns, 2);

std::unique_ptr<DMatrix> dmat{
DMatrix::Create(static_cast<DataIterHandle>(&iter), iter.Proxy(), Reset,
Next, std::numeric_limits<float>::quiet_NaN(), tmp_file)};
std::unique_ptr<DMatrix> dmat{DMatrix::Create(
static_cast<DataIterHandle>(&iter), iter.Proxy(), Reset, Next,
std::numeric_limits<float>::quiet_NaN(), 1, tmp_file)};
auto row_page_path =
data::MakeId(tmp_file,
dynamic_cast<data::SparsePageDMatrix *>(dmat.get())) +
Expand Down

0 comments on commit 72e5b35

Please sign in to comment.