From 136287abadc467782fc4a748bf8f02ae4f3a96cb Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 5 Mar 2021 18:14:03 +0800 Subject: [PATCH 1/4] Calculate feature scores in native XGBoost. * Support categorical data. --- include/xgboost/c_api.h | 26 +++++++- include/xgboost/gbm.h | 6 ++ include/xgboost/learner.h | 9 +++ python-package/xgboost/core.py | 113 +++++++++------------------------ src/c_api/c_api.cc | 42 ++++++++++++ src/c_api/c_api_utils.h | 36 +++++++++++ src/gbm/gbtree.h | 53 ++++++++++++++++ src/learner.cc | 24 +++++++ tests/cpp/gbm/test_gbtree.cc | 41 +++++++++++- 9 files changed, 266 insertions(+), 84 deletions(-) diff --git a/include/xgboost/c_api.h b/include/xgboost/c_api.h index be3444252248..b0fc28825936 100644 --- a/include/xgboost/c_api.h +++ b/include/xgboost/c_api.h @@ -1,5 +1,5 @@ /*! - * Copyright (c) 2015~2020 by Contributors + * Copyright (c) 2015~2021 by Contributors * \file c_api.h * \author Tianqi Chen * \brief C API of XGBoost, used for interfacing to other languages. @@ -1193,4 +1193,28 @@ XGB_DLL int XGBoosterSetStrFeatureInfo(BoosterHandle handle, const char *field, XGB_DLL int XGBoosterGetStrFeatureInfo(BoosterHandle handle, const char *field, bst_ulong *len, const char ***out_features); + +/*! + * \brief Calculate feature scores for tree models. + * + * \param handle An instance of Booster + * \param json_config Parameters for computing scores. Accepted JSON keys are: + * - importance_type: A JSON string with following possible values: + * * 'weight': the number of times a feature is used to split the data across all trees. + * * 'gain': the average gain across all splits the feature is used in. + * * 'cover': the average coverage across all splits the feature is used in. + * * 'total_gain': the total gain across all splits the feature is used in. + * * 'total_cover': the total coverage across all splits the feature is used in. + * - feature_map: An optional JSON string with URI or path to the feature map file. + * + * \param out_length Length of output arrays. + * \param out_features An array of string as feature names, ordered the same as output scores. + * \param out_scores An array of floating point as feature scores. + * + * \return 0 when success, -1 when failure happens + */ +XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, const char *json_config, + bst_ulong *out_length, + const char ***out_features, + float **out_scores); #endif // XGBOOST_C_API_H_ diff --git a/include/xgboost/gbm.h b/include/xgboost/gbm.h index c49fe4747e1e..fde861f13921 100644 --- a/include/xgboost/gbm.h +++ b/include/xgboost/gbm.h @@ -181,6 +181,12 @@ class GradientBooster : public Model, public Configurable { virtual std::vector DumpModel(const FeatureMap& fmap, bool with_stats, std::string format) const = 0; + + virtual void FeatureScore(std::string const &importance_type, + std::vector *features, + std::vector *scores) const { + LOG(FATAL) << "`feature_score` is not implemented for current booster."; + } /*! * \brief Whether the current booster uses GPU. */ diff --git a/include/xgboost/learner.h b/include/xgboost/learner.h index b8f18022501f..e0fc6073e185 100644 --- a/include/xgboost/learner.h +++ b/include/xgboost/learner.h @@ -46,6 +46,8 @@ struct XGBAPIThreadLocalEntry { std::string ret_str; /*! \brief result holder for returning strings */ std::vector ret_vec_str; + /*! \brief result holder for returning unsigned integers */ + std::vector ret_vec_uint32; /*! \brief result holder for returning string pointers */ std::vector ret_vec_charp; /*! \brief returning float vector. */ @@ -152,6 +154,13 @@ class Learner : public Model, public Configurable, public dmlc::Serializable { HostDeviceVector **out_preds, uint32_t layer_begin, uint32_t layer_end) = 0; + /*! + * \brief Calculate feature score. See doc in C API for outputs. + */ + virtual void CalcFeatureScore(std::string const &importance_type, + std::vector *features, + std::vector *scores) = 0; + /* * \brief Get number of boosted rounds from gradient booster. */ diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index cf415b9e9afc..3aa0883f1235 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -2191,7 +2191,9 @@ def get_fscore(self, fmap=''): return self.get_score(fmap, importance_type='weight') - def get_score(self, fmap='', importance_type='weight'): + def get_score( + self, fmap: os.PathLike = '', importance_type: str = 'weight' + ) -> Dict[str, float]: """Get feature importance of each feature. Importance type can be defined as: @@ -2203,9 +2205,9 @@ def get_score(self, fmap='', importance_type='weight'): .. note:: Feature importance is defined only for tree boosters - Feature importance is only defined when the decision tree model is chosen as base - learner (`booster=gbtree`). It is not defined for other base learner types, such - as linear learners (`booster=gblinear`). + Feature importance is only defined when the decision tree model is chosen as + base learner (`booster=gbtree` or `booster=dart`). It is not defined for other + base learner types, such as linear learners (`booster=gblinear`). Parameters ---------- @@ -2213,86 +2215,33 @@ def get_score(self, fmap='', importance_type='weight'): The name of feature map file. importance_type: str, default 'weight' One of the importance types defined above. + + Returns + ------- + A map between feature names and their scores. """ fmap = os.fspath(os.path.expanduser(fmap)) - if getattr(self, 'booster', None) is not None and self.booster not in {'gbtree', 'dart'}: - raise ValueError('Feature importance is not defined for Booster type {}' - .format(self.booster)) - - allowed_importance_types = ['weight', 'gain', 'cover', 'total_gain', 'total_cover'] - if importance_type not in allowed_importance_types: - msg = ("importance_type mismatch, got '{}', expected one of " + - repr(allowed_importance_types)) - raise ValueError(msg.format(importance_type)) - - # if it's weight, then omap stores the number of missing values - if importance_type == 'weight': - # do a simpler tree dump to save time - trees = self.get_dump(fmap, with_stats=False) - fmap = {} - for tree in trees: - for line in tree.split('\n'): - # look for the opening square bracket - arr = line.split('[') - # if no opening bracket (leaf node), ignore this line - if len(arr) == 1: - continue - - # extract feature name from string between [] - fid = arr[1].split(']')[0].split('<')[0] - - if fid not in fmap: - # if the feature hasn't been seen yet - fmap[fid] = 1 - else: - fmap[fid] += 1 - - return fmap - - average_over_splits = True - if importance_type == 'total_gain': - importance_type = 'gain' - average_over_splits = False - elif importance_type == 'total_cover': - importance_type = 'cover' - average_over_splits = False - - trees = self.get_dump(fmap, with_stats=True) - - importance_type += '=' - fmap = {} - gmap = {} - for tree in trees: - for line in tree.split('\n'): - # look for the opening square bracket - arr = line.split('[') - # if no opening bracket (leaf node), ignore this line - if len(arr) == 1: - continue - - # look for the closing bracket, extract only info within that bracket - fid = arr[1].split(']') - - # extract gain or cover from string after closing bracket - g = float(fid[1].split(importance_type)[1].split(',')[0]) - - # extract feature name from string before closing bracket - fid = fid[0].split('<')[0] - - if fid not in fmap: - # if the feature hasn't been seen yet - fmap[fid] = 1 - gmap[fid] = g - else: - fmap[fid] += 1 - gmap[fid] += g - - # calculate average value (gain/cover) for each feature - if average_over_splits: - for fid in gmap: - gmap[fid] = gmap[fid] / fmap[fid] - - return gmap + args = from_pystr_to_cstr( + json.dumps({"importance_type": importance_type, "feature_map": fmap}) + ) + features = ctypes.POINTER(ctypes.c_char_p)() + scores = ctypes.POINTER(ctypes.c_float)() + length = c_bst_ulong() + _check_call( + _LIB.XGBoosterFeatureScore( + self.handle, + args, + ctypes.byref(length), + ctypes.byref(features), + ctypes.byref(scores) + ) + ) + features_arr = from_cstr_to_pystr(features, length) + scores_arr = ctypes2numpy(scores, length.value, np.float32) + results = {} + for feat, score in zip(features_arr, scores_arr): + results[feat] = score + return results def trees_to_dataframe(self, fmap=''): """Parse a boosted tree model text dump into a pandas DataFrame structure. diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 0b048c988e27..9709a6b056e2 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -1098,5 +1098,47 @@ XGB_DLL int XGBoosterGetStrFeatureInfo(BoosterHandle handle, const char *field, API_END(); } +XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, + const char *json_config, + xgboost::bst_ulong* out_length, + const char ***out_features, + float **out_scores) { + API_BEGIN(); + CHECK_HANDLE(); + auto *learner = static_cast(handle); + auto config = Json::Load(StringView{json_config}); + auto importance = get(config["importance_type"]); + std::string feature_map_uri; + if (!IsA(config["feature_map"])) { + feature_map_uri = get(config["feature_map"]); + } + FeatureMap feature_map = LoadFeatureMap(feature_map_uri); + + auto& scores = learner->GetThreadLocal().ret_vec_float; + auto& features = learner->GetThreadLocal().ret_vec_uint32; + learner->CalcFeatureScore(importance, &features, &scores); + + auto n_features = learner->GetNumFeature(); + GenerateFeatureMap(learner, n_features, &feature_map); + CHECK_LE(features.size(), n_features); + + auto& feature_names = learner->GetThreadLocal().ret_vec_str; + feature_names.resize(features.size()); + auto& feature_names_c = learner->GetThreadLocal().ret_vec_charp; + feature_names_c.resize(features.size()); + + for (bst_feature_t i = 0; i < features.size(); ++i) { + feature_names[i] = feature_map.Name(features[i]); + feature_names_c[i] = feature_names[i].data(); + } + + CHECK_EQ(scores.size(), features.size()); + CHECK_EQ(scores.size(), feature_names.size()); + *out_length = scores.size(); + *out_scores = scores.data(); + *out_features = dmlc::BeginPtr(feature_names_c); + API_END(); +} + // force link rabit static DMLC_ATTRIBUTE_UNUSED int XGBOOST_LINK_RABIT_C_API_ = RabitLinkTag(); diff --git a/src/c_api/c_api_utils.h b/src/c_api/c_api_utils.h index 3a8b130597f8..11f54cfb9549 100644 --- a/src/c_api/c_api_utils.h +++ b/src/c_api/c_api_utils.h @@ -7,6 +7,8 @@ #include #include #include +#include +#include #include "xgboost/logging.h" #include "xgboost/json.h" @@ -181,5 +183,39 @@ class XGBoostAPIGuard { RestoreGPUAttribute(); } }; + +inline FeatureMap LoadFeatureMap(std::string const& uri) { + FeatureMap feat; + if (uri.size() != 0) { + std::unique_ptr fs(dmlc::Stream::Create(uri.c_str(), "r")); + dmlc::istream is(fs.get()); + feat.LoadText(is); + } + return feat; +} + +// FIXME(jiamingy): Use this for model dump. +inline void GenerateFeatureMap(Learner const *learner, + size_t n_features, FeatureMap *out_feature_map) { + auto &feature_map = *out_feature_map; + auto maybe = [&](std::vector const &values, size_t i, + std::string const &dft) { + return values.empty() ? dft : values[i]; + }; + if (feature_map.Size() == 0) { + // Use the feature names and types from booster. + std::vector feature_names; + learner->GetFeatureNames(&feature_names); + std::vector feature_types; + learner->GetFeatureTypes(&feature_types); + for (size_t i = 0; i < n_features; ++i) { + feature_map.PushBack( + i, + maybe(feature_names, i, "f" + std::to_string(i)).data(), + maybe(feature_types, i, "q").data()); + } + } + CHECK_EQ(feature_map.Size(), n_features); +} } // namespace xgboost #endif // XGBOOST_C_API_C_API_UTILS_H_ diff --git a/src/gbm/gbtree.h b/src/gbm/gbtree.h index d948c731e63b..872e22fd4bc0 100644 --- a/src/gbm/gbtree.h +++ b/src/gbm/gbtree.h @@ -9,6 +9,7 @@ #include +#include #include #include #include @@ -299,6 +300,58 @@ class GBTree : public GradientBooster { } } + void FeatureScore(std::string const &importance_type, + std::vector *features, + std::vector *scores) const override { + // Because feature with no importance doesn't appear in the return value so + // we need to set up another pair of vectors to store the values during + // computation. + std::vector split_counts(this->model_.learner_model_param->num_feature, 0); + std::vector gain_map(this->model_.learner_model_param->num_feature, 0); + auto add_score = [&](auto fn) { + for (auto const &p_tree : model_.trees) { + p_tree->WalkTree([&](bst_node_t nidx) { + auto const &node = (*p_tree)[nidx]; + if (!node.IsLeaf()) { + split_counts[node.SplitIndex()]++; + fn(p_tree, nidx, node.SplitIndex()); + } + return true; + }); + } + }; + + if (importance_type == "weight") { + add_score([&](auto const &p_tree, bst_node_t, bst_feature_t split) { + gain_map[split] = split_counts[split]; + }); + } + if (importance_type == "gain" || importance_type == "total_gain") { + add_score([&](auto const &p_tree, bst_node_t nidx, bst_feature_t split) { + gain_map[split] += p_tree->Stat(nidx).loss_chg; + }); + } + if (importance_type == "cover" || importance_type == "total_cover") { + add_score([&](auto const &p_tree, bst_node_t nidx, bst_feature_t split) { + gain_map[split] += p_tree->Stat(nidx).sum_hess; + }); + } + if (importance_type == "gain" || importance_type == "cover") { + for (size_t i = 0; i < gain_map.size(); ++i) { + gain_map[i] /= std::max(1.0f, static_cast(split_counts[i])); + } + } + + features->clear(); + scores->clear(); + for (size_t i = 0; i < split_counts.size(); ++i) { + if (split_counts[i] != 0) { + features->push_back(i); + scores->push_back(gain_map[i]); + } + } + } + void PredictInstance(const SparsePage::Inst& inst, std::vector* out_preds, uint32_t layer_begin, uint32_t layer_end) override { diff --git a/src/learner.cc b/src/learner.cc index b7d2026eb74f..a3086aa7227d 100644 --- a/src/learner.cc +++ b/src/learner.cc @@ -1193,6 +1193,30 @@ class LearnerImpl : public LearnerIO { *out_preds = &out_predictions.predictions; } + void CalcFeatureScore(std::string const &importance_type, + std::vector *features, + std::vector *scores) override { + this->Configure(); + std::vector allowed_importance_type = { + "weight", "total_gain", "total_cover", "gain", "cover" + }; + if (std::find(allowed_importance_type.begin(), + allowed_importance_type.end(), + importance_type) == allowed_importance_type.end()) { + std::stringstream ss; + ss << "importance_type mismatch, got: " << importance_type + << "`, expected one of "; + for (size_t i = 0; i < allowed_importance_type.size(); ++i) { + ss << "`" << allowed_importance_type[i] << "`"; + if (i != allowed_importance_type.size() - 1) { + ss << ", "; + } + } + LOG(FATAL) << ss.str(); + } + gbm_->FeatureScore(importance_type, features, scores); + } + const std::map& GetConfigurationArguments() const override { return cfg_; } diff --git a/tests/cpp/gbm/test_gbtree.cc b/tests/cpp/gbm/test_gbtree.cc index 8a3650bdfd53..0cd68765542c 100644 --- a/tests/cpp/gbm/test_gbtree.cc +++ b/tests/cpp/gbm/test_gbtree.cc @@ -1,5 +1,5 @@ /*! - * Copyright 2019-2020 XGBoost contributors + * Copyright 2019-2021 XGBoost contributors */ #include #include @@ -410,4 +410,43 @@ TEST(Dart, Slice) { auto const& trees = get(model["learner"]["gradient_booster"]["gbtree"]["model"]["trees"]); ASSERT_EQ(weights.size(), trees.size()); } + +TEST(GBTree, FeatureScore) { + size_t n_samples = 1000, n_features = 10; + auto m = RandomDataGenerator{n_samples, n_features, 0.5}.GenerateDMatrix(true, false, 4); + + std::unique_ptr learner{ Learner::Create({m}) }; + learner->SetParam("num_class", "4"); + + learner->Configure(); + for (size_t i = 0; i < 2; ++i) { + learner->UpdateOneIter(i, m); + } + + Json model {Object{}}; + learner->SaveModel(&model); + std::vector features_weight; + std::vector scores_weight; + learner->CalcFeatureScore("weight", &features_weight, &scores_weight); + ASSERT_EQ(features_weight.size(), scores_weight.size()); + ASSERT_LE(features_weight.size(), learner->GetNumFeature()); + ASSERT_TRUE(std::is_sorted(features_weight.begin(), features_weight.end())); + + auto test_eq = [&learner, &scores_weight](std::string type) { + std::vector features; + std::vector scores; + learner->CalcFeatureScore(type, &features, &scores); + + std::vector features_total; + std::vector scores_total; + learner->CalcFeatureScore("total_" + type, &features_total, &scores_total); + + for (size_t i = 0; i < scores_weight.size(); ++i) { + ASSERT_LE(RelError(scores_total[i] / scores[i], scores_weight[i]), kRtEps); + } + }; + + test_eq("gain"); + test_eq("cover"); +} } // namespace xgboost From a6c4a9812d08261c3f0556f60d20d3dcffe1e3c3 Mon Sep 17 00:00:00 2001 From: fis Date: Tue, 15 Jun 2021 20:35:45 +0800 Subject: [PATCH 2/4] Cleanup. --- include/xgboost/learner.h | 2 -- src/c_api/c_api.cc | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/include/xgboost/learner.h b/include/xgboost/learner.h index e0fc6073e185..09c16eff6cfa 100644 --- a/include/xgboost/learner.h +++ b/include/xgboost/learner.h @@ -46,8 +46,6 @@ struct XGBAPIThreadLocalEntry { std::string ret_str; /*! \brief result holder for returning strings */ std::vector ret_vec_str; - /*! \brief result holder for returning unsigned integers */ - std::vector ret_vec_uint32; /*! \brief result holder for returning string pointers */ std::vector ret_vec_charp; /*! \brief returning float vector. */ diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index 9709a6b056e2..5354e78b1e74 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -1115,7 +1115,7 @@ XGB_DLL int XGBoosterFeatureScore(BoosterHandle handle, FeatureMap feature_map = LoadFeatureMap(feature_map_uri); auto& scores = learner->GetThreadLocal().ret_vec_float; - auto& features = learner->GetThreadLocal().ret_vec_uint32; + std::vector features; learner->CalcFeatureScore(importance, &features, &scores); auto n_features = learner->GetNumFeature(); From d097ed5957a84ad236798f58f5ee7fa07aec1f94 Mon Sep 17 00:00:00 2001 From: fis Date: Tue, 15 Jun 2021 21:40:57 +0800 Subject: [PATCH 3/4] Cleanup. --- tests/cpp/gbm/test_gbtree.cc | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/cpp/gbm/test_gbtree.cc b/tests/cpp/gbm/test_gbtree.cc index 0cd68765542c..749dc4b66d93 100644 --- a/tests/cpp/gbm/test_gbtree.cc +++ b/tests/cpp/gbm/test_gbtree.cc @@ -412,19 +412,17 @@ TEST(Dart, Slice) { } TEST(GBTree, FeatureScore) { - size_t n_samples = 1000, n_features = 10; - auto m = RandomDataGenerator{n_samples, n_features, 0.5}.GenerateDMatrix(true, false, 4); + size_t n_samples = 1000, n_features = 10, n_classes = 4; + auto m = RandomDataGenerator{n_samples, n_features, 0.5}.GenerateDMatrix(true, false, n_classes); std::unique_ptr learner{ Learner::Create({m}) }; - learner->SetParam("num_class", "4"); + learner->SetParam("num_class", std::to_string(n_classes)); learner->Configure(); for (size_t i = 0; i < 2; ++i) { learner->UpdateOneIter(i, m); } - Json model {Object{}}; - learner->SaveModel(&model); std::vector features_weight; std::vector scores_weight; learner->CalcFeatureScore("weight", &features_weight, &scores_weight); From 3d82f2b2b574545478d40cf8336803e85cd29b46 Mon Sep 17 00:00:00 2001 From: fis Date: Fri, 18 Jun 2021 00:48:20 +0800 Subject: [PATCH 4/4] Make sure the number of features is correct. --- src/c_api/c_api_utils.h | 6 ++++++ tests/python/test_basic.py | 17 +++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/src/c_api/c_api_utils.h b/src/c_api/c_api_utils.h index 11f54cfb9549..7c1538cb132c 100644 --- a/src/c_api/c_api_utils.h +++ b/src/c_api/c_api_utils.h @@ -206,8 +206,14 @@ inline void GenerateFeatureMap(Learner const *learner, // Use the feature names and types from booster. std::vector feature_names; learner->GetFeatureNames(&feature_names); + if (!feature_names.empty()) { + CHECK_EQ(feature_names.size(), n_features) << "Incorrect number of feature names."; + } std::vector feature_types; learner->GetFeatureTypes(&feature_types); + if (!feature_types.empty()) { + CHECK_EQ(feature_types.size(), n_features) << "Incorrect number of feature types."; + } for (size_t i = 0; i < n_features; ++i) { feature_map.PushBack( i, diff --git a/tests/python/test_basic.py b/tests/python/test_basic.py index 7ce87b208642..2d2bec51867c 100644 --- a/tests/python/test_basic.py +++ b/tests/python/test_basic.py @@ -154,6 +154,23 @@ def test_dump(self): dump4j = json.loads(dump4[0]) assert 'gain' in dump4j, "Expected 'gain' to be dumped in JSON." + def test_feature_score(self): + rng = np.random.RandomState(0) + data = rng.randn(100, 2) + target = np.array([0, 1] * 50) + features = ["F0"] + with pytest.raises(ValueError): + xgb.DMatrix(data, label=target, feature_names=features) + + params = {"objective": "binary:logistic"} + dm = xgb.DMatrix(data, label=target, feature_names=["F0", "F1"]) + booster = xgb.train(params, dm, num_boost_round=1) + # no error since feature names might be assigned before the booster seeing data + # and booster doesn't known about the actual number of features. + booster.feature_names = ["F0"] + with pytest.raises(ValueError): + booster.get_fscore() + def test_load_file_invalid(self): with pytest.raises(xgb.core.XGBoostError): xgb.Booster(model_file='incorrect_path')