From ca6f9809329103449e821a1f1efd9ddac5d80897 Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Fri, 12 Nov 2021 18:20:23 +0800 Subject: [PATCH] Check number of trees in inplace predict. (#7409) --- src/gbm/gbtree.h | 1 + tests/cpp/gbm/test_gbtree.cc | 43 ++++++++++++++++++++++++++++++++++++ 2 files changed, 44 insertions(+) diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index a0b15603ea67..4e508bbaec2f 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -273,6 +273,7 @@ class GBTree : public GradientBooster { uint32_t tree_begin, tree_end; std::tie(tree_begin, tree_end) = detail::LayerToTree(model_, tparam_, layer_begin, layer_end); + CHECK_LE(tree_end, model_.trees.size()) << "Invalid number of trees."; std::vector predictors{ cpu_predictor_.get(), #if defined(XGBOOST_USE_CUDA) diff --git a/tests/cpp/gbm/test_gbtree.cc b/tests/cpp/gbm/test_gbtree.cc index 3c307594b8de..6a454f96df16 100644 --- a/tests/cpp/gbm/test_gbtree.cc +++ b/tests/cpp/gbm/test_gbtree.cc @@ -452,4 +452,47 @@ TEST(GBTree, FeatureScore) { test_eq("gain"); test_eq("cover"); } + +TEST(GBTree, PredictRange) { + size_t n_samples = 1000, n_features = 10, n_classes = 4; + auto m = RandomDataGenerator{n_samples, n_features, 0.5}.GenerateDMatrix(true, false, n_classes); + + std::unique_ptr learner{Learner::Create({m})}; + learner->SetParam("num_class", std::to_string(n_classes)); + + learner->Configure(); + for (size_t i = 0; i < 2; ++i) { + learner->UpdateOneIter(i, m); + } + HostDeviceVector out_predt; + ASSERT_THROW(learner->Predict(m, false, &out_predt, 0, 3), dmlc::Error); + + auto m_1 = + RandomDataGenerator{n_samples, n_features, 0.5}.GenerateDMatrix(true, false, n_classes); + HostDeviceVector out_predt_full; + learner->Predict(m_1, false, &out_predt_full, 0, 0); + ASSERT_TRUE(std::equal(out_predt.HostVector().begin(), out_predt.HostVector().end(), + out_predt_full.HostVector().begin())); + + { + // inplace predict + HostDeviceVector raw_storage; + auto raw = RandomDataGenerator{n_samples, n_features, 0.5}.GenerateArrayInterface(&raw_storage); + std::shared_ptr x{new data::ArrayAdapter{StringView{raw}}}; + + HostDeviceVector* out_predt; + learner->InplacePredict(x, nullptr, PredictionType::kValue, + std::numeric_limits::quiet_NaN(), &out_predt, 0, 2); + auto h_out_predt = out_predt->HostVector(); + learner->InplacePredict(x, nullptr, PredictionType::kValue, + std::numeric_limits::quiet_NaN(), &out_predt, 0, 0); + auto h_out_predt_full = out_predt->HostVector(); + + ASSERT_TRUE(std::equal(h_out_predt.begin(), h_out_predt.end(), h_out_predt_full.begin())); + + ASSERT_THROW(learner->InplacePredict(x, nullptr, PredictionType::kValue, + std::numeric_limits::quiet_NaN(), &out_predt, 0, 3), + dmlc::Error); + } +} } // namespace xgboost