Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update GPUTreeshap #6163

Merged
merged 5 commits into from Sep 27, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion gputreeshap
2 changes: 1 addition & 1 deletion src/gbm/gbtree.h
Expand Up @@ -249,7 +249,7 @@ class GBTree : public GradientBooster {
std::vector<bst_float>* out_contribs,
unsigned ntree_limit, bool approximate) override {
CHECK(configured_);
cpu_predictor_->PredictInteractionContributions(p_fmat, out_contribs, model_,
this->GetPredictor()->PredictInteractionContributions(p_fmat, out_contribs, model_,
ntree_limit, nullptr, approximate);
}

Expand Down
73 changes: 62 additions & 11 deletions src/predictor/gpu_predictor.cu
Expand Up @@ -602,13 +602,74 @@ class GPUPredictor : public xgboost::Predictor {
model.learner_model_param->num_feature);
gpu_treeshap::GPUTreeShap(
X, device_paths.begin(), device_paths.end(), ngroup,
phis.data().get() + batch.base_rowid * contributions_columns);
phis.data().get() + batch.base_rowid * contributions_columns, phis.size());
}
dh::safe_cuda(cudaMemcpy(contribs.data(), phis.data().get(),
sizeof(float) * phis.size(),
cudaMemcpyDefault));
}

void PredictInteractionContributions(DMatrix* p_fmat,
std::vector<bst_float>* out_contribs,
RAMitchell marked this conversation as resolved.
Show resolved Hide resolved
const gbm::GBTreeModel& model,
unsigned ntree_limit,
std::vector<bst_float>* tree_weights,
bool approximate) override {
if (approximate) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need dispatching in gbm get predictor.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There might be a benefit to explicitly failing, maybe the user is expecting the algorithm to use GPU and it is silently switching to CPU.

LOG(FATAL) << "[Internal error]: " << __func__
<< " approximate is not implemented in GPU Predictor.";
}

dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
uint32_t real_ntree_limit =
ntree_limit * model.learner_model_param->num_output_group;
if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) {
real_ntree_limit = static_cast<uint32_t>(model.trees.size());
}

const int ngroup = model.learner_model_param->num_output_group;
CHECK_NE(ngroup, 0);
// allocate space for (number of features + bias) times the number of rows
size_t contributions_columns =
model.learner_model_param->num_feature + 1; // +1 for bias
out_contribs->resize(p_fmat->Info().num_row_ * contributions_columns *
contributions_columns *
model.learner_model_param->num_output_group);
dh::TemporaryArray<float> phis(out_contribs->size(), 0.0);
p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id);
const auto margin = p_fmat->Info().base_margin_.ConstDeviceSpan();
float base_score = model.learner_model_param->base_score;
auto d_phis = phis.data().get();
// Add the base margin term to last column
size_t n_features = model.learner_model_param->num_feature;
dh::LaunchN(
generic_param_->gpu_id,
p_fmat->Info().num_row_ * model.learner_model_param->num_output_group,
[=] __device__(size_t idx) {
size_t group = idx % ngroup;
size_t row_idx = idx / ngroup;
d_phis[gpu_treeshap::IndexPhiInteractions(
row_idx, ngroup, group, n_features, n_features, n_features)] =
margin.empty() ? base_score : margin[idx];
});

dh::device_vector<gpu_treeshap::PathElement> device_paths;
ExtractPaths(&device_paths, model, real_ntree_limit,
generic_param_->gpu_id);
for (auto& batch : p_fmat->GetBatches<SparsePage>()) {
batch.data.SetDevice(generic_param_->gpu_id);
batch.offset.SetDevice(generic_param_->gpu_id);
SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
model.learner_model_param->num_feature);
gpu_treeshap::GPUTreeShapInteractions(
X, device_paths.begin(), device_paths.end(), ngroup,
phis.data().get() + batch.base_rowid * contributions_columns, phis.size());
}
dh::safe_cuda(cudaMemcpy(out_contribs->data(), phis.data().get(),
sizeof(float) * phis.size(),
RAMitchell marked this conversation as resolved.
Show resolved Hide resolved
cudaMemcpyDefault));
}

