Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPUTreeShap #6038

Merged
merged 5 commits into from Aug 25, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Expand Up @@ -4,3 +4,6 @@
[submodule "cub"]
path = cub
url = https://github.com/NVlabs/cub
[submodule "gputreeshap"]
path = gputreeshap
url = https://github.com/rapidsai/gputreeshap.git
1 change: 1 addition & 0 deletions gputreeshap
Submodule gputreeshap added at a3d4c4
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Expand Up @@ -9,6 +9,7 @@ if (USE_CUDA)
file(GLOB_RECURSE CUDA_SOURCES *.cu *.cuh)
target_sources(objxgboost PRIVATE ${CUDA_SOURCES})
target_compile_definitions(objxgboost PRIVATE -DXGBOOST_USE_CUDA=1)
target_include_directories(objxgboost PRIVATE ${xgboost_SOURCE_DIR}/gputreeshap)
if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.0)
target_include_directories(objxgboost PRIVATE ${xgboost_SOURCE_DIR}/cub/)
endif (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.0)
Expand Down
8 changes: 4 additions & 4 deletions src/gbm/gbtree.h
Expand Up @@ -238,11 +238,11 @@ class GBTree : public GradientBooster {

void PredictContribution(DMatrix* p_fmat,
std::vector<bst_float>* out_contribs,
unsigned ntree_limit, bool approximate, int condition,
unsigned condition_feature) override {
unsigned ntree_limit, bool approximate,
int condition, unsigned condition_feature) override {
CHECK(configured_);
cpu_predictor_->PredictContribution(p_fmat, out_contribs, model_,
ntree_limit, nullptr, approximate);
this->GetPredictor()->PredictContribution(
p_fmat, out_contribs, model_, ntree_limit, nullptr, approximate);
}

void PredictInteractionContributions(DMatrix* p_fmat,
Expand Down
200 changes: 149 additions & 51 deletions src/predictor/gpu_predictor.cu
Expand Up @@ -5,6 +5,7 @@
#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <thrust/fill.h>
#include <GPUTreeShap/gpu_treeshap.h>
#include <memory>

#include "xgboost/data.h"
Expand All @@ -27,72 +28,79 @@ DMLC_REGISTRY_FILE_TAG(gpu_predictor);
struct SparsePageView {
common::Span<const Entry> d_data;
common::Span<const bst_row_t> d_row_ptr;
bst_feature_t num_features;

XGBOOST_DEVICE SparsePageView(common::Span<const Entry> data,
common::Span<const bst_row_t> row_ptr) :
d_data{data}, d_row_ptr{row_ptr} {}
common::Span<const bst_row_t> row_ptr,
bst_feature_t num_features)
: d_data{data}, d_row_ptr{row_ptr}, num_features(num_features) {}
__device__ float GetElement(size_t ridx, size_t fidx) const {
// Binary search
auto begin_ptr = d_data.begin() + d_row_ptr[ridx];
auto end_ptr = d_data.begin() + d_row_ptr[ridx + 1];
if (end_ptr - begin_ptr == this->NumCols()) {
// Bypass span check for dense data
return d_data.data()[d_row_ptr[ridx] + fidx].fvalue;
}
common::Span<const Entry>::iterator previous_middle;
while (end_ptr != begin_ptr) {
auto middle = begin_ptr + (end_ptr - begin_ptr) / 2;
if (middle == previous_middle) {
break;
} else {
previous_middle = middle;
}

if (middle->index == fidx) {
return middle->fvalue;
} else if (middle->index < fidx) {
begin_ptr = middle;
} else {
end_ptr = middle;
}
}
// Value is missing
return nanf("");
}
XGBOOST_DEVICE size_t NumRows() const { return d_row_ptr.size() - 1; }
XGBOOST_DEVICE size_t NumCols() const { return num_features; }
};

struct SparsePageLoader {
bool use_shared;
common::Span<const bst_row_t> d_row_ptr;
common::Span<const Entry> d_data;
bst_feature_t num_features;
SparsePageView data;
float* smem;
size_t entry_start;

__device__ SparsePageLoader(SparsePageView data, bool use_shared, bst_feature_t num_features,
bst_row_t num_rows, size_t entry_start)
: use_shared(use_shared),
d_row_ptr(data.d_row_ptr),
d_data(data.d_data),
num_features(num_features),
data(data),
entry_start(entry_start) {
extern __shared__ float _smem[];
smem = _smem;
// Copy instances
if (use_shared) {
bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x;
int shared_elements = blockDim.x * num_features;
int shared_elements = blockDim.x * data.num_features;
dh::BlockFill(smem, shared_elements, nanf(""));
__syncthreads();
if (global_idx < num_rows) {
bst_uint elem_begin = d_row_ptr[global_idx];
bst_uint elem_end = d_row_ptr[global_idx + 1];
bst_uint elem_begin = data.d_row_ptr[global_idx];
bst_uint elem_end = data.d_row_ptr[global_idx + 1];
for (bst_uint elem_idx = elem_begin; elem_idx < elem_end; elem_idx++) {
Entry elem = d_data[elem_idx - entry_start];
smem[threadIdx.x * num_features + elem.index] = elem.fvalue;
Entry elem = data.d_data[elem_idx - entry_start];
smem[threadIdx.x * data.num_features + elem.index] = elem.fvalue;
}
}
__syncthreads();
}
}
__device__ float GetFvalue(int ridx, int fidx) const {
__device__ float GetElement(size_t ridx, size_t fidx) const {
if (use_shared) {
return smem[threadIdx.x * num_features + fidx];
return smem[threadIdx.x * data.num_features + fidx];
} else {
// Binary search
auto begin_ptr = d_data.begin() + (d_row_ptr[ridx] - entry_start);
auto end_ptr = d_data.begin() + (d_row_ptr[ridx + 1] - entry_start);
common::Span<const Entry>::iterator previous_middle;
while (end_ptr != begin_ptr) {
auto middle = begin_ptr + (end_ptr - begin_ptr) / 2;
if (middle == previous_middle) {
break;
} else {
previous_middle = middle;
}

if (middle->index == fidx) {
return middle->fvalue;
} else if (middle->index < fidx) {
begin_ptr = middle;
} else {
end_ptr = middle;
}
}
// Value is missing
return nanf("");
return data.GetElement(ridx, fidx);
}
}
};
Expand All @@ -103,7 +111,7 @@ struct EllpackLoader {
bst_feature_t num_features, bst_row_t num_rows,
size_t entry_start)
: matrix{m} {}
__device__ __forceinline__ float GetFvalue(int ridx, int fidx) const {
__device__ __forceinline__ float GetElement(size_t ridx, size_t fidx) const {
auto gidx = matrix.GetBinIndex(ridx, fidx);
if (gidx == -1) {
return nan("");
Expand Down Expand Up @@ -150,7 +158,7 @@ struct DeviceAdapterLoader {
__syncthreads();
}

DEV_INLINE float GetFvalue(bst_row_t ridx, bst_feature_t fidx) const {
DEV_INLINE float GetElement(size_t ridx, size_t fidx) const {
if (use_shared) {
return smem[threadIdx.x * columns + fidx];
}
Expand All @@ -163,7 +171,7 @@ __device__ float GetLeafWeight(bst_uint ridx, const RegTree::Node* tree,
Loader* loader) {
RegTree::Node n = tree[0];
while (!n.IsLeaf()) {
float fvalue = loader->GetFvalue(ridx, n.SplitIndex());
float fvalue = loader->GetElement(ridx, n.SplitIndex());
// Missing value
if (isnan(fvalue)) {
n = tree[n.DefaultChild()];
Expand Down Expand Up @@ -273,7 +281,8 @@ class GPUPredictor : public xgboost::Predictor {
use_shared = false;
}
size_t entry_start = 0;
SparsePageView data{batch.data.DeviceSpan(), batch.offset.DeviceSpan()};
SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
num_features);
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} (
PredictKernel<SparsePageLoader, SparsePageView>,
data,
Expand Down Expand Up @@ -447,6 +456,61 @@ class GPUPredictor : public xgboost::Predictor {
}
}

void PredictContribution(DMatrix* p_fmat,
std::vector<bst_float>* out_contribs,
const gbm::GBTreeModel& model, unsigned ntree_limit,
std::vector<bst_float>* tree_weights,
bool approximate, int condition,
unsigned condition_feature) override {
if (approximate) {
LOG(FATAL) << "[Internal error]: " << __func__
<< " approximate is not implemented in GPU Predictor.";
}

uint32_t real_ntree_limit =
ntree_limit * model.learner_model_param->num_output_group;
if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) {
real_ntree_limit = static_cast<uint32_t>(model.trees.size());
}

const int ngroup = model.learner_model_param->num_output_group;
CHECK_NE(ngroup, 0);
// allocate space for (number of features + bias) times the number of rows
std::vector<bst_float>& contribs = *out_contribs;
size_t contributions_columns =
model.learner_model_param->num_feature + 1; // +1 for bias
contribs.resize(p_fmat->Info().num_row_ * contributions_columns *
model.learner_model_param->num_output_group);
dh::caching_device_vector<float> phis(contribs.size(), 0.0);
RAMitchell marked this conversation as resolved.
Show resolved Hide resolved
// Add the bias term to last column
RAMitchell marked this conversation as resolved.
Show resolved Hide resolved
p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id);
const auto margin = p_fmat->Info().base_margin_.ConstDeviceSpan();
CHECK(margin.empty());
float base_score = model.learner_model_param->base_score;
auto d_phis = phis.data().get();
dh::LaunchN(
generic_param_->gpu_id,
p_fmat->Info().num_row_ * model.learner_model_param->num_output_group,
[=] __device__(size_t idx) {
d_phis[(idx + 1) * contributions_columns - 1] =
margin.empty() ? base_score : margin[idx];
trivialfis marked this conversation as resolved.
Show resolved Hide resolved
});

const auto& paths = this->ExtractPaths(model, real_ntree_limit);
for (auto& batch : p_fmat->GetBatches<SparsePage>()) {
batch.data.SetDevice(generic_param_->gpu_id);
batch.offset.SetDevice(generic_param_->gpu_id);
SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
model.learner_model_param->num_feature);
gpu_treeshap::GPUTreeShap(
X, paths, ngroup,
phis.data().get() + batch.base_rowid * contributions_columns);
}
dh::safe_cuda(cudaMemcpyAsync(contribs.data(), phis.data().get(),
sizeof(float) * phis.size(),
cudaMemcpyDefault));
}

protected:
void InitOutPredictions(const MetaInfo& info,
HostDeviceVector<bst_float>* out_preds,
Expand Down Expand Up @@ -478,16 +542,6 @@ class GPUPredictor : public xgboost::Predictor {
<< " is not implemented in GPU Predictor.";
}

void PredictContribution(DMatrix* p_fmat,
std::vector<bst_float>* out_contribs,
const gbm::GBTreeModel& model, unsigned ntree_limit,
std::vector<bst_float>* tree_weights,
bool approximate, int condition,
unsigned condition_feature) override {
LOG(FATAL) << "[Internal error]: " << __func__
<< " is not implemented in GPU Predictor.";
}

void PredictInteractionContributions(DMatrix* p_fmat,
std::vector<bst_float>* out_contribs,
const gbm::GBTreeModel& model,
Expand All @@ -510,6 +564,50 @@ class GPUPredictor : public xgboost::Predictor {
}
}

std::vector<gpu_treeshap::PathElement> ExtractPaths(
const gbm::GBTreeModel& model, size_t tree_limit) {
std::vector<gpu_treeshap::PathElement> paths;
size_t path_idx = 0;
CHECK_LE(tree_limit, model.trees.size());
for (auto i = 0ull; i < tree_limit; i++) {
const auto& tree = *model.trees.at(i);
size_t group = model.tree_info[i];
for (auto j = 0ull; j < tree.GetNodes().size(); j++) {
RAMitchell marked this conversation as resolved.
Show resolved Hide resolved
const auto& nodes = tree.GetNodes();
RAMitchell marked this conversation as resolved.
Show resolved Hide resolved
if (nodes[j].IsLeaf()) {
auto child = nodes[j];
float v = child.LeafValue();
size_t child_idx = j;
const float inf = std::numeric_limits<float>::infinity();
while (!child.IsRoot()) {
float child_cover = tree.Stat(child_idx).sum_hess;
float parent_cover = tree.Stat(child.Parent()).sum_hess;
float zero_fraction = child_cover / parent_cover;
CHECK(zero_fraction >= 0.0 && zero_fraction <= 1.0);
auto parent = nodes[child.Parent()];
CHECK(parent.LeftChild() == child_idx ||
parent.RightChild() == child_idx);
bool is_left_path = parent.LeftChild() == child_idx;
bool is_missing_path = (!parent.DefaultLeft() && !is_left_path) ||
(parent.DefaultLeft() && is_left_path);
float lower_bound = is_left_path ? -inf : parent.SplitCond();
float upper_bound = is_left_path ? parent.SplitCond() : inf;
paths.emplace_back(path_idx, parent.SplitIndex(), group,

RAMitchell marked this conversation as resolved.
Show resolved Hide resolved
lower_bound, upper_bound, is_missing_path,
hcho3 marked this conversation as resolved.
Show resolved Hide resolved
zero_fraction, v);
child_idx = child.Parent();
child = parent;
}
// Root node has feature -1
paths.emplace_back(path_idx, -1, group, -inf, inf, false, 1.0, v);
path_idx++;
}
}
}
return paths;
}

std::mutex lock_;
DeviceModel model_;
size_t max_shared_memory_bytes_;
Expand Down
56 changes: 56 additions & 0 deletions tests/cpp/predictor/test_gpu_predictor.cu
Expand Up @@ -163,5 +163,61 @@ TEST(GPUPredictor, MGPU_InplacePredict) { // NOLINT
TEST(GpuPredictor, LesserFeatures) {
TestPredictionWithLesserFeatures("gpu_predictor");
}
// Very basic test of empty model
TEST(GPUPredictor, ShapStump) {
cudaSetDevice(0);
LearnerModelParam param;
param.num_feature = 1;
param.num_output_group = 1;
param.base_score = 0.5;
gbm::GBTreeModel model(&param);
std::vector<std::unique_ptr<RegTree>> trees;
trees.push_back(std::unique_ptr<RegTree>(new RegTree));
model.CommitModel(std::move(trees), 0);

auto gpu_lparam = CreateEmptyGenericParam(0);
std::unique_ptr<Predictor> gpu_predictor =
std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor", &gpu_lparam));
gpu_predictor->Configure({});
std::vector<float > phis;
auto dmat = RandomDataGenerator(3, 1, 0).GenerateDMatrix();
gpu_predictor->PredictContribution(dmat.get(), &phis, model);
EXPECT_EQ(phis[0], 0.0);
EXPECT_EQ(phis[1], param.base_score);
EXPECT_EQ(phis[2], 0.0);
EXPECT_EQ(phis[3], param.base_score);
EXPECT_EQ(phis[4], 0.0);
EXPECT_EQ(phis[5], param.base_score);
}
TEST(GPUPredictor, Shap) {
LearnerModelParam param;
param.num_feature = 1;
param.num_output_group = 1;
param.base_score = 0.5;
gbm::GBTreeModel model(&param);
std::vector<std::unique_ptr<RegTree>> trees;
trees.push_back(std::unique_ptr<RegTree>(new RegTree));
trees[0]->ExpandNode(0, 0, 0.5, true, 1.0, -1.0, 1.0, 0.0, 5.0, 2.0, 3.0);
model.CommitModel(std::move(trees), 0);

auto gpu_lparam = CreateEmptyGenericParam(0);
auto cpu_lparam = CreateEmptyGenericParam(-1);
std::unique_ptr<Predictor> gpu_predictor =
std::unique_ptr<Predictor>(Predictor::Create("gpu_predictor", &gpu_lparam));
std::unique_ptr<Predictor> cpu_predictor =
std::unique_ptr<Predictor>(Predictor::Create("cpu_predictor", &cpu_lparam));
gpu_predictor->Configure({});
cpu_predictor->Configure({});
std::vector<float > phis;
std::vector<float > cpu_phis;
auto dmat = RandomDataGenerator(3, 1, 0).GenerateDMatrix();
gpu_predictor->PredictContribution(dmat.get(), &phis, model);
cpu_predictor->PredictContribution(dmat.get(), &cpu_phis, model);
for(auto i = 0ull; i < phis.size(); i++)
{
EXPECT_NEAR(cpu_phis[i], phis[i], 1e-3);
}
}

} // namespace predictor
} // namespace xgboost