Skip to content

Commit

Permalink
Fix race condition in CPU shap. (#7050)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jun 21, 2021
1 parent 29f8fd6 commit bbfffb4
Show file tree
Hide file tree
Showing 5 changed files with 69 additions and 65 deletions.
28 changes: 12 additions & 16 deletions include/xgboost/predictor.h
Expand Up @@ -206,22 +206,18 @@ class Predictor {
* \param condition_feature Feature to condition on (i.e. fix) during calculations.
*/

virtual void PredictContribution(DMatrix* dmat,
HostDeviceVector<bst_float>* out_contribs,
const gbm::GBTreeModel& model,
unsigned tree_end = 0,
std::vector<bst_float>* tree_weights = nullptr,
bool approximate = false,
int condition = 0,
unsigned condition_feature = 0) const = 0;

virtual void PredictInteractionContributions(DMatrix* dmat,
HostDeviceVector<bst_float>* out_contribs,
const gbm::GBTreeModel& model,
unsigned tree_end = 0,
std::vector<bst_float>* tree_weights = nullptr,
bool approximate = false) const = 0;

virtual void
PredictContribution(DMatrix *dmat, HostDeviceVector<bst_float> *out_contribs,
const gbm::GBTreeModel &model, unsigned tree_end = 0,
std::vector<bst_float> const *tree_weights = nullptr,
bool approximate = false, int condition = 0,
unsigned condition_feature = 0) const = 0;

virtual void PredictInteractionContributions(
DMatrix *dmat, HostDeviceVector<bst_float> *out_contribs,
const gbm::GBTreeModel &model, unsigned tree_end = 0,
std::vector<bst_float> const *tree_weights = nullptr,
bool approximate = false) const = 0;

/**
* \brief Creates a new Predictor*.
Expand Down
8 changes: 2 additions & 6 deletions include/xgboost/tree_model.h
Expand Up @@ -550,6 +550,7 @@ class RegTree : public Model {
* \param condition_feature the index of the feature to fix
*/
void CalculateContributions(const RegTree::FVec& feat,
std::vector<float>* mean_values,
bst_float* out_contribs, int condition = 0,
unsigned condition_feature = 0) const;
/*!
Expand Down Expand Up @@ -578,6 +579,7 @@ class RegTree : public Model {
* \param out_contribs output vector to hold the contributions
*/
void CalculateContributionsApprox(const RegTree::FVec& feat,
std::vector<float>* mean_values,
bst_float* out_contribs) const;
/*!
* \brief dump the model in the requested format as a text string
Expand All @@ -589,10 +591,6 @@ class RegTree : public Model {
std::string DumpModel(const FeatureMap& fmap,
bool with_stats,
std::string format) const;
/*!
* \brief calculate the mean value for each node, required for feature contributions
*/
void FillNodeMeanValues();
/*!
* \brief Get split type for a node.
* \param nidx Index of node.
Expand Down Expand Up @@ -639,7 +637,6 @@ class RegTree : public Model {
std::vector<int> deleted_nodes_;
// stats of nodes
std::vector<RTreeNodeStat> stats_;
std::vector<bst_float> node_mean_values_;
std::vector<FeatureType> split_types_;

// Categories for each internal node.
Expand Down Expand Up @@ -680,7 +677,6 @@ class RegTree : public Model {
nodes_[nid].MarkDelete();
++param.num_deleted;
}
bst_float FillNodeMeanValue(int nid);
};

inline void RegTree::FVec::Init(size_t size) {
Expand Down
61 changes: 47 additions & 14 deletions src/predictor/cpu_predictor.cc
Expand Up @@ -213,6 +213,32 @@ void PredictBatchByBlockOfRowsKernel(
});
}

float FillNodeMeanValues(RegTree const *tree, bst_node_t nidx, std::vector<float> *mean_values) {
bst_float result;
auto &node = (*tree)[nidx];
auto &node_mean_values = *mean_values;
if (node.IsLeaf()) {
result = node.LeafValue();
} else {
result = FillNodeMeanValues(tree, node.LeftChild(), mean_values) *
tree->Stat(node.LeftChild()).sum_hess;
result += FillNodeMeanValues(tree, node.RightChild(), mean_values) *
tree->Stat(node.RightChild()).sum_hess;
result /= tree->Stat(nidx).sum_hess;
}
node_mean_values[nidx] = result;
return result;
}

void FillNodeMeanValues(RegTree const* tree, std::vector<float>* mean_values) {
size_t num_nodes = tree->param.num_nodes;
if (mean_values->size() == num_nodes) {
return;
}
mean_values->resize(num_nodes);
FillNodeMeanValues(tree, 0, mean_values);
}

class CPUPredictor : public Predictor {
protected:
// init thread buffers
Expand Down Expand Up @@ -396,9 +422,10 @@ class CPUPredictor : public Predictor {
}
}

void PredictContribution(DMatrix* p_fmat, HostDeviceVector<float>* out_contribs,
const gbm::GBTreeModel& model, uint32_t ntree_limit,
std::vector<bst_float>* tree_weights,
void PredictContribution(DMatrix *p_fmat,
HostDeviceVector<float> *out_contribs,
const gbm::GBTreeModel &model, uint32_t ntree_limit,
std::vector<bst_float> const *tree_weights,
bool approximate, int condition,
unsigned condition_feature) const override {
const int nthread = omp_get_max_threads();
Expand All @@ -421,8 +448,9 @@ class CPUPredictor : public Predictor {
// allocated one
std::fill(contribs.begin(), contribs.end(), 0);
// initialize tree node mean values
std::vector<std::vector<float>> mean_values(ntree_limit);
common::ParallelFor(bst_omp_uint(ntree_limit), [&](bst_omp_uint i) {
model.trees[i]->FillNodeMeanValues();
FillNodeMeanValues(model.trees[i].get(), &(mean_values[i]));
});
const std::vector<bst_float>& base_margin = info.base_margin_.HostVector();
// start collecting the contributions
Expand All @@ -443,19 +471,23 @@ class CPUPredictor : public Predictor {
feats.Fill(page[i]);
// calculate contributions
for (unsigned j = 0; j < ntree_limit; ++j) {
auto *tree_mean_values = &mean_values.at(j);
std::fill(this_tree_contribs.begin(), this_tree_contribs.end(), 0);
if (model.tree_info[j] != gid) {
continue;
}
if (!approximate) {
model.trees[j]->CalculateContributions(feats, &this_tree_contribs[0],
condition, condition_feature);
model.trees[j]->CalculateContributions(
feats, tree_mean_values, &this_tree_contribs[0], condition,
condition_feature);
} else {
model.trees[j]->CalculateContributionsApprox(feats, &this_tree_contribs[0]);
model.trees[j]->CalculateContributionsApprox(
feats, tree_mean_values, &this_tree_contribs[0]);
}
for (size_t ci = 0 ; ci < ncolumns ; ++ci) {
p_contribs[ci] += this_tree_contribs[ci] *
(tree_weights == nullptr ? 1 : (*tree_weights)[j]);
for (size_t ci = 0; ci < ncolumns; ++ci) {
p_contribs[ci] +=
this_tree_contribs[ci] *
(tree_weights == nullptr ? 1 : (*tree_weights)[j]);
}
}
feats.Drop(page[i]);
Expand All @@ -470,10 +502,11 @@ class CPUPredictor : public Predictor {
}
}

void PredictInteractionContributions(DMatrix* p_fmat, HostDeviceVector<bst_float>* out_contribs,
const gbm::GBTreeModel& model, unsigned ntree_limit,
std::vector<bst_float>* tree_weights,
bool approximate) const override {
void PredictInteractionContributions(
DMatrix *p_fmat, HostDeviceVector<bst_float> *out_contribs,
const gbm::GBTreeModel &model, unsigned ntree_limit,
std::vector<bst_float> const *tree_weights,
bool approximate) const override {
const MetaInfo& info = p_fmat->Info();
const int ngroup = model.learner_model_param->num_output_group;
size_t const ncolumns = model.learner_model_param->num_feature;
Expand Down
4 changes: 2 additions & 2 deletions src/predictor/gpu_predictor.cu
Expand Up @@ -696,7 +696,7 @@ class GPUPredictor : public xgboost::Predictor {
void PredictContribution(DMatrix* p_fmat,
HostDeviceVector<bst_float>* out_contribs,
const gbm::GBTreeModel& model, unsigned tree_end,
std::vector<bst_float>*,
std::vector<bst_float> const*,
bool approximate, int,
unsigned) const override {
if (approximate) {
Expand Down Expand Up @@ -746,7 +746,7 @@ class GPUPredictor : public xgboost::Predictor {
HostDeviceVector<bst_float>* out_contribs,
const gbm::GBTreeModel& model,
unsigned tree_end,
std::vector<bst_float>*,
std::vector<bst_float> const*,
bool approximate) const override {
if (approximate) {
LOG(FATAL) << "[Internal error]: " << __func__
Expand Down
33 changes: 6 additions & 27 deletions src/tree/tree_model.cc
Expand Up @@ -1128,36 +1128,14 @@ void RegTree::SaveModel(Json* p_out) const {
out["default_left"] = std::move(default_left);
}

void RegTree::FillNodeMeanValues() {
size_t num_nodes = this->param.num_nodes;
if (this->node_mean_values_.size() == num_nodes) {
return;
}
this->node_mean_values_.resize(num_nodes);
this->FillNodeMeanValue(0);
}

bst_float RegTree::FillNodeMeanValue(int nid) {
bst_float result;
auto& node = (*this)[nid];
if (node.IsLeaf()) {
result = node.LeafValue();
} else {
result = this->FillNodeMeanValue(node.LeftChild()) * this->Stat(node.LeftChild()).sum_hess;
result += this->FillNodeMeanValue(node.RightChild()) * this->Stat(node.RightChild()).sum_hess;
result /= this->Stat(nid).sum_hess;
}
this->node_mean_values_[nid] = result;
return result;
}

void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat,
std::vector<float>* mean_values,
bst_float *out_contribs) const {
CHECK_GT(this->node_mean_values_.size(), 0U);
CHECK_GT(mean_values->size(), 0U);
// this follows the idea of http://blog.datadive.net/interpreting-random-forests/
unsigned split_index = 0;
// update bias value
bst_float node_value = this->node_mean_values_[0];
bst_float node_value = (*mean_values)[0];
out_contribs[feat.Size()] += node_value;
if ((*this)[0].IsLeaf()) {
// nothing to do anymore
Expand All @@ -1172,7 +1150,7 @@ void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat,
nid = predictor::GetNextNode<true, true>((*this)[nid], nid,
feat.GetFvalue(split_index),
feat.IsMissing(split_index), cats);
bst_float new_value = this->node_mean_values_[nid];
bst_float new_value = (*mean_values)[nid];
// update feature weight
out_contribs[split_index] += new_value - node_value;
node_value = new_value;
Expand Down Expand Up @@ -1352,12 +1330,13 @@ void RegTree::TreeShap(const RegTree::FVec &feat, bst_float *phi,
}

void RegTree::CalculateContributions(const RegTree::FVec &feat,
std::vector<float>* mean_values,
bst_float *out_contribs,
int condition,
unsigned condition_feature) const {
// find the expected value of the tree's predictions
if (condition == 0) {
bst_float node_value = this->node_mean_values_[0];
bst_float node_value = (*mean_values)[0];
out_contribs[feat.Size()] += node_value;
}

Expand Down

0 comments on commit bbfffb4

Please sign in to comment.