diff --git a/gputreeshap b/gputreeshap index 04410099299e..1de23c95ff07 160000 --- a/gputreeshap +++ b/gputreeshap @@ -1 +1 @@ -Subproject commit 04410099299ec918c75d00e385da35cf2e5aa87c +Subproject commit 1de23c95ff07d086db02837fb4a746b6924abbd5 diff --git a/include/xgboost/tree_model.h b/include/xgboost/tree_model.h index 60aa9ce16e8c..faec6b118a3a 100644 --- a/include/xgboost/tree_model.h +++ b/include/xgboost/tree_model.h @@ -337,6 +337,9 @@ class RegTree : public Model { /*! \brief get const reference to nodes */ const std::vector& GetNodes() const { return nodes_; } + /*! \brief get const reference to stats */ + const std::vector& GetStats() const { return stats_; } + /*! \brief get node statistics given nid */ RTreeNodeStat& Stat(int nid) { return stats_[nid]; diff --git a/src/common/host_device_vector.cu b/src/common/host_device_vector.cu index 39a0fbe9efb0..2de1bb652900 100644 --- a/src/common/host_device_vector.cu +++ b/src/common/host_device_vector.cu @@ -404,6 +404,7 @@ template class HostDeviceVector; template class HostDeviceVector; // bst_row_t template class HostDeviceVector; // bst_feature_t template class HostDeviceVector; +template class HostDeviceVector; #if defined(__APPLE__) /* diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 9380df399ba7..65d90982b2e5 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -223,6 +223,7 @@ class DeviceModel { public: // Need to lazily construct the vectors because GPU id is only known at runtime HostDeviceVector nodes; + HostDeviceVector stats; HostDeviceVector tree_segments; HostDeviceVector tree_group; size_t tree_beg_; // NOLINT @@ -246,22 +247,116 @@ class DeviceModel { nodes = std::move(HostDeviceVector(h_tree_segments.back(), RegTree::Node(), gpu_id)); - auto& h_nodes = nodes.HostVector(); + stats = std::move(HostDeviceVector(h_tree_segments.back(), + RTreeNodeStat(), gpu_id)); + auto d_nodes = nodes.DevicePointer(); + auto d_stats = stats.DevicePointer(); for (auto tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { auto& src_nodes = model.trees.at(tree_idx)->GetNodes(); - std::copy(src_nodes.begin(), src_nodes.end(), - h_nodes.begin() + h_tree_segments[tree_idx - tree_begin]); + auto& src_stats = model.trees.at(tree_idx)->GetStats(); + dh::safe_cuda(cudaMemcpyAsync( + d_nodes + h_tree_segments[tree_idx - tree_begin], src_nodes.data(), + sizeof(RegTree::Node) * src_nodes.size(), cudaMemcpyDefault)); + dh::safe_cuda(cudaMemcpyAsync( + d_stats + h_tree_segments[tree_idx - tree_begin], src_stats.data(), + sizeof(RTreeNodeStat) * src_stats.size(), cudaMemcpyDefault)); } tree_group = std::move(HostDeviceVector(model.tree_info.size(), 0, gpu_id)); - auto& h_tree_group = tree_group.HostVector(); - std::memcpy(h_tree_group.data(), model.tree_info.data(), sizeof(int) * model.tree_info.size()); + auto d_tree_group = tree_group.DevicePointer(); + dh::safe_cuda(cudaMemcpyAsync(d_tree_group, model.tree_info.data(), + sizeof(int) * model.tree_info.size(), + cudaMemcpyDefault)); this->tree_beg_ = tree_begin; this->tree_end_ = tree_end; this->num_group = model.learner_model_param->num_output_group; } }; +struct PathInfo { + int64_t leaf_position; // -1 not a leaf + size_t length; + size_t tree_idx; +}; + +// Transform model into path element form for GPUTreeShap +void ExtractPaths(dh::device_vector* paths, + const gbm::GBTreeModel& model, size_t tree_limit, + int gpu_id) { + DeviceModel device_model; + device_model.Init(model, 0, tree_limit, gpu_id); + dh::caching_device_vector info(device_model.nodes.Size()); + dh::XGBCachingDeviceAllocator alloc; + auto d_nodes = device_model.nodes.ConstDeviceSpan(); + auto d_tree_segments = device_model.tree_segments.ConstDeviceSpan(); + auto nodes_transform = dh::MakeTransformIterator( + thrust::make_counting_iterator(0ull), [=] __device__(size_t idx) { + auto n = d_nodes[idx]; + if (!n.IsLeaf() || n.IsDeleted()) { + return PathInfo{-1, 0, 0}; + } + size_t tree_idx = + dh::SegmentId(d_tree_segments.begin(), d_tree_segments.end(), idx); + size_t tree_offset = d_tree_segments[tree_idx]; + size_t path_length = 1; + while (!n.IsRoot()) { + n = d_nodes[n.Parent() + tree_offset]; + path_length++; + } + return PathInfo{int64_t(idx), path_length, tree_idx}; + }); + auto end = thrust::copy_if( + thrust::cuda::par(alloc), nodes_transform, + nodes_transform + d_nodes.size(), info.begin(), + [=] __device__(const PathInfo& e) { return e.leaf_position != -1; }); + info.resize(end - info.begin()); + auto length_iterator = dh::MakeTransformIterator( + info.begin(), + [=] __device__(const PathInfo& info) { return info.length; }); + dh::caching_device_vector path_segments(info.size() + 1); + thrust::exclusive_scan(thrust::cuda::par(alloc), length_iterator, + length_iterator + info.size() + 1, + path_segments.begin()); + + paths->resize(path_segments.back()); + + auto d_paths = paths->data().get(); + auto d_info = info.data().get(); + auto d_stats = device_model.stats.ConstDeviceSpan(); + auto d_tree_group = device_model.tree_group.ConstDeviceSpan(); + auto d_path_segments = path_segments.data().get(); + dh::LaunchN(gpu_id, info.size(), [=] __device__(size_t idx) { + auto path_info = d_info[idx]; + size_t tree_offset = d_tree_segments[path_info.tree_idx]; + int group = d_tree_group[path_info.tree_idx]; + size_t child_idx = path_info.leaf_position; + auto child = d_nodes[child_idx]; + float v = child.LeafValue(); + const float inf = std::numeric_limits::infinity(); + size_t output_position = d_path_segments[idx + 1] - 1; + while (!child.IsRoot()) { + size_t parent_idx = tree_offset + child.Parent(); + double child_cover = d_stats[child_idx].sum_hess; + double parent_cover = d_stats[parent_idx].sum_hess; + double zero_fraction = child_cover / parent_cover; + auto parent = d_nodes[parent_idx]; + bool is_left_path = (tree_offset + parent.LeftChild()) == child_idx; + bool is_missing_path = (!parent.DefaultLeft() && !is_left_path) || + (parent.DefaultLeft() && is_left_path); + float lower_bound = is_left_path ? -inf : parent.SplitCond(); + float upper_bound = is_left_path ? parent.SplitCond() : inf; + d_paths[output_position--] = { + idx, parent.SplitIndex(), group, lower_bound, + upper_bound, is_missing_path, zero_fraction, v}; + child_idx = parent_idx; + child = parent; + } + // Root node has feature -1 + d_paths[output_position] = {idx, -1, group, -inf, inf, false, 1.0, v}; + }); +} + + class GPUPredictor : public xgboost::Predictor { private: void PredictInternal(const SparsePage& batch, size_t num_features, @@ -495,17 +590,19 @@ class GPUPredictor : public xgboost::Predictor { margin.empty() ? base_score : margin[idx]; }); - const auto& paths = this->ExtractPaths(model, real_ntree_limit); + dh::device_vector device_paths; + ExtractPaths(&device_paths, model, real_ntree_limit, + generic_param_->gpu_id); for (auto& batch : p_fmat->GetBatches()) { batch.data.SetDevice(generic_param_->gpu_id); batch.offset.SetDevice(generic_param_->gpu_id); SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(), model.learner_model_param->num_feature); gpu_treeshap::GPUTreeShap( - X, paths, ngroup, + X, device_paths.begin(), device_paths.end(), ngroup, phis.data().get() + batch.base_rowid * contributions_columns); } - dh::safe_cuda(cudaMemcpyAsync(contribs.data(), phis.data().get(), + dh::safe_cuda(cudaMemcpy(contribs.data(), phis.data().get(), sizeof(float) * phis.size(), cudaMemcpyDefault)); } @@ -563,49 +660,6 @@ class GPUPredictor : public xgboost::Predictor { } } - std::vector ExtractPaths( - const gbm::GBTreeModel& model, size_t tree_limit) { - std::vector paths; - size_t path_idx = 0; - CHECK_LE(tree_limit, model.trees.size()); - for (auto i = 0ull; i < tree_limit; i++) { - const auto& tree = *model.trees.at(i); - size_t group = model.tree_info[i]; - const auto& nodes = tree.GetNodes(); - for (auto j = 0ull; j < nodes.size(); j++) { - if (nodes[j].IsLeaf() && !nodes[j].IsDeleted()) { - auto child = nodes[j]; - float v = child.LeafValue(); - size_t child_idx = j; - const float inf = std::numeric_limits::infinity(); - while (!child.IsRoot()) { - float child_cover = tree.Stat(child_idx).sum_hess; - float parent_cover = tree.Stat(child.Parent()).sum_hess; - float zero_fraction = child_cover / parent_cover; - CHECK(zero_fraction >= 0.0 && zero_fraction <= 1.0); - auto parent = nodes[child.Parent()]; - CHECK(parent.LeftChild() == child_idx || - parent.RightChild() == child_idx); - bool is_left_path = parent.LeftChild() == child_idx; - bool is_missing_path = (!parent.DefaultLeft() && !is_left_path) || - (parent.DefaultLeft() && is_left_path); - float lower_bound = is_left_path ? -inf : parent.SplitCond(); - float upper_bound = is_left_path ? parent.SplitCond() : inf; - paths.emplace_back(path_idx, parent.SplitIndex(), group, - lower_bound, upper_bound, is_missing_path, - zero_fraction, v); - child_idx = child.Parent(); - child = parent; - } - // Root node has feature -1 - paths.emplace_back(path_idx, -1, group, -inf, inf, false, 1.0, v); - path_idx++; - } - } - } - return paths; - } - std::mutex lock_; DeviceModel model_; size_t max_shared_memory_bytes_;