diff --git a/tests/cpp/gbm/test_gbtree.cc b/tests/cpp/gbm/test_gbtree.cc index 0cd68765542c..749dc4b66d93 100644 --- a/tests/cpp/gbm/test_gbtree.cc +++ b/tests/cpp/gbm/test_gbtree.cc @@ -412,19 +412,17 @@ TEST(Dart, Slice) { } TEST(GBTree, FeatureScore) { - size_t n_samples = 1000, n_features = 10; - auto m = RandomDataGenerator{n_samples, n_features, 0.5}.GenerateDMatrix(true, false, 4); + size_t n_samples = 1000, n_features = 10, n_classes = 4; + auto m = RandomDataGenerator{n_samples, n_features, 0.5}.GenerateDMatrix(true, false, n_classes); std::unique_ptr learner{ Learner::Create({m}) }; - learner->SetParam("num_class", "4"); + learner->SetParam("num_class", std::to_string(n_classes)); learner->Configure(); for (size_t i = 0; i < 2; ++i) { learner->UpdateOneIter(i, m); } - Json model {Object{}}; - learner->SaveModel(&model); std::vector features_weight; std::vector scores_weight; learner->CalcFeatureScore("weight", &features_weight, &scores_weight);