Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update GPUTreeShap, add docs #6281

Merged
merged 3 commits into from Oct 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion demo/gpu_acceleration/README.md
@@ -1,3 +1,5 @@
# GPU Acceleration Demo

`cover_type.py` shows how to train a model on the [forest cover type](https://archive.ics.uci.edu/ml/datasets/covertype) dataset using GPU acceleration. The forest cover type dataset has 581,012 rows and 54 features, making it time consuming to process. We compare the run-time and accuracy of the GPU and CPU histogram algorithms.
`cover_type.py` shows how to train a model on the [forest cover type](https://archive.ics.uci.edu/ml/datasets/covertype) dataset using GPU acceleration. The forest cover type dataset has 581,012 rows and 54 features, making it time consuming to process. We compare the run-time and accuracy of the GPU and CPU histogram algorithms.

`shap.ipynb` demonstrates using GPU acceleration to compute SHAP values for feature importance.
211 changes: 211 additions & 0 deletions demo/gpu_acceleration/shap.ipynb

Large diffs are not rendered by default.

13 changes: 13 additions & 0 deletions doc/gpu/index.rst
Expand Up @@ -85,6 +85,19 @@ The GPU algorithms currently work with CLI, Python and R packages. See :doc:`/bu
XGBRegressor(tree_method='gpu_hist', gpu_id=0)


GPU-Accelerated SHAP values
=============================
XGBoost makes use of `GPUTreeShap <https://github.com/rapidsai/gputreeshap>`_ as a backend for computing shap values when the GPU predictor is selected.

.. code-block:: python

model.set_param({"predictor": "gpu_predictor"})
shap_values = model.predict(dtrain, pred_contribs=True)
shap_interaction_values = model.predict(dtrain, pred_interactions=True)

See examples `here
<https://github.com/dmlc/xgboost/tree/master/demo/gpu_acceleration>`_.

Multi-node Multi-GPU Training
=============================
XGBoost supports fully distributed GPU training using `Dask <https://dask.org/>`_. For
Expand Down
52 changes: 26 additions & 26 deletions src/predictor/gpu_predictor.cu
Expand Up @@ -671,17 +671,6 @@ class GPUPredictor : public xgboost::Predictor {
model.learner_model_param->num_output_group);
out_contribs->Fill(0.0f);
auto phis = out_contribs->DeviceSpan();
p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id);
const auto margin = p_fmat->Info().base_margin_.ConstDeviceSpan();
float base_score = model.learner_model_param->base_score;
// Add the base margin term to last column
dh::LaunchN(
generic_param_->gpu_id,
p_fmat->Info().num_row_ * model.learner_model_param->num_output_group,
[=] __device__(size_t idx) {
phis[(idx + 1) * contributions_columns - 1] =
margin.empty() ? base_score : margin[idx];
});

dh::device_vector<gpu_treeshap::PathElement> device_paths;
ExtractPaths(&device_paths, model, real_ntree_limit,
Expand All @@ -695,6 +684,17 @@ class GPUPredictor : public xgboost::Predictor {
X, device_paths.begin(), device_paths.end(), ngroup,
phis.data() + batch.base_rowid * contributions_columns, phis.size());
}
// Add the base margin term to last column
p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id);
const auto margin = p_fmat->Info().base_margin_.ConstDeviceSpan();
float base_score = model.learner_model_param->base_score;
dh::LaunchN(
generic_param_->gpu_id,
p_fmat->Info().num_row_ * model.learner_model_param->num_output_group,
[=] __device__(size_t idx) {
phis[(idx + 1) * contributions_columns - 1] +=
margin.empty() ? base_score : margin[idx];
});
}

void PredictInteractionContributions(DMatrix* p_fmat,
Expand Down Expand Up @@ -726,21 +726,6 @@ class GPUPredictor : public xgboost::Predictor {
model.learner_model_param->num_output_group);
out_contribs->Fill(0.0f);
auto phis = out_contribs->DeviceSpan();
p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id);
const auto margin = p_fmat->Info().base_margin_.ConstDeviceSpan();
float base_score = model.learner_model_param->base_score;
// Add the base margin term to last column
size_t n_features = model.learner_model_param->num_feature;
dh::LaunchN(
generic_param_->gpu_id,
p_fmat->Info().num_row_ * model.learner_model_param->num_output_group,
[=] __device__(size_t idx) {
size_t group = idx % ngroup;
size_t row_idx = idx / ngroup;
phis[gpu_treeshap::IndexPhiInteractions(
row_idx, ngroup, group, n_features, n_features, n_features)] =
margin.empty() ? base_score : margin[idx];
});

dh::device_vector<gpu_treeshap::PathElement> device_paths;
ExtractPaths(&device_paths, model, real_ntree_limit,
Expand All @@ -754,6 +739,21 @@ class GPUPredictor : public xgboost::Predictor {
X, device_paths.begin(), device_paths.end(), ngroup,
phis.data() + batch.base_rowid * contributions_columns, phis.size());
}
// Add the base margin term to last column
p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id);
const auto margin = p_fmat->Info().base_margin_.ConstDeviceSpan();
float base_score = model.learner_model_param->base_score;
size_t n_features = model.learner_model_param->num_feature;
dh::LaunchN(
generic_param_->gpu_id,
p_fmat->Info().num_row_ * model.learner_model_param->num_output_group,
[=] __device__(size_t idx) {
size_t group = idx % ngroup;
size_t row_idx = idx / ngroup;
phis[gpu_treeshap::IndexPhiInteractions(
row_idx, ngroup, group, n_features, n_features, n_features)] +=
margin.empty() ? base_score : margin[idx];
});
}

protected:
Expand Down