Skip to content

Commit

Permalink
Implement cat support for shap.
Browse files Browse the repository at this point in the history
* Add CPU support.
* Update GPUTreeSHAP and define customized split condition.
  • Loading branch information
trivialfis committed Jun 25, 2021
1 parent da1ad79 commit 40c5747
Show file tree
Hide file tree
Showing 11 changed files with 275 additions and 40 deletions.
2 changes: 1 addition & 1 deletion .gitmodules
Expand Up @@ -7,4 +7,4 @@
url = https://github.com/NVlabs/cub
[submodule "gputreeshap"]
path = gputreeshap
url = https://github.com/rapidsai/gputreeshap.git
url = https://github.com/trivialfis/gputreeshap.git
2 changes: 1 addition & 1 deletion gputreeshap
2 changes: 1 addition & 1 deletion include/xgboost/tree_model.h
Expand Up @@ -567,7 +567,7 @@ class RegTree : public Model {
* \param condition_feature the index of the feature to fix
* \param condition_fraction what fraction of the current weight matches our conditioning feature
*/
void TreeShap(const RegTree::FVec& feat, bst_float* phi, unsigned node_index,
void TreeShap(const RegTree::FVec& feat, bst_float* phi, bst_node_t node_index,
unsigned unique_depth, PathElement* parent_unique_path,
bst_float parent_zero_fraction, bst_float parent_one_fraction,
int parent_feature_index, int condition,
Expand Down
6 changes: 4 additions & 2 deletions src/common/bitfield.h
Expand Up @@ -87,9 +87,11 @@ struct BitFieldContainer {
BitFieldContainer() = default;
XGBOOST_DEVICE explicit BitFieldContainer(common::Span<value_type> bits) : bits_{bits} {}
XGBOOST_DEVICE BitFieldContainer(BitFieldContainer const& other) : bits_{other.bits_} {}
BitFieldContainer &operator=(BitFieldContainer const &that) = default;
BitFieldContainer &operator=(BitFieldContainer &&that) = default;

common::Span<value_type> Bits() { return bits_; }
common::Span<value_type const> Bits() const { return bits_; }
XGBOOST_DEVICE common::Span<value_type> Bits() { return bits_; }
XGBOOST_DEVICE common::Span<value_type const> Bits() const { return bits_; }

/*\brief Compute the size of needed memory allocation. The returned value is in terms
* of number of elements with `BitFieldContainer::value_type'.
Expand Down
177 changes: 154 additions & 23 deletions src/predictor/gpu_predictor.cu
Expand Up @@ -415,18 +415,82 @@ class DeviceModel {
}
};

struct ShapSplitCondition {
ShapSplitCondition() = default;
__host__ __device__ ShapSplitCondition(float feature_lower_bound,
float feature_upper_bound,
bool is_missing_branch,
common::CatBitField cats)
: feature_lower_bound(feature_lower_bound),
feature_upper_bound(feature_upper_bound),
is_missing_branch(is_missing_branch), categories{std::move(cats)} {
assert(feature_lower_bound <= feature_upper_bound);
}

/*! Feature values >= lower and < upper flow down this path. */
float feature_lower_bound;
float feature_upper_bound;
/*! Feature value set to true flow down this path. */
common::CatBitField categories;
/*! Do missing values flow down this path? */
bool is_missing_branch;

// Does this instance flow down this path?
__host__ __device__ bool EvaluateSplit(float x) const {
// is nan
if (isnan(x)) {
return is_missing_branch;
}
if (categories.Size() != 0) {
auto cat = static_cast<uint32_t>(x);
return categories.Check(cat);
} else {
return x >= feature_lower_bound && x < feature_upper_bound;
}
}

// the &= op in bitfiled is per cuda thread, this one loops over the entire
// bitfield.
XGBOOST_DEVICE static common::CatBitField Intersect(common::CatBitField l,
common::CatBitField r) {
if (l.Data() == r.Data()) {
return l;
}
if (l.Size() > r.Size()) {
thrust::swap(l, r);
}
for (size_t i = 0; i < r.Bits().size(); ++i) {
l.Bits()[i] &= r.Bits()[i];
}
return l;
}

// Combine two split conditions on the same feature
XGBOOST_DEVICE void Merge(ShapSplitCondition other) {
// Combine duplicate features
if (categories.Size() != 0 || other.categories.Size() != 0) {
categories = Intersect(categories, other.categories);
} else {
feature_lower_bound = max(feature_lower_bound, other.feature_lower_bound);
feature_upper_bound = min(feature_upper_bound, other.feature_upper_bound);
}
is_missing_branch = is_missing_branch && other.is_missing_branch;
}
};

struct PathInfo {
int64_t leaf_position; // -1 not a leaf
size_t length;
size_t tree_idx;
};

// Transform model into path element form for GPUTreeShap
void ExtractPaths(dh::device_vector<gpu_treeshap::PathElement>* paths,
const gbm::GBTreeModel& model, size_t tree_limit,
int gpu_id) {
DeviceModel device_model;
device_model.Init(model, 0, tree_limit, gpu_id);
void ExtractPaths(
dh::device_vector<gpu_treeshap::PathElement<ShapSplitCondition>> *paths,
DeviceModel *model, dh::device_vector<uint32_t> *path_categories,
int gpu_id) {
auto& device_model = *model;

dh::caching_device_vector<PathInfo> info(device_model.nodes.Size());
dh::XGBCachingDeviceAllocator<PathInfo> alloc;
auto d_nodes = device_model.nodes.ConstDeviceSpan();
Expand Down Expand Up @@ -462,14 +526,47 @@ void ExtractPaths(dh::device_vector<gpu_treeshap::PathElement>* paths,

paths->resize(path_segments.back());

auto d_paths = paths->data().get();
auto d_paths = dh::ToSpan(*paths);
auto d_info = info.data().get();
auto d_stats = device_model.stats.ConstDeviceSpan();
auto d_tree_group = device_model.tree_group.ConstDeviceSpan();
auto d_path_segments = path_segments.data().get();

auto d_split_types = device_model.split_types.ConstDeviceSpan();
auto d_cat_segments = device_model.categories_tree_segments.ConstDeviceSpan();
auto d_cat_node_segments = device_model.categories_node_segments.ConstDeviceSpan();

size_t max_cat = 0;
if (thrust::any_of(dh::tbegin(d_split_types), dh::tend(d_split_types),
[] __device__(FeatureType ft) {
return ft == FeatureType::kCategorical;
})) {
dh::PinnedMemory pinned;
auto h_max_cat = pinned.GetSpan<RegTree::Segment>(1);
auto max_elem_it = dh::MakeTransformIterator<size_t>(
dh::tbegin(d_cat_node_segments),
[] __device__(RegTree::Segment seg) { return seg.size; });
size_t max_cat_it =
thrust::max_element(thrust::device, max_elem_it,
max_elem_it + d_cat_node_segments.size()) -
max_elem_it;
dh::safe_cuda(cudaMemcpy(h_max_cat.data(),
d_cat_node_segments.data() + max_cat_it,
h_max_cat.size_bytes(), cudaMemcpyDeviceToHost));
max_cat = h_max_cat[0].size;
CHECK_GE(max_cat, 1);
path_categories->resize(max_cat * paths->size());
}

auto d_model_categories = device_model.categories.DeviceSpan();
common::Span<uint32_t> d_path_categories = dh::ToSpan(*path_categories);

dh::LaunchN(gpu_id, info.size(), [=] __device__(size_t idx) {
auto path_info = d_info[idx];
size_t tree_offset = d_tree_segments[path_info.tree_idx];
TreeView tree{0, path_info.tree_idx, d_nodes,
d_tree_segments, d_split_types, d_cat_segments,
d_cat_node_segments, d_model_categories};
int group = d_tree_group[path_info.tree_idx];
size_t child_idx = path_info.leaf_position;
auto child = d_nodes[child_idx];
Expand All @@ -481,20 +578,44 @@ void ExtractPaths(dh::device_vector<gpu_treeshap::PathElement>* paths,
double child_cover = d_stats[child_idx].sum_hess;
double parent_cover = d_stats[parent_idx].sum_hess;
double zero_fraction = child_cover / parent_cover;
auto parent = d_nodes[parent_idx];
auto parent = tree.d_tree[child.Parent()];

bool is_left_path = (tree_offset + parent.LeftChild()) == child_idx;
bool is_missing_path = (!parent.DefaultLeft() && !is_left_path) ||
(parent.DefaultLeft() && is_left_path);
float lower_bound = is_left_path ? -inf : parent.SplitCond();
float upper_bound = is_left_path ? parent.SplitCond() : inf;
d_paths[output_position--] = {
idx, parent.SplitIndex(), group, lower_bound,
upper_bound, is_missing_path, zero_fraction, v};

float lower_bound = -inf;
float upper_bound = inf;
common::CatBitField bits;
if (common::IsCat(tree.cats.split_type, child.Parent())) {
auto path_cats = d_path_categories.subspan(max_cat * output_position, max_cat);
size_t size = tree.cats.node_ptr[child.Parent()].size;
auto node_cats = tree.cats.categories.subspan(tree.cats.node_ptr[child.Parent()].beg, size);
SPAN_CHECK(path_cats.size() >= node_cats.size());
if (is_left_path) {
for (size_t i = 0; i < node_cats.size(); ++i) {
path_cats[i] = ~node_cats[i];
}
} else {
for (size_t i = 0; i < node_cats.size(); ++i) {
path_cats[i] = node_cats[i];
}
}
bits = common::CatBitField{path_cats};
} else {
lower_bound = is_left_path ? -inf : parent.SplitCond();
upper_bound = is_left_path ? parent.SplitCond() : inf;
}
d_paths[output_position--] =
gpu_treeshap::PathElement<ShapSplitCondition>{
idx, parent.SplitIndex(),
group, ShapSplitCondition{lower_bound, upper_bound, is_missing_path, bits},
zero_fraction, v};
child_idx = parent_idx;
child = parent;
}
// Root node has feature -1
d_paths[output_position] = {idx, -1, group, -inf, inf, false, 1.0, v};
d_paths[output_position] = {idx, -1, group, ShapSplitCondition{-inf, inf, false, {}}, 1.0, v};
});
}

Expand Down Expand Up @@ -718,16 +839,21 @@ class GPUPredictor : public xgboost::Predictor {
out_contribs->Fill(0.0f);
auto phis = out_contribs->DeviceSpan();

dh::device_vector<gpu_treeshap::PathElement> device_paths;
ExtractPaths(&device_paths, model, tree_end, generic_param_->gpu_id);
dh::device_vector<gpu_treeshap::PathElement<ShapSplitCondition>>
device_paths;
DeviceModel d_model;
d_model.Init(model, 0, tree_end, generic_param_->gpu_id);
dh::device_vector<uint32_t> categories;
ExtractPaths(&device_paths, &d_model, &categories, generic_param_->gpu_id);
for (auto& batch : p_fmat->GetBatches<SparsePage>()) {
batch.data.SetDevice(generic_param_->gpu_id);
batch.offset.SetDevice(generic_param_->gpu_id);
SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
model.learner_model_param->num_feature);
gpu_treeshap::GPUTreeShap(
X, device_paths.begin(), device_paths.end(), ngroup,
phis.data() + batch.base_rowid * contributions_columns, phis.size());
auto begin = dh::tbegin(phis) + batch.base_rowid * contributions_columns;
gpu_treeshap::GPUTreeShap<dh::XGBDeviceAllocator<int>>(
X, device_paths.begin(), device_paths.end(), ngroup, begin,
dh::tend(phis));
}
// Add the base margin term to last column
p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id);
Expand Down Expand Up @@ -769,16 +895,21 @@ class GPUPredictor : public xgboost::Predictor {
out_contribs->Fill(0.0f);
auto phis = out_contribs->DeviceSpan();

dh::device_vector<gpu_treeshap::PathElement> device_paths;
ExtractPaths(&device_paths, model, tree_end, generic_param_->gpu_id);
dh::device_vector<gpu_treeshap::PathElement<ShapSplitCondition>>
device_paths;
DeviceModel d_model;
d_model.Init(model, 0, tree_end, generic_param_->gpu_id);
dh::device_vector<uint32_t> categories;
ExtractPaths(&device_paths, &d_model, &categories, generic_param_->gpu_id);
for (auto& batch : p_fmat->GetBatches<SparsePage>()) {
batch.data.SetDevice(generic_param_->gpu_id);
batch.offset.SetDevice(generic_param_->gpu_id);
SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
model.learner_model_param->num_feature);
gpu_treeshap::GPUTreeShapInteractions(
X, device_paths.begin(), device_paths.end(), ngroup,
phis.data() + batch.base_rowid * contributions_columns, phis.size());
auto begin = dh::tbegin(phis) + batch.base_rowid * contributions_columns;
gpu_treeshap::GPUTreeShapInteractions<dh::XGBDeviceAllocator<int>>(
X, device_paths.begin(), device_paths.end(), ngroup, begin,
dh::tend(phis));
}
// Add the base margin term to last column
p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id);
Expand Down
19 changes: 8 additions & 11 deletions src/tree/tree_model.cc
Expand Up @@ -1245,7 +1245,7 @@ bst_float UnwoundPathSum(const PathElement *unique_path, unsigned unique_depth,

// recursive computation of SHAP values for a decision tree
void RegTree::TreeShap(const RegTree::FVec &feat, bst_float *phi,
unsigned node_index, unsigned unique_depth,
bst_node_t node_index, unsigned unique_depth,
PathElement *parent_unique_path,
bst_float parent_zero_fraction,
bst_float parent_one_fraction, int parent_feature_index,
Expand Down Expand Up @@ -1278,16 +1278,13 @@ void RegTree::TreeShap(const RegTree::FVec &feat, bst_float *phi,
// internal node
} else {
// find which branch is "hot" (meaning x would follow it)
unsigned hot_index = 0;
if (feat.IsMissing(split_index)) {
hot_index = node.DefaultChild();
} else if (feat.GetFvalue(split_index) < node.SplitCond()) {
hot_index = node.LeftChild();
} else {
hot_index = node.RightChild();
}
const unsigned cold_index = (static_cast<int>(hot_index) == node.LeftChild() ?
node.RightChild() : node.LeftChild());
auto const &cats = this->GetCategoriesMatrix();
bst_node_t hot_index = predictor::GetNextNode<true, true>(
node, node_index, feat.GetFvalue(split_index),
feat.IsMissing(split_index), cats);

const auto cold_index =
(hot_index == node.LeftChild() ? node.RightChild() : node.LeftChild());
const bst_float w = this->Stat(node_index).sum_hess;
const bst_float hot_zero_fraction = this->Stat(hot_index).sum_hess / w;
const bst_float cold_zero_fraction = this->Stat(cold_index).sum_hess / w;
Expand Down
5 changes: 5 additions & 0 deletions tests/cpp/predictor/test_cpu_predictor.cc
Expand Up @@ -86,6 +86,11 @@ TEST(CpuPredictor, Basic) {
}
}


TEST(CpuPredictor, IterationRange) {
TestIterationRange("cpu_predictor");
}

TEST(CpuPredictor, ExternalMemory) {
dmlc::TemporaryDirectory tmpdir;
std::string filename = tmpdir.path + "/big.libsvm";
Expand Down
5 changes: 5 additions & 0 deletions tests/cpp/predictor/test_gpu_predictor.cu
Expand Up @@ -224,6 +224,11 @@ TEST(GPUPredictor, Shap) {
}
}

TEST(GPUPredictor, IterationRange) {
TestIterationRange("gpu_predictor");
}


TEST(GPUPredictor, CategoricalPrediction) {
TestCategoricalPrediction("gpu_predictor");
}
Expand Down

0 comments on commit 40c5747

Please sign in to comment.