protected:
void InitOutPredictions(const MetaInfo& info,
HostDeviceVector<bst_float>* out_preds,
Expand Down Expand Up @@ -640,16 +701,6 @@ class GPUPredictor : public xgboost::Predictor {
<< " is not implemented in GPU Predictor.";
}

void PredictInteractionContributions(DMatrix* p_fmat,
std::vector<bst_float>* out_contribs,
const gbm::GBTreeModel& model,
unsigned ntree_limit,
std::vector<bst_float>* tree_weights,
bool approximate) override {
LOG(FATAL) << "[Internal error]: " << __func__
<< " is not implemented in GPU Predictor.";
}

void Configure(const std::vector<std::pair<std::string, std::string>>& cfg) override {
Predictor::Configure(cfg);
}
Expand Down
1 change: 1 addition & 0 deletions tests/ci_build/conda_env/cpu_test.yml
Expand Up @@ -29,6 +29,7 @@ dependencies:
- boto3
- awscli
- pip:
- shap
- guzzle_sphinx_theme
- datatable
- modin[all]
35 changes: 20 additions & 15 deletions tests/python-gpu/test_gpu_prediction.py
Expand Up @@ -16,7 +16,7 @@
'max_depth': strategies.integers(0, 11),
'max_leaves': strategies.integers(0, 256),
'num_parallel_tree': strategies.sampled_from([1, 10]),
})
}).filter(lambda x: x['max_depth'] > 0 or x['max_leaves'] > 0)


class TestGPUPredict(unittest.TestCase):
Expand Down Expand Up @@ -194,26 +194,31 @@ def predict_df(x):
for i in range(10):
run_threaded_predict(X, rows, predict_df)

@given(strategies.integers(1, 200),
tm.dataset_strategy, shap_parameter_strategy, strategies.booleans())
@given(strategies.integers(1, 10),
tm.dataset_strategy, shap_parameter_strategy)
@settings(deadline=None)
def test_shap(self, num_rounds, dataset, param, all_rows):
if param['max_depth'] == 0 and param['max_leaves'] == 0:
return

def test_shap(self, num_rounds, dataset, param):
param.update({"predictor": "gpu_predictor", "gpu_id": 0})
param = dataset.set_params(param)
dmat = dataset.get_dmat()
bst = xgb.train(param, dmat, num_rounds)
if all_rows:
test_dmat = xgb.DMatrix(dataset.X, dataset.y, dataset.w, dataset.margin)
else:
test_dmat = xgb.DMatrix(dataset.X[0:1, :])
test_dmat = xgb.DMatrix(dataset.X, dataset.y, dataset.w, dataset.margin)
shap = bst.predict(test_dmat, pred_contribs=True)
bst.set_param({"predictor": "cpu_predictor"})
cpu_shap = bst.predict(test_dmat, pred_contribs=True)
margin = bst.predict(test_dmat, output_margin=True)
assert np.allclose(shap, cpu_shap, 1e-3, 1e-3)
# feature contributions should add up to predictions
assume(len(dataset.y) > 0)
assert np.allclose(np.sum(shap, axis=len(shap.shape) - 1), margin, 1e-3, 1e-3)

@given(strategies.integers(1, 10),
tm.dataset_strategy, shap_parameter_strategy)
@settings(deadline=None, max_examples=20)
def test_shap_interactions(self, num_rounds, dataset, param):
param.update({"predictor": "gpu_predictor", "gpu_id": 0})
param = dataset.set_params(param)
dmat = dataset.get_dmat()
bst = xgb.train(param, dmat, num_rounds)
test_dmat = xgb.DMatrix(dataset.X, dataset.y, dataset.w, dataset.margin)
shap = bst.predict(test_dmat, pred_interactions=True)
margin = bst.predict(test_dmat, output_margin=True)
assume(len(dataset.y) > 0)
assert np.allclose(np.sum(shap, axis=(len(shap.shape) - 1, len(shap.shape) - 2)), margin,
1e-3, 1e-3)
27 changes: 27 additions & 0 deletions tests/python/test_with_shap.py
@@ -0,0 +1,27 @@
# -*- coding: utf-8 -*-
RAMitchell marked this conversation as resolved.
Show resolved Hide resolved
import numpy as np
import xgboost as xgb
import testing as tm
import unittest
import pytest

try:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we move this into testing.py?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer it to stay here, I don't think it will be reused.

import shap
except ImportError:
shap = None
pass

pytestmark = pytest.mark.skipif(shap is None, reason="Requires shap package")


# Check integration is not broken from xgboost side
# Changes in binary format may cause problems
def test_with_shap():
X, y = shap.datasets.boston()
dtrain = xgb.DMatrix(X, label=y)
model = xgb.train({"learning_rate": 0.01}, dtrain, 10)
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X)
margin = model.predict(dtrain, output_margin=True)
assert np.allclose(np.sum(shap_values, axis=len(shap_values.shape) - 1),
margin - explainer.expected_value, 1e-3, 1e-3)