Skip to content

Commit

Permalink
Updates to GPUTreeShap (#6087)
Browse files Browse the repository at this point in the history
* Extract paths on device

* Update GPUTreeShap
  • Loading branch information
RAMitchell committed Sep 6, 2020
1 parent 0e2d566 commit 2e907ab
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 52 deletions.
2 changes: 1 addition & 1 deletion gputreeshap
3 changes: 3 additions & 0 deletions include/xgboost/tree_model.h
Expand Up @@ -337,6 +337,9 @@ class RegTree : public Model {
/*! \brief get const reference to nodes */
const std::vector<Node>& GetNodes() const { return nodes_; }

/*! \brief get const reference to stats */
const std::vector<RTreeNodeStat>& GetStats() const { return stats_; }

/*! \brief get node statistics given nid */
RTreeNodeStat& Stat(int nid) {
return stats_[nid];
Expand Down
1 change: 1 addition & 0 deletions src/common/host_device_vector.cu
Expand Up @@ -404,6 +404,7 @@ template class HostDeviceVector<Entry>;
template class HostDeviceVector<uint64_t>; // bst_row_t
template class HostDeviceVector<uint32_t>; // bst_feature_t
template class HostDeviceVector<RegTree::Node>;
template class HostDeviceVector<RTreeNodeStat>;

#if defined(__APPLE__)
/*
Expand Down
156 changes: 105 additions & 51 deletions src/predictor/gpu_predictor.cu
Expand Up @@ -223,6 +223,7 @@ class DeviceModel {
public:
// Need to lazily construct the vectors because GPU id is only known at runtime
HostDeviceVector<RegTree::Node> nodes;
HostDeviceVector<RTreeNodeStat> stats;
HostDeviceVector<size_t> tree_segments;
HostDeviceVector<int> tree_group;
size_t tree_beg_; // NOLINT
Expand All @@ -246,22 +247,116 @@ class DeviceModel {

nodes = std::move(HostDeviceVector<RegTree::Node>(h_tree_segments.back(), RegTree::Node(),
gpu_id));
auto& h_nodes = nodes.HostVector();
stats = std::move(HostDeviceVector<RTreeNodeStat>(h_tree_segments.back(),
RTreeNodeStat(), gpu_id));
auto d_nodes = nodes.DevicePointer();
auto d_stats = stats.DevicePointer();
for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
auto& src_nodes = model.trees.at(tree_idx)->GetNodes();
std::copy(src_nodes.begin(), src_nodes.end(),
h_nodes.begin() + h_tree_segments[tree_idx - tree_begin]);
auto& src_stats = model.trees.at(tree_idx)->GetStats();
dh::safe_cuda(cudaMemcpyAsync(
d_nodes + h_tree_segments[tree_idx - tree_begin], src_nodes.data(),
sizeof(RegTree::Node) * src_nodes.size(), cudaMemcpyDefault));
dh::safe_cuda(cudaMemcpyAsync(
d_stats + h_tree_segments[tree_idx - tree_begin], src_stats.data(),
sizeof(RTreeNodeStat) * src_stats.size(), cudaMemcpyDefault));
}

tree_group = std::move(HostDeviceVector<int>(model.tree_info.size(), 0, gpu_id));
auto& h_tree_group = tree_group.HostVector();
std::memcpy(h_tree_group.data(), model.tree_info.data(), sizeof(int) * model.tree_info.size());
auto d_tree_group = tree_group.DevicePointer();
dh::safe_cuda(cudaMemcpyAsync(d_tree_group, model.tree_info.data(),
sizeof(int) * model.tree_info.size(),
cudaMemcpyDefault));
this->tree_beg_ = tree_begin;
this->tree_end_ = tree_end;
this->num_group = model.learner_model_param->num_output_group;
}
};

struct PathInfo {
int64_t leaf_position; // -1 not a leaf
size_t length;
size_t tree_idx;
};

// Transform model into path element form for GPUTreeShap
void ExtractPaths(dh::device_vector<gpu_treeshap::PathElement>* paths,
const gbm::GBTreeModel& model, size_t tree_limit,
int gpu_id) {
DeviceModel device_model;
device_model.Init(model, 0, tree_limit, gpu_id);
dh::caching_device_vector<PathInfo> info(device_model.nodes.Size());
dh::XGBCachingDeviceAllocator<PathInfo> alloc;
auto d_nodes = device_model.nodes.ConstDeviceSpan();
auto d_tree_segments = device_model.tree_segments.ConstDeviceSpan();
auto nodes_transform = dh::MakeTransformIterator<PathInfo>(
thrust::make_counting_iterator(0ull), [=] __device__(size_t idx) {
auto n = d_nodes[idx];
if (!n.IsLeaf() || n.IsDeleted()) {
return PathInfo{-1, 0, 0};
}
size_t tree_idx =
dh::SegmentId(d_tree_segments.begin(), d_tree_segments.end(), idx);
size_t tree_offset = d_tree_segments[tree_idx];
size_t path_length = 1;
while (!n.IsRoot()) {
n = d_nodes[n.Parent() + tree_offset];
path_length++;
}
return PathInfo{int64_t(idx), path_length, tree_idx};
});
auto end = thrust::copy_if(
thrust::cuda::par(alloc), nodes_transform,
nodes_transform + d_nodes.size(), info.begin(),
[=] __device__(const PathInfo& e) { return e.leaf_position != -1; });
info.resize(end - info.begin());
auto length_iterator = dh::MakeTransformIterator<size_t>(
info.begin(),
[=] __device__(const PathInfo& info) { return info.length; });
dh::caching_device_vector<size_t> path_segments(info.size() + 1);
thrust::exclusive_scan(thrust::cuda::par(alloc), length_iterator,
length_iterator + info.size() + 1,
path_segments.begin());

paths->resize(path_segments.back());

auto d_paths = paths->data().get();
auto d_info = info.data().get();
auto d_stats = device_model.stats.ConstDeviceSpan();
auto d_tree_group = device_model.tree_group.ConstDeviceSpan();
auto d_path_segments = path_segments.data().get();
dh::LaunchN(gpu_id, info.size(), [=] __device__(size_t idx) {
auto path_info = d_info[idx];
size_t tree_offset = d_tree_segments[path_info.tree_idx];
int group = d_tree_group[path_info.tree_idx];
size_t child_idx = path_info.leaf_position;
auto child = d_nodes[child_idx];
float v = child.LeafValue();
const float inf = std::numeric_limits<float>::infinity();
size_t output_position = d_path_segments[idx + 1] - 1;
while (!child.IsRoot()) {
size_t parent_idx = tree_offset + child.Parent();
double child_cover = d_stats[child_idx].sum_hess;
double parent_cover = d_stats[parent_idx].sum_hess;
double zero_fraction = child_cover / parent_cover;
auto parent = d_nodes[parent_idx];
bool is_left_path = (tree_offset + parent.LeftChild()) == child_idx;
bool is_missing_path = (!parent.DefaultLeft() && !is_left_path) ||
(parent.DefaultLeft() && is_left_path);
float lower_bound = is_left_path ? -inf : parent.SplitCond();
float upper_bound = is_left_path ? parent.SplitCond() : inf;
d_paths[output_position--] = {
idx, parent.SplitIndex(), group, lower_bound,
upper_bound, is_missing_path, zero_fraction, v};
child_idx = parent_idx;
child = parent;
}
// Root node has feature -1
d_paths[output_position] = {idx, -1, group, -inf, inf, false, 1.0, v};
});
}


class GPUPredictor : public xgboost::Predictor {
private:
void PredictInternal(const SparsePage& batch, size_t num_features,
Expand Down Expand Up @@ -495,17 +590,19 @@ class GPUPredictor : public xgboost::Predictor {
margin.empty() ? base_score : margin[idx];
});

const auto& paths = this->ExtractPaths(model, real_ntree_limit);
dh::device_vector<gpu_treeshap::PathElement> device_paths;
ExtractPaths(&device_paths, model, real_ntree_limit,
generic_param_->gpu_id);
for (auto& batch : p_fmat->GetBatches<SparsePage>()) {
batch.data.SetDevice(generic_param_->gpu_id);
batch.offset.SetDevice(generic_param_->gpu_id);
SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
model.learner_model_param->num_feature);
gpu_treeshap::GPUTreeShap(
X, paths, ngroup,
X, device_paths.begin(), device_paths.end(), ngroup,
phis.data().get() + batch.base_rowid * contributions_columns);
}
dh::safe_cuda(cudaMemcpyAsync(contribs.data(), phis.data().get(),
dh::safe_cuda(cudaMemcpy(contribs.data(), phis.data().get(),
sizeof(float) * phis.size(),
cudaMemcpyDefault));
}
Expand Down Expand Up @@ -563,49 +660,6 @@ class GPUPredictor : public xgboost::Predictor {
}
}

std::vector<gpu_treeshap::PathElement> ExtractPaths(
const gbm::GBTreeModel& model, size_t tree_limit) {
std::vector<gpu_treeshap::PathElement> paths;
size_t path_idx = 0;
CHECK_LE(tree_limit, model.trees.size());
for (auto i = 0ull; i < tree_limit; i++) {
const auto& tree = *model.trees.at(i);
size_t group = model.tree_info[i];
const auto& nodes = tree.GetNodes();
for (auto j = 0ull; j < nodes.size(); j++) {
if (nodes[j].IsLeaf() && !nodes[j].IsDeleted()) {
auto child = nodes[j];
float v = child.LeafValue();
size_t child_idx = j;
const float inf = std::numeric_limits<float>::infinity();
while (!child.IsRoot()) {
float child_cover = tree.Stat(child_idx).sum_hess;
float parent_cover = tree.Stat(child.Parent()).sum_hess;
float zero_fraction = child_cover / parent_cover;
CHECK(zero_fraction >= 0.0 && zero_fraction <= 1.0);
auto parent = nodes[child.Parent()];
CHECK(parent.LeftChild() == child_idx ||
parent.RightChild() == child_idx);
bool is_left_path = parent.LeftChild() == child_idx;
bool is_missing_path = (!parent.DefaultLeft() && !is_left_path) ||
(parent.DefaultLeft() && is_left_path);
float lower_bound = is_left_path ? -inf : parent.SplitCond();
float upper_bound = is_left_path ? parent.SplitCond() : inf;
paths.emplace_back(path_idx, parent.SplitIndex(), group,
lower_bound, upper_bound, is_missing_path,
zero_fraction, v);
child_idx = child.Parent();
child = parent;
}
// Root node has feature -1
paths.emplace_back(path_idx, -1, group, -inf, inf, false, 1.0, v);
path_idx++;
}
}
}
return paths;
}

std::mutex lock_;
DeviceModel model_;
size_t max_shared_memory_bytes_;
Expand Down

0 comments on commit 2e907ab

Please sign in to comment.