Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates to GPUTreeShap #6087

Merged
merged 2 commits into from Sep 6, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion gputreeshap
3 changes: 3 additions & 0 deletions include/xgboost/tree_model.h
Expand Up @@ -337,6 +337,9 @@ class RegTree : public Model {
/*! \brief get const reference to nodes */
const std::vector<Node>& GetNodes() const { return nodes_; }

/*! \brief get const reference to stats */
const std::vector<RTreeNodeStat>& GetStats() const { return stats_; }

/*! \brief get node statistics given nid */
RTreeNodeStat& Stat(int nid) {
return stats_[nid];
Expand Down
1 change: 1 addition & 0 deletions src/common/host_device_vector.cu
Expand Up @@ -404,6 +404,7 @@ template class HostDeviceVector<Entry>;
template class HostDeviceVector<uint64_t>; // bst_row_t
template class HostDeviceVector<uint32_t>; // bst_feature_t
template class HostDeviceVector<RegTree::Node>;
template class HostDeviceVector<RTreeNodeStat>;

#if defined(__APPLE__)
/*
Expand Down
156 changes: 105 additions & 51 deletions src/predictor/gpu_predictor.cu
Expand Up @@ -223,6 +223,7 @@ class DeviceModel {
public:
// Need to lazily construct the vectors because GPU id is only known at runtime
HostDeviceVector<RegTree::Node> nodes;
HostDeviceVector<RTreeNodeStat> stats;
HostDeviceVector<size_t> tree_segments;
HostDeviceVector<int> tree_group;
size_t tree_beg_; // NOLINT
Expand All @@ -246,22 +247,116 @@ class DeviceModel {

nodes = std::move(HostDeviceVector<RegTree::Node>(h_tree_segments.back(), RegTree::Node(),
gpu_id));
auto& h_nodes = nodes.HostVector();
stats = std::move(HostDeviceVector<RTreeNodeStat>(h_tree_segments.back(),
RTreeNodeStat(), gpu_id));
auto d_nodes = nodes.DevicePointer();
auto d_stats = stats.DevicePointer();
for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
auto& src_nodes = model.trees.at(tree_idx)->GetNodes();
std::copy(src_nodes.begin(), src_nodes.end(),
h_nodes.begin() + h_tree_segments[tree_idx - tree_begin]);
auto& src_stats = model.trees.at(tree_idx)->GetStats();
dh::safe_cuda(cudaMemcpyAsync(
d_nodes + h_tree_segments[tree_idx - tree_begin], src_nodes.data(),
sizeof(RegTree::Node) * src_nodes.size(), cudaMemcpyDefault));
dh::safe_cuda(cudaMemcpyAsync(
d_stats + h_tree_segments[tree_idx - tree_begin], src_stats.data(),
sizeof(RTreeNodeStat) * src_stats.size(), cudaMemcpyDefault));
}

tree_group = std::move(HostDeviceVector<int>(model.tree_info.size(), 0, gpu_id));
auto& h_tree_group = tree_group.HostVector();
std::memcpy(h_tree_group.data(), model.tree_info.data(), sizeof(int) * model.tree_info.size());
auto d_tree_group = tree_group.DevicePointer();
dh::safe_cuda(cudaMemcpyAsync(d_tree_group, model.tree_info.data(),
sizeof(int) * model.tree_info.size(),
cudaMemcpyDefault));
this->tree_beg_ = tree_begin;
this->tree_end_ = tree_end;
this->num_group = model.learner_model_param->num_output_group;
}
};

struct PathInfo {
int64_t leaf_position; // -1 not a leaf
size_t length;
size_t tree_idx;
};

// Transform model into path element form for GPUTreeShap
void ExtractPaths(dh::device_vector<gpu_treeshap::PathElement>* paths,
const gbm::GBTreeModel& model, size_t tree_limit,
int gpu_id) {
DeviceModel device_model;
device_model.Init(model, 0, tree_limit, gpu_id);
dh::caching_device_vector<PathInfo> info(device_model.nodes.Size());
dh::XGBCachingDeviceAllocator<PathInfo> alloc;
auto d_nodes = device_model.nodes.ConstDeviceSpan();
auto d_tree_segments = device_model.tree_segments.ConstDeviceSpan();
auto nodes_transform = dh::MakeTransformIterator<PathInfo>(
thrust::make_counting_iterator(0ull), [=] __device__(size_t idx) {
auto n = d_nodes[idx];
if (!n.IsLeaf() || n.IsDeleted()) {
return PathInfo{-1, 0, 0};
}
size_t tree_idx =
dh::SegmentId(d_tree_segments.begin(), d_tree_segments.end(), idx);
size_t tree_offset = d_tree_segments[tree_idx];
size_t path_length = 1;
while (!n.IsRoot()) {
n = d_nodes[n.Parent() + tree_offset];
path_length++;
}
return PathInfo{int64_t(idx), path_length, tree_idx};
});
auto end = thrust::copy_if(
thrust::cuda::par(alloc), nodes_transform,
nodes_transform + d_nodes.size(), info.begin(),
[=] __device__(const PathInfo& e) { return e.leaf_position != -1; });
info.resize(end - info.begin());
auto length_iterator = dh::MakeTransformIterator<size_t>(
info.begin(),
[=] __device__(const PathInfo& info) { return info.length; });
dh::caching_device_vector<size_t> path_segments(info.size() + 1);
thrust::exclusive_scan(thrust::cuda::par(alloc), length_iterator,
length_iterator + info.size() + 1,
path_segments.begin());

paths->resize(path_segments.back());

auto d_paths = paths->data().get();
auto d_info = info.data().get();
auto d_stats = device_model.stats.ConstDeviceSpan();
auto d_tree_group = device_model.tree_group.ConstDeviceSpan();
auto d_path_segments = path_segments.data().get();
dh::LaunchN(gpu_id, info.size(), [=] __device__(size_t idx) {
auto path_info = d_info[idx];
size_t tree_offset = d_tree_segments[path_info.tree_idx];
int group = d_tree_group[path_info.tree_idx];
size_t child_idx = path_info.leaf_position;
auto child = d_nodes[child_idx];
float v = child.LeafValue();
const float inf = std::numeric_limits<float>::infinity();
size_t output_position = d_path_segments[idx + 1] - 1;
while (!child.IsRoot()) {
size_t parent_idx = tree_offset + child.Parent();
double child_cover = d_stats[child_idx].sum_hess;
double parent_cover = d_stats[parent_idx].sum_hess;
double zero_fraction = child_cover / parent_cover;
auto parent = d_nodes[parent_idx];
bool is_left_path = (tree_offset + parent.LeftChild()) == child_idx;
bool is_missing_path = (!parent.DefaultLeft() && !is_left_path) ||
(parent.DefaultLeft() && is_left_path);
float lower_bound = is_left_path ? -inf : parent.SplitCond();
float upper_bound = is_left_path ? parent.SplitCond() : inf;
d_paths[output_position--] = {
idx, parent.SplitIndex(), group, lower_bound,
upper_bound, is_missing_path, zero_fraction, v};
child_idx = parent_idx;
child = parent;
}
// Root node has feature -1
d_paths[output_position] = {idx, -1, group, -inf, inf, false, 1.0, v};
});
}


class GPUPredictor : public xgboost::Predictor {
private:
void PredictInternal(const SparsePage& batch, size_t num_features,
Expand Down Expand Up @@ -495,17 +590,19 @@ class GPUPredictor : public xgboost::Predictor {
margin.empty() ? base_score : margin[idx];
});

const auto& paths = this->ExtractPaths(model, real_ntree_limit);
dh::device_vector<gpu_treeshap::PathElement> device_paths;
ExtractPaths(&device_paths, model, real_ntree_limit,
generic_param_->gpu_id);
for (auto& batch : p_fmat->GetBatches<SparsePage>()) {
batch.data.SetDevice(generic_param_->gpu_id);
batch.offset.SetDevice(generic_param_->gpu_id);
SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
model.learner_model_param->num_feature);
gpu_treeshap::GPUTreeShap(
X, paths, ngroup,
X, device_paths.begin(), device_paths.end(), ngroup,
phis.data().get() + batch.base_rowid * contributions_columns);
}
dh::safe_cuda(cudaMemcpyAsync(contribs.data(), phis.data().get(),
dh::safe_cuda(cudaMemcpy(contribs.data(), phis.data().get(),
sizeof(float) * phis.size(),
cudaMemcpyDefault));
}
Expand Down Expand Up @@ -563,49 +660,6 @@ class GPUPredictor : public xgboost::Predictor {
}
}

std::vector<gpu_treeshap::PathElement> ExtractPaths(
const gbm::GBTreeModel& model, size_t tree_limit) {
std::vector<gpu_treeshap::PathElement> paths;
size_t path_idx = 0;
CHECK_LE(tree_limit, model.trees.size());
for (auto i = 0ull; i < tree_limit; i++) {
const auto& tree = *model.trees.at(i);
size_t group = model.tree_info[i];
const auto& nodes = tree.GetNodes();
for (auto j = 0ull; j < nodes.size(); j++) {
if (nodes[j].IsLeaf() && !nodes[j].IsDeleted()) {
auto child = nodes[j];
float v = child.LeafValue();
size_t child_idx = j;
const float inf = std::numeric_limits<float>::infinity();
while (!child.IsRoot()) {
float child_cover = tree.Stat(child_idx).sum_hess;
float parent_cover = tree.Stat(child.Parent()).sum_hess;
float zero_fraction = child_cover / parent_cover;
CHECK(zero_fraction >= 0.0 && zero_fraction <= 1.0);
auto parent = nodes[child.Parent()];
CHECK(parent.LeftChild() == child_idx ||
parent.RightChild() == child_idx);
bool is_left_path = parent.LeftChild() == child_idx;
bool is_missing_path = (!parent.DefaultLeft() && !is_left_path) ||
(parent.DefaultLeft() && is_left_path);
float lower_bound = is_left_path ? -inf : parent.SplitCond();
float upper_bound = is_left_path ? parent.SplitCond() : inf;
paths.emplace_back(path_idx, parent.SplitIndex(), group,
lower_bound, upper_bound, is_missing_path,
zero_fraction, v);
child_idx = child.Parent();
child = parent;
}
// Root node has feature -1
paths.emplace_back(path_idx, -1, group, -inf, inf, false, 1.0, v);
path_idx++;
}
}
}
return paths;
}

std::mutex lock_;
DeviceModel model_;
size_t max_shared_memory_bytes_;
Expand Down