Skip to content

Commit

Permalink
Remove omp_get_max_threads in CPU predictor. (#7519)
Browse files Browse the repository at this point in the history
This is part of the on going effort to remove the dependency on global omp variables.
  • Loading branch information
trivialfis committed Jan 4, 2022
1 parent 5516281 commit 68cdbc9
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 68 deletions.
2 changes: 1 addition & 1 deletion include/xgboost/generic_parameters.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ struct GenericParameter : public XGBoostParameter<GenericParameter> {
bool seed_per_iteration;
// number of threads to use if OpenMP is enabled
// if equals 0, use system default
int nthread;
int nthread{0};
// primary device, -1 means no gpu.
int gpu_id;
// fail when gpu_id is invalid
Expand Down
6 changes: 3 additions & 3 deletions include/xgboost/predictor.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,11 @@ class Predictor {
/*
* \brief Runtime parameters.
*/
GenericParameter const* generic_param_;
GenericParameter const* ctx_;

public:
explicit Predictor(GenericParameter const* generic_param) :
generic_param_{generic_param} {}
explicit Predictor(GenericParameter const* ctx) : ctx_{ctx} {}

virtual ~Predictor() = default;

/**
Expand Down
53 changes: 27 additions & 26 deletions src/predictor/cpu_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,12 @@ class AdapterView {
static size_t constexpr kUnroll = kUnrollLen;

public:
explicit AdapterView(Adapter *adapter, float missing,
common::Span<Entry> workplace)
: adapter_{adapter}, missing_{missing}, workspace_{workplace},
current_unroll_(omp_get_max_threads() > 0 ? omp_get_max_threads() : 1, 0) {}
explicit AdapterView(Adapter *adapter, float missing, common::Span<Entry> workplace,
int32_t n_threads)
: adapter_{adapter},
missing_{missing},
workspace_{workplace},
current_unroll_(n_threads > 0 ? n_threads : 1, 0) {}
SparsePage::Inst operator[](size_t i) {
bst_feature_t columns = adapter_->NumColumns();
auto const &batch = adapter_->Value();
Expand Down Expand Up @@ -186,7 +188,7 @@ template <typename DataView, size_t block_of_rows_size>
void PredictBatchByBlockOfRowsKernel(
DataView batch, std::vector<bst_float> *out_preds,
gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end,
std::vector<RegTree::FVec> *p_thread_temp) {
std::vector<RegTree::FVec> *p_thread_temp, int32_t n_threads) {
auto &thread_temp = *p_thread_temp;
int32_t const num_group = model.learner_model_param->num_output_group;

Expand All @@ -197,7 +199,7 @@ void PredictBatchByBlockOfRowsKernel(
const int num_feature = model.learner_model_param->num_feature;
omp_ulong n_blocks = common::DivRoundUp(nsize, block_of_rows_size);

common::ParallelFor(n_blocks, [&](bst_omp_uint block_id) {
common::ParallelFor(n_blocks, n_threads, [&](bst_omp_uint block_id) {
const size_t batch_offset = block_id * block_of_rows_size;
const size_t block_size =
std::min(nsize - batch_offset, block_of_rows_size);
Expand Down Expand Up @@ -252,7 +254,7 @@ class CPUPredictor : public Predictor {
void PredictDMatrix(DMatrix *p_fmat, std::vector<bst_float> *out_preds,
gbm::GBTreeModel const &model, int32_t tree_begin,
int32_t tree_end) const {
const int threads = omp_get_max_threads();
auto const n_threads = this->ctx_->Threads();
constexpr double kDensityThresh = .5;
size_t total = std::max(p_fmat->Info().num_row_ * p_fmat->Info().num_col_,
static_cast<uint64_t>(1));
Expand All @@ -261,23 +263,22 @@ class CPUPredictor : public Predictor {
bool blocked = density > kDensityThresh;

std::vector<RegTree::FVec> feat_vecs;
InitThreadTemp(threads * (blocked ? kBlockOfRowsSize : 1),
InitThreadTemp(n_threads * (blocked ? kBlockOfRowsSize : 1),
model.learner_model_param->num_feature, &feat_vecs);
for (auto const &batch : p_fmat->GetBatches<SparsePage>()) {
CHECK_EQ(out_preds->size(),
p_fmat->Info().num_row_ *
model.learner_model_param->num_output_group);
size_t constexpr kUnroll = 8;
if (blocked) {
PredictBatchByBlockOfRowsKernel<SparsePageView<kUnroll>,
kBlockOfRowsSize>(
SparsePageView<kUnroll>{&batch}, out_preds, model, tree_begin,
tree_end, &feat_vecs);
PredictBatchByBlockOfRowsKernel<SparsePageView<kUnroll>, kBlockOfRowsSize>(
SparsePageView<kUnroll>{&batch}, out_preds, model, tree_begin, tree_end, &feat_vecs,
n_threads);

} else {
PredictBatchByBlockOfRowsKernel<SparsePageView<kUnroll>, 1>(
SparsePageView<kUnroll>{&batch}, out_preds, model, tree_begin,
tree_end, &feat_vecs);
SparsePageView<kUnroll>{&batch}, out_preds, model, tree_begin, tree_end, &feat_vecs,
n_threads);
}
}
}
Expand All @@ -304,7 +305,7 @@ class CPUPredictor : public Predictor {
const gbm::GBTreeModel &model, float missing,
PredictionCacheEntry *out_preds,
uint32_t tree_begin, uint32_t tree_end) const {
auto threads = omp_get_max_threads();
auto const n_threads = this->ctx_->Threads();
auto m = dmlc::get<std::shared_ptr<Adapter>>(x);
CHECK_EQ(m->NumColumns(), model.learner_model_param->num_feature)
<< "Number of columns in data must equal to trained model.";
Expand All @@ -316,14 +317,14 @@ class CPUPredictor : public Predictor {
info.num_row_ = m->NumRows();
this->InitOutPredictions(info, &(out_preds->predictions), model);
}
std::vector<Entry> workspace(m->NumColumns() * 8 * threads);
std::vector<Entry> workspace(m->NumColumns() * 8 * n_threads);
auto &predictions = out_preds->predictions.HostVector();
std::vector<RegTree::FVec> thread_temp;
InitThreadTemp(threads * kBlockSize, model.learner_model_param->num_feature,
InitThreadTemp(n_threads * kBlockSize, model.learner_model_param->num_feature,
&thread_temp);
PredictBatchByBlockOfRowsKernel<AdapterView<Adapter>, kBlockSize>(
AdapterView<Adapter>(m.get(), missing, common::Span<Entry>{workspace}),
&predictions, model, tree_begin, tree_end, &thread_temp);
AdapterView<Adapter>(m.get(), missing, common::Span<Entry>{workspace}, n_threads),
&predictions, model, tree_begin, tree_end, &thread_temp, n_threads);
}

bool InplacePredict(dmlc::any const &x, std::shared_ptr<DMatrix> p_m,
Expand Down Expand Up @@ -370,10 +371,10 @@ class CPUPredictor : public Predictor {

void PredictLeaf(DMatrix* p_fmat, HostDeviceVector<bst_float>* out_preds,
const gbm::GBTreeModel& model, unsigned ntree_limit) const override {
const int nthread = omp_get_max_threads();
auto const n_threads = this->ctx_->Threads();
std::vector<RegTree::FVec> feat_vecs;
const int num_feature = model.learner_model_param->num_feature;
InitThreadTemp(nthread, num_feature, &feat_vecs);
InitThreadTemp(n_threads, num_feature, &feat_vecs);
const MetaInfo& info = p_fmat->Info();
// number of valid trees
if (ntree_limit == 0 || ntree_limit > model.trees.size()) {
Expand All @@ -386,7 +387,7 @@ class CPUPredictor : public Predictor {
// parallel over local batch
auto page = batch.GetView();
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
common::ParallelFor(nsize, [&](bst_omp_uint i) {
common::ParallelFor(nsize, n_threads, [&](bst_omp_uint i) {
const int tid = omp_get_thread_num();
auto ridx = static_cast<size_t>(batch.base_rowid + i);
RegTree::FVec &feats = feat_vecs[tid];
Expand All @@ -411,10 +412,10 @@ class CPUPredictor : public Predictor {
std::vector<bst_float> const *tree_weights,
bool approximate, int condition,
unsigned condition_feature) const override {
const int nthread = omp_get_max_threads();
auto const n_threads = this->ctx_->Threads();
const int num_feature = model.learner_model_param->num_feature;
std::vector<RegTree::FVec> feat_vecs;
InitThreadTemp(nthread, num_feature, &feat_vecs);
InitThreadTemp(n_threads, num_feature, &feat_vecs);
const MetaInfo& info = p_fmat->Info();
// number of valid trees
if (ntree_limit == 0 || ntree_limit > model.trees.size()) {
Expand All @@ -432,7 +433,7 @@ class CPUPredictor : public Predictor {
std::fill(contribs.begin(), contribs.end(), 0);
// initialize tree node mean values
std::vector<std::vector<float>> mean_values(ntree_limit);
common::ParallelFor(bst_omp_uint(ntree_limit), [&](bst_omp_uint i) {
common::ParallelFor(ntree_limit, n_threads, [&](bst_omp_uint i) {
FillNodeMeanValues(model.trees[i].get(), &(mean_values[i]));
});
auto base_margin = info.base_margin_.View(GenericParameter::kCpuId);
Expand All @@ -441,7 +442,7 @@ class CPUPredictor : public Predictor {
auto page = batch.GetView();
// parallel over local batch
const auto nsize = static_cast<bst_omp_uint>(batch.Size());
common::ParallelFor(nsize, [&](bst_omp_uint i) {
common::ParallelFor(nsize, n_threads, [&](bst_omp_uint i) {
auto row_idx = static_cast<size_t>(batch.base_rowid + i);
RegTree::FVec &feats = feat_vecs[omp_get_thread_num()];
if (feats.Size() == 0) {
Expand Down
64 changes: 32 additions & 32 deletions src/predictor/gpu_predictor.cu
Original file line number Diff line number Diff line change
Expand Up @@ -633,12 +633,12 @@ class GPUPredictor : public xgboost::Predictor {
size_t num_features,
HostDeviceVector<bst_float>* predictions,
size_t batch_offset, bool is_dense) const {
batch.offset.SetDevice(generic_param_->gpu_id);
batch.data.SetDevice(generic_param_->gpu_id);
batch.offset.SetDevice(ctx_->gpu_id);
batch.data.SetDevice(ctx_->gpu_id);
const uint32_t BLOCK_THREADS = 128;
size_t num_rows = batch.Size();
auto GRID_SIZE = static_cast<uint32_t>(common::DivRoundUp(num_rows, BLOCK_THREADS));
auto max_shared_memory_bytes = ConfigureDevice(generic_param_->gpu_id);
auto max_shared_memory_bytes = ConfigureDevice(ctx_->gpu_id);
size_t shared_memory_bytes =
SharedMemoryBytes<BLOCK_THREADS>(num_features, max_shared_memory_bytes);
bool use_shared = shared_memory_bytes != 0;
Expand Down Expand Up @@ -694,10 +694,10 @@ class GPUPredictor : public xgboost::Predictor {
if (tree_end - tree_begin == 0) {
return;
}
out_preds->SetDevice(generic_param_->gpu_id);
out_preds->SetDevice(ctx_->gpu_id);
auto const& info = dmat->Info();
DeviceModel d_model;
d_model.Init(model, tree_begin, tree_end, generic_param_->gpu_id);
d_model.Init(model, tree_begin, tree_end, ctx_->gpu_id);

if (dmat->PageExists<SparsePage>()) {
size_t batch_offset = 0;
Expand All @@ -709,10 +709,10 @@ class GPUPredictor : public xgboost::Predictor {
} else {
size_t batch_offset = 0;
for (auto const& page : dmat->GetBatches<EllpackPage>()) {
dmat->Info().feature_types.SetDevice(generic_param_->gpu_id);
dmat->Info().feature_types.SetDevice(ctx_->gpu_id);
auto feature_types = dmat->Info().feature_types.ConstDeviceSpan();
this->PredictInternal(
page.Impl()->GetDeviceAccessor(generic_param_->gpu_id, feature_types),
page.Impl()->GetDeviceAccessor(ctx_->gpu_id, feature_types),
d_model,
out_preds,
batch_offset);
Expand All @@ -726,15 +726,15 @@ class GPUPredictor : public xgboost::Predictor {
Predictor::Predictor{generic_param} {}

~GPUPredictor() override {
if (generic_param_->gpu_id >= 0 && generic_param_->gpu_id < common::AllVisibleGPUs()) {
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
if (ctx_->gpu_id >= 0 && ctx_->gpu_id < common::AllVisibleGPUs()) {
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
}
}

void PredictBatch(DMatrix* dmat, PredictionCacheEntry* predts,
const gbm::GBTreeModel& model, uint32_t tree_begin,
uint32_t tree_end = 0) const override {
int device = generic_param_->gpu_id;
int device = ctx_->gpu_id;
CHECK_GE(device, 0) << "Set `gpu_id' to positive value for processing GPU data.";
auto* out_preds = &predts->predictions;
if (tree_end == 0) {
Expand All @@ -754,7 +754,7 @@ class GPUPredictor : public xgboost::Predictor {
CHECK_EQ(m->NumColumns(), model.learner_model_param->num_feature)
<< "Number of columns in data must equal to trained model.";
CHECK_EQ(dh::CurrentDevice(), m->DeviceIdx())
<< "XGBoost is running on device: " << this->generic_param_->gpu_id << ", "
<< "XGBoost is running on device: " << this->ctx_->gpu_id << ", "
<< "but data is on: " << m->DeviceIdx();
if (p_m) {
p_m->Info().num_row_ = m->NumRows();
Expand Down Expand Up @@ -821,8 +821,8 @@ class GPUPredictor : public xgboost::Predictor {
if (tree_weights != nullptr) {
LOG(FATAL) << "Dart booster feature " << not_implemented;
}
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
out_contribs->SetDevice(generic_param_->gpu_id);
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
out_contribs->SetDevice(ctx_->gpu_id);
if (tree_end == 0 || tree_end > model.trees.size()) {
tree_end = static_cast<uint32_t>(model.trees.size());
}
Expand All @@ -840,12 +840,12 @@ class GPUPredictor : public xgboost::Predictor {
dh::device_vector<gpu_treeshap::PathElement<ShapSplitCondition>>
device_paths;
DeviceModel d_model;
d_model.Init(model, 0, tree_end, generic_param_->gpu_id);
d_model.Init(model, 0, tree_end, ctx_->gpu_id);
dh::device_vector<uint32_t> categories;
ExtractPaths(&device_paths, &d_model, &categories, generic_param_->gpu_id);
ExtractPaths(&device_paths, &d_model, &categories, ctx_->gpu_id);
for (auto& batch : p_fmat->GetBatches<SparsePage>()) {
batch.data.SetDevice(generic_param_->gpu_id);
batch.offset.SetDevice(generic_param_->gpu_id);
batch.data.SetDevice(ctx_->gpu_id);
batch.offset.SetDevice(ctx_->gpu_id);
SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
model.learner_model_param->num_feature);
auto begin = dh::tbegin(phis) + batch.base_rowid * contributions_columns;
Expand All @@ -854,7 +854,7 @@ class GPUPredictor : public xgboost::Predictor {
dh::tend(phis));
}
// Add the base margin term to last column
p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id);
p_fmat->Info().base_margin_.SetDevice(ctx_->gpu_id);
const auto margin = p_fmat->Info().base_margin_.Data()->ConstDeviceSpan();
float base_score = model.learner_model_param->base_score;
dh::LaunchN(
Expand All @@ -879,8 +879,8 @@ class GPUPredictor : public xgboost::Predictor {
if (tree_weights != nullptr) {
LOG(FATAL) << "Dart booster feature " << not_implemented;
}
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
out_contribs->SetDevice(generic_param_->gpu_id);
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
out_contribs->SetDevice(ctx_->gpu_id);
if (tree_end == 0 || tree_end > model.trees.size()) {
tree_end = static_cast<uint32_t>(model.trees.size());
}
Expand All @@ -899,12 +899,12 @@ class GPUPredictor : public xgboost::Predictor {
dh::device_vector<gpu_treeshap::PathElement<ShapSplitCondition>>
device_paths;
DeviceModel d_model;
d_model.Init(model, 0, tree_end, generic_param_->gpu_id);
d_model.Init(model, 0, tree_end, ctx_->gpu_id);
dh::device_vector<uint32_t> categories;
ExtractPaths(&device_paths, &d_model, &categories, generic_param_->gpu_id);
ExtractPaths(&device_paths, &d_model, &categories, ctx_->gpu_id);
for (auto& batch : p_fmat->GetBatches<SparsePage>()) {
batch.data.SetDevice(generic_param_->gpu_id);
batch.offset.SetDevice(generic_param_->gpu_id);
batch.data.SetDevice(ctx_->gpu_id);
batch.offset.SetDevice(ctx_->gpu_id);
SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
model.learner_model_param->num_feature);
auto begin = dh::tbegin(phis) + batch.base_rowid * contributions_columns;
Expand All @@ -913,7 +913,7 @@ class GPUPredictor : public xgboost::Predictor {
dh::tend(phis));
}
// Add the base margin term to last column
p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id);
p_fmat->Info().base_margin_.SetDevice(ctx_->gpu_id);
const auto margin = p_fmat->Info().base_margin_.Data()->ConstDeviceSpan();
float base_score = model.learner_model_param->base_score;
size_t n_features = model.learner_model_param->num_feature;
Expand All @@ -938,8 +938,8 @@ class GPUPredictor : public xgboost::Predictor {
void PredictLeaf(DMatrix *p_fmat, HostDeviceVector<bst_float> *predictions,
const gbm::GBTreeModel &model,
unsigned tree_end) const override {
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
auto max_shared_memory_bytes = ConfigureDevice(generic_param_->gpu_id);
dh::safe_cuda(cudaSetDevice(ctx_->gpu_id));
auto max_shared_memory_bytes = ConfigureDevice(ctx_->gpu_id);

const MetaInfo& info = p_fmat->Info();
constexpr uint32_t kBlockThreads = 128;
Expand All @@ -953,15 +953,15 @@ class GPUPredictor : public xgboost::Predictor {
if (tree_end == 0 || tree_end > model.trees.size()) {
tree_end = static_cast<uint32_t>(model.trees.size());
}
predictions->SetDevice(generic_param_->gpu_id);
predictions->SetDevice(ctx_->gpu_id);
predictions->Resize(num_rows * tree_end);
DeviceModel d_model;
d_model.Init(model, 0, tree_end, this->generic_param_->gpu_id);
d_model.Init(model, 0, tree_end, this->ctx_->gpu_id);

if (p_fmat->PageExists<SparsePage>()) {
for (auto const& batch : p_fmat->GetBatches<SparsePage>()) {
batch.data.SetDevice(generic_param_->gpu_id);
batch.offset.SetDevice(generic_param_->gpu_id);
batch.data.SetDevice(ctx_->gpu_id);
batch.offset.SetDevice(ctx_->gpu_id);
bst_row_t batch_offset = 0;
SparsePageView data{batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
model.learner_model_param->num_feature};
Expand All @@ -986,7 +986,7 @@ class GPUPredictor : public xgboost::Predictor {
} else {
for (auto const& batch : p_fmat->GetBatches<EllpackPage>()) {
bst_row_t batch_offset = 0;
EllpackDeviceAccessor data{batch.Impl()->GetDeviceAccessor(generic_param_->gpu_id)};
EllpackDeviceAccessor data{batch.Impl()->GetDeviceAccessor(ctx_->gpu_id)};
size_t num_rows = batch.Size();
auto grid =
static_cast<uint32_t>(common::DivRoundUp(num_rows, kBlockThreads));
Expand Down
4 changes: 2 additions & 2 deletions src/predictor/predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ void Predictor::InitOutPredictions(const MetaInfo& info, HostDeviceVector<bst_fl
size_t n_classes = model.learner_model_param->num_output_group;
size_t n = n_classes * info.num_row_;
const HostDeviceVector<bst_float>* base_margin = info.base_margin_.Data();
if (generic_param_->gpu_id >= 0) {
out_preds->SetDevice(generic_param_->gpu_id);
if (ctx_->gpu_id >= 0) {
out_preds->SetDevice(ctx_->gpu_id);
}
if (base_margin->Size() != 0) {
out_preds->Resize(n);
Expand Down
7 changes: 3 additions & 4 deletions tests/cpp/predictor/test_predictor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,10 +217,9 @@ void TestCategoricalPrediction(std::string name) {
gbm::GBTreeModel model(&param);
GBTreeModelForTest(&model, split_ind, split_cat, left_weight, right_weight);

GenericParameter runtime;
runtime.gpu_id = 0;
std::unique_ptr<Predictor> predictor{
Predictor::Create(name.c_str(), &runtime)};
GenericParameter ctx;
ctx.UpdateAllowUnknown(Args{{"gpu_id", "0"}});
std::unique_ptr<Predictor> predictor{Predictor::Create(name.c_str(), &ctx)};

std::vector<float> row(kCols);
row[split_ind] = split_cat;
Expand Down

0 comments on commit 68cdbc9

Please sign in to comment.