Skip to content

Commit

Permalink
Check number of trees in inplace predict. (dmlc#7409)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Nov 12, 2021
1 parent e7ac248 commit a2f0242
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/gbm/gbtree.h
Expand Up @@ -273,6 +273,7 @@ class GBTree : public GradientBooster {
uint32_t tree_begin, tree_end;
std::tie(tree_begin, tree_end) =
detail::LayerToTree(model_, tparam_, layer_begin, layer_end);
CHECK_LE(tree_end, model_.trees.size()) << "Invalid number of trees.";
std::vector<Predictor const *> predictors{
cpu_predictor_.get(),
#if defined(XGBOOST_USE_CUDA)
Expand Down
43 changes: 43 additions & 0 deletions tests/cpp/gbm/test_gbtree.cc
Expand Up @@ -452,4 +452,47 @@ TEST(GBTree, FeatureScore) {
test_eq("gain");
test_eq("cover");
}

TEST(GBTree, PredictRange) {
size_t n_samples = 1000, n_features = 10, n_classes = 4;
auto m = RandomDataGenerator{n_samples, n_features, 0.5}.GenerateDMatrix(true, false, n_classes);

std::unique_ptr<Learner> learner{Learner::Create({m})};
learner->SetParam("num_class", std::to_string(n_classes));

learner->Configure();
for (size_t i = 0; i < 2; ++i) {
learner->UpdateOneIter(i, m);
}
HostDeviceVector<float> out_predt;
ASSERT_THROW(learner->Predict(m, false, &out_predt, 0, 3), dmlc::Error);

auto m_1 =
RandomDataGenerator{n_samples, n_features, 0.5}.GenerateDMatrix(true, false, n_classes);
HostDeviceVector<float> out_predt_full;
learner->Predict(m_1, false, &out_predt_full, 0, 0);
ASSERT_TRUE(std::equal(out_predt.HostVector().begin(), out_predt.HostVector().end(),
out_predt_full.HostVector().begin()));

{
// inplace predict
HostDeviceVector<float> raw_storage;
auto raw = RandomDataGenerator{n_samples, n_features, 0.5}.GenerateArrayInterface(&raw_storage);
std::shared_ptr<data::ArrayAdapter> x{new data::ArrayAdapter{StringView{raw}}};

HostDeviceVector<float>* out_predt;
learner->InplacePredict(x, nullptr, PredictionType::kValue,
std::numeric_limits<float>::quiet_NaN(), &out_predt, 0, 2);
auto h_out_predt = out_predt->HostVector();
learner->InplacePredict(x, nullptr, PredictionType::kValue,
std::numeric_limits<float>::quiet_NaN(), &out_predt, 0, 0);
auto h_out_predt_full = out_predt->HostVector();

ASSERT_TRUE(std::equal(h_out_predt.begin(), h_out_predt.end(), h_out_predt_full.begin()));

ASSERT_THROW(learner->InplacePredict(x, nullptr, PredictionType::kValue,
std::numeric_limits<float>::quiet_NaN(), &out_predt, 0, 3),
dmlc::Error);
}
}
} // namespace xgboost

0 comments on commit a2f0242

Please sign in to comment.