Skip to content

Commit

Permalink
GPUTreeShap (#6038)
Browse files Browse the repository at this point in the history
  • Loading branch information
RAMitchell committed Aug 25, 2020
1 parent b319305 commit 9a4e8b1
Show file tree
Hide file tree
Showing 9 changed files with 266 additions and 62 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Expand Up @@ -4,3 +4,6 @@
[submodule "cub"]
path = cub
url = https://github.com/NVlabs/cub
[submodule "gputreeshap"]
path = gputreeshap
url = https://github.com/rapidsai/gputreeshap.git
1 change: 1 addition & 0 deletions gputreeshap
Submodule gputreeshap added at a3d4c4
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Expand Up @@ -9,6 +9,7 @@ if (USE_CUDA)
file(GLOB_RECURSE CUDA_SOURCES *.cu *.cuh)
target_sources(objxgboost PRIVATE ${CUDA_SOURCES})
target_compile_definitions(objxgboost PRIVATE -DXGBOOST_USE_CUDA=1)
target_include_directories(objxgboost PRIVATE ${xgboost_SOURCE_DIR}/gputreeshap)
if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.0)
target_include_directories(objxgboost PRIVATE ${xgboost_SOURCE_DIR}/cub/)
endif (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.0)
Expand Down
12 changes: 11 additions & 1 deletion src/common/device_helpers.cuh
Expand Up @@ -474,8 +474,18 @@ class TemporaryArray {
using AllocT = XGBCachingDeviceAllocator<T>;
using value_type = T; // NOLINT
explicit TemporaryArray(size_t n) : size_(n) { ptr_ = AllocT().allocate(n); }
TemporaryArray(size_t n, T val) : size_(n) {
ptr_ = AllocT().allocate(n);
this->fill(val);
}
~TemporaryArray() { AllocT().deallocate(ptr_, this->size()); }

void fill(T val) // NOLINT
{
int device = 0;
dh::safe_cuda(cudaGetDevice(&device));
auto d_data = ptr_.get();
LaunchN(device, this->size(), [=] __device__(size_t idx) { d_data[idx] = val; });
}
thrust::device_ptr<T> data() { return ptr_; } // NOLINT
size_t size() { return size_; } // NOLINT

Expand Down
8 changes: 4 additions & 4 deletions src/gbm/gbtree.h
Expand Up @@ -238,11 +238,11 @@ class GBTree : public GradientBooster {

void PredictContribution(DMatrix* p_fmat,
std::vector<bst_float>* out_contribs,
unsigned ntree_limit, bool approximate, int condition,
unsigned condition_feature) override {
unsigned ntree_limit, bool approximate,
int condition, unsigned condition_feature) override {
CHECK(configured_);
cpu_predictor_->PredictContribution(p_fmat, out_contribs, model_,
ntree_limit, nullptr, approximate);
this->GetPredictor()->PredictContribution(
p_fmat, out_contribs, model_, ntree_limit, nullptr, approximate);
}

void PredictInteractionContributions(DMatrix* p_fmat,
Expand Down
198 changes: 147 additions & 51 deletions src/predictor/gpu_predictor.cu
Expand Up @@ -5,6 +5,7 @@
#include <thrust/device_ptr.h>
#include <thrust/device_vector.h>
#include <thrust/fill.h>
#include <GPUTreeShap/gpu_treeshap.h>
#include <memory>

#include "xgboost/data.h"
Expand All @@ -27,72 +28,79 @@ DMLC_REGISTRY_FILE_TAG(gpu_predictor);
struct SparsePageView {
common::Span<const Entry> d_data;
common::Span<const bst_row_t> d_row_ptr;
bst_feature_t num_features;

XGBOOST_DEVICE SparsePageView(common::Span<const Entry> data,
common::Span<const bst_row_t> row_ptr) :
d_data{data}, d_row_ptr{row_ptr} {}
common::Span<const bst_row_t> row_ptr,
bst_feature_t num_features)
: d_data{data}, d_row_ptr{row_ptr}, num_features(num_features) {}
__device__ float GetElement(size_t ridx, size_t fidx) const {
// Binary search
auto begin_ptr = d_data.begin() + d_row_ptr[ridx];
auto end_ptr = d_data.begin() + d_row_ptr[ridx + 1];
if (end_ptr - begin_ptr == this->NumCols()) {
// Bypass span check for dense data
return d_data.data()[d_row_ptr[ridx] + fidx].fvalue;
}
common::Span<const Entry>::iterator previous_middle;
while (end_ptr != begin_ptr) {
auto middle = begin_ptr + (end_ptr - begin_ptr) / 2;
if (middle == previous_middle) {
break;
} else {
previous_middle = middle;
}

if (middle->index == fidx) {
return middle->fvalue;
} else if (middle->index < fidx) {
begin_ptr = middle;
} else {
end_ptr = middle;
}
}
// Value is missing
return nanf("");
}
XGBOOST_DEVICE size_t NumRows() const { return d_row_ptr.size() - 1; }
XGBOOST_DEVICE size_t NumCols() const { return num_features; }
};

struct SparsePageLoader {
bool use_shared;
common::Span<const bst_row_t> d_row_ptr;
common::Span<const Entry> d_data;
bst_feature_t num_features;
SparsePageView data;
float* smem;
size_t entry_start;

__device__ SparsePageLoader(SparsePageView data, bool use_shared, bst_feature_t num_features,
bst_row_t num_rows, size_t entry_start)
: use_shared(use_shared),
d_row_ptr(data.d_row_ptr),
d_data(data.d_data),
num_features(num_features),
data(data),
entry_start(entry_start) {
extern __shared__ float _smem[];
smem = _smem;
// Copy instances
if (use_shared) {
bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x;
int shared_elements = blockDim.x * num_features;
int shared_elements = blockDim.x * data.num_features;
dh::BlockFill(smem, shared_elements, nanf(""));
__syncthreads();
if (global_idx < num_rows) {
bst_uint elem_begin = d_row_ptr[global_idx];
bst_uint elem_end = d_row_ptr[global_idx + 1];
bst_uint elem_begin = data.d_row_ptr[global_idx];
bst_uint elem_end = data.d_row_ptr[global_idx + 1];
for (bst_uint elem_idx = elem_begin; elem_idx < elem_end; elem_idx++) {
Entry elem = d_data[elem_idx - entry_start];
smem[threadIdx.x * num_features + elem.index] = elem.fvalue;
Entry elem = data.d_data[elem_idx - entry_start];
smem[threadIdx.x * data.num_features + elem.index] = elem.fvalue;
}
}
__syncthreads();
}
}
__device__ float GetFvalue(int ridx, int fidx) const {
__device__ float GetElement(size_t ridx, size_t fidx) const {
if (use_shared) {
return smem[threadIdx.x * num_features + fidx];
return smem[threadIdx.x * data.num_features + fidx];
} else {
// Binary search
auto begin_ptr = d_data.begin() + (d_row_ptr[ridx] - entry_start);
auto end_ptr = d_data.begin() + (d_row_ptr[ridx + 1] - entry_start);
common::Span<const Entry>::iterator previous_middle;
while (end_ptr != begin_ptr) {
auto middle = begin_ptr + (end_ptr - begin_ptr) / 2;
if (middle == previous_middle) {
break;
} else {
previous_middle = middle;
}

if (middle->index == fidx) {
return middle->fvalue;
} else if (middle->index < fidx) {
begin_ptr = middle;
} else {
end_ptr = middle;
}
}
// Value is missing
return nanf("");
return data.GetElement(ridx, fidx);
}
}
};
Expand All @@ -103,7 +111,7 @@ struct EllpackLoader {
bst_feature_t num_features, bst_row_t num_rows,
size_t entry_start)
: matrix{m} {}
__device__ __forceinline__ float GetFvalue(int ridx, int fidx) const {
__device__ __forceinline__ float GetElement(size_t ridx, size_t fidx) const {
auto gidx = matrix.GetBinIndex(ridx, fidx);
if (gidx == -1) {
return nan("");
Expand Down Expand Up @@ -150,7 +158,7 @@ struct DeviceAdapterLoader {
__syncthreads();
}

DEV_INLINE float GetFvalue(bst_row_t ridx, bst_feature_t fidx) const {
DEV_INLINE float GetElement(size_t ridx, size_t fidx) const {
if (use_shared) {
return smem[threadIdx.x * columns + fidx];
}
Expand All @@ -163,7 +171,7 @@ __device__ float GetLeafWeight(bst_uint ridx, const RegTree::Node* tree,
Loader* loader) {
RegTree::Node n = tree[0];
while (!n.IsLeaf()) {
float fvalue = loader->GetFvalue(ridx, n.SplitIndex());
float fvalue = loader->GetElement(ridx, n.SplitIndex());
// Missing value
if (isnan(fvalue)) {
n = tree[n.DefaultChild()];
Expand Down Expand Up @@ -273,7 +281,8 @@ class GPUPredictor : public xgboost::Predictor {
use_shared = false;
}
size_t entry_start = 0;
SparsePageView data{batch.data.DeviceSpan(), batch.offset.DeviceSpan()};
SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
num_features);
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} (
PredictKernel<SparsePageLoader, SparsePageView>,
data,
Expand Down Expand Up @@ -447,6 +456,60 @@ class GPUPredictor : public xgboost::Predictor {
}
}

void PredictContribution(DMatrix* p_fmat,
std::vector<bst_float>* out_contribs,
const gbm::GBTreeModel& model, unsigned ntree_limit,
std::vector<bst_float>* tree_weights,
bool approximate, int condition,
unsigned condition_feature) override {
if (approximate) {
LOG(FATAL) << "[Internal error]: " << __func__
<< " approximate is not implemented in GPU Predictor.";
}

uint32_t real_ntree_limit =
ntree_limit * model.learner_model_param->num_output_group;
if (real_ntree_limit == 0 || real_ntree_limit > model.trees.size()) {
real_ntree_limit = static_cast<uint32_t>(model.trees.size());
}

const int ngroup = model.learner_model_param->num_output_group;
CHECK_NE(ngroup, 0);
// allocate space for (number of features + bias) times the number of rows
std::vector<bst_float>& contribs = *out_contribs;
size_t contributions_columns =
model.learner_model_param->num_feature + 1; // +1 for bias
contribs.resize(p_fmat->Info().num_row_ * contributions_columns *
model.learner_model_param->num_output_group);
dh::TemporaryArray<float> phis(contribs.size(), 0.0);
p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id);
const auto margin = p_fmat->Info().base_margin_.ConstDeviceSpan();
float base_score = model.learner_model_param->base_score;
auto d_phis = phis.data().get();
// Add the base margin term to last column
dh::LaunchN(
generic_param_->gpu_id,
p_fmat->Info().num_row_ * model.learner_model_param->num_output_group,
[=] __device__(size_t idx) {
d_phis[(idx + 1) * contributions_columns - 1] =
margin.empty() ? base_score : margin[idx];
});

const auto& paths = this->ExtractPaths(model, real_ntree_limit);
for (auto& batch : p_fmat->GetBatches<SparsePage>()) {
batch.data.SetDevice(generic_param_->gpu_id);
batch.offset.SetDevice(generic_param_->gpu_id);
SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
model.learner_model_param->num_feature);
gpu_treeshap::GPUTreeShap(
X, paths, ngroup,
phis.data().get() + batch.base_rowid * contributions_columns);
}
dh::safe_cuda(cudaMemcpyAsync(contribs.data(), phis.data().get(),
sizeof(float) * phis.size(),
cudaMemcpyDefault));
}

protected:
void InitOutPredictions(const MetaInfo& info,
HostDeviceVector<bst_float>* out_preds,
Expand Down Expand Up @@ -478,16 +541,6 @@ class GPUPredictor : public xgboost::Predictor {
<< " is not implemented in GPU Predictor.";
}

void PredictContribution(DMatrix* p_fmat,
std::vector<bst_float>* out_contribs,
const gbm::GBTreeModel& model, unsigned ntree_limit,
std::vector<bst_float>* tree_weights,
bool approximate, int condition,
unsigned condition_feature) override {
LOG(FATAL) << "[Internal error]: " << __func__
<< " is not implemented in GPU Predictor.";
}

void PredictInteractionContributions(DMatrix* p_fmat,
std::vector<bst_float>* out_contribs,
const gbm::GBTreeModel& model,
Expand All @@ -510,6 +563,49 @@ class GPUPredictor : public xgboost::Predictor {
}
}

std::vector<gpu_treeshap::PathElement> ExtractPaths(
const gbm::GBTreeModel& model, size_t tree_limit) {
std::vector<gpu_treeshap::PathElement> paths;
size_t path_idx = 0;
CHECK_LE(tree_limit, model.trees.size());
for (auto i = 0ull; i < tree_limit; i++) {
const auto& tree = *model.trees.at(i);
size_t group = model.tree_info[i];
const auto& nodes = tree.GetNodes();
for (auto j = 0ull; j < nodes.size(); j++) {
if (nodes[j].IsLeaf() && !nodes[j].IsDeleted()) {
auto child = nodes[j];
float v = child.LeafValue();
size_t child_idx = j;
const float inf = std::numeric_limits<float>::infinity();
while (!child.IsRoot()) {
float child_cover = tree.Stat(child_idx).sum_hess;
float parent_cover = tree.Stat(child.Parent()).sum_hess;
float zero_fraction = child_cover / parent_cover;
CHECK(zero_fraction >= 0.0 && zero_fraction <= 1.0);
auto parent = nodes[child.Parent()];
CHECK(parent.LeftChild() == child_idx ||
parent.RightChild() == child_idx);
bool is_left_path = parent.LeftChild() == child_idx;
bool is_missing_path = (!parent.DefaultLeft() && !is_left_path) ||
(parent.DefaultLeft() && is_left_path);
float lower_bound = is_left_path ? -inf : parent.SplitCond();
float upper_bound = is_left_path ? parent.SplitCond() : inf;
paths.emplace_back(path_idx, parent.SplitIndex(), group,
lower_bound, upper_bound, is_missing_path,
zero_fraction, v);
child_idx = child.Parent();
child = parent;
}
// Root node has feature -1
paths.emplace_back(path_idx, -1, group, -inf, inf, false, 1.0, v);
path_idx++;
}
}
}
return paths;
}

std::mutex lock_;
DeviceModel model_;
size_t max_shared_memory_bytes_;
Expand Down

0 comments on commit 9a4e8b1

Please sign in to comment.