Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add categorical data support to GPU predictor. #6165

Merged
merged 2 commits into from Sep 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/common/host_device_vector.cc
Expand Up @@ -10,6 +10,7 @@
#include <cstdint>
#include <memory>
#include <utility>
#include "xgboost/tree_model.h"
#include "xgboost/host_device_vector.h"

namespace xgboost {
Expand Down Expand Up @@ -176,6 +177,7 @@ template class HostDeviceVector<FeatureType>;
template class HostDeviceVector<Entry>;
template class HostDeviceVector<uint64_t>; // bst_row_t
template class HostDeviceVector<uint32_t>; // bst_feature_t
template class HostDeviceVector<RegTree::Segment>;

#if defined(__APPLE__)
/*
Expand Down
1 change: 1 addition & 0 deletions src/common/host_device_vector.cu
Expand Up @@ -404,6 +404,7 @@ template class HostDeviceVector<Entry>;
template class HostDeviceVector<uint64_t>; // bst_row_t
template class HostDeviceVector<uint32_t>; // bst_feature_t
template class HostDeviceVector<RegTree::Node>;
template class HostDeviceVector<RegTree::Segment>;
template class HostDeviceVector<RTreeNodeStat>;

#if defined(__APPLE__)
Expand Down
179 changes: 136 additions & 43 deletions src/predictor/gpu_predictor.cu
Expand Up @@ -18,6 +18,8 @@
#include "../data/ellpack_page.cuh"
#include "../data/device_adapter.cuh"
#include "../common/common.h"
#include "../common/bitfield.h"
#include "../common/categorical.h"
#include "../common/device_helpers.cuh"

namespace xgboost {
Expand Down Expand Up @@ -169,33 +171,49 @@ struct DeviceAdapterLoader {

template <typename Loader>
__device__ float GetLeafWeight(bst_uint ridx, const RegTree::Node* tree,
common::Span<FeatureType const> split_types,
common::Span<RegTree::Segment const> d_cat_ptrs,
common::Span<uint32_t const> d_categories,
Loader* loader) {
RegTree::Node n = tree[0];
bst_node_t nidx = 0;
RegTree::Node n = tree[nidx];
while (!n.IsLeaf()) {
float fvalue = loader->GetElement(ridx, n.SplitIndex());
// Missing value
if (isnan(fvalue)) {
n = tree[n.DefaultChild()];
if (common::CheckNAN(fvalue)) {
nidx = n.DefaultChild();
} else {
if (fvalue < n.SplitCond()) {
n = tree[n.LeftChild()];
bool go_left = true;
if (common::IsCat(split_types, nidx)) {
auto categories = d_categories.subspan(d_cat_ptrs[nidx].beg,
d_cat_ptrs[nidx].size);
go_left = Decision(categories, common::AsCat(fvalue));
} else {
n = tree[n.RightChild()];
go_left = fvalue < n.SplitCond();
}
if (go_left) {
nidx = n.LeftChild();
} else {
nidx = n.RightChild();
}
}
n = tree[nidx];
}
return n.LeafValue();
return tree[nidx].LeafValue();
}

template <typename Loader, typename Data>
__global__ void PredictKernel(Data data,
common::Span<const RegTree::Node> d_nodes,
common::Span<float> d_out_predictions,
common::Span<size_t> d_tree_segments,
common::Span<int> d_tree_group,
size_t tree_begin, size_t tree_end, size_t num_features,
size_t num_rows, size_t entry_start,
bool use_shared, int num_group) {
__global__ void
PredictKernel(Data data, common::Span<const RegTree::Node> d_nodes,
common::Span<float> d_out_predictions,
common::Span<size_t const> d_tree_segments,
common::Span<int const> d_tree_group,
common::Span<FeatureType const> d_tree_split_types,
common::Span<uint32_t const> d_cat_tree_segments,
common::Span<RegTree::Segment const> d_cat_node_segments,
common::Span<uint32_t const> d_categories, size_t tree_begin,
size_t tree_end, size_t num_features, size_t num_rows,
size_t entry_start, bool use_shared, int num_group) {
bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x;
Loader loader(data, use_shared, num_features, num_rows, entry_start);
if (global_idx >= num_rows) return;
Expand All @@ -204,7 +222,18 @@ __global__ void PredictKernel(Data data,
for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) {
const RegTree::Node* d_tree =
&d_nodes[d_tree_segments[tree_idx - tree_begin]];
float leaf = GetLeafWeight(global_idx, d_tree, &loader);
auto tree_cat_ptrs = d_cat_node_segments.subspan(
d_tree_segments[tree_idx - tree_begin],
d_tree_segments[tree_idx - tree_begin + 1] -
d_tree_segments[tree_idx - tree_begin]);
auto tree_categories =
d_categories.subspan(d_cat_tree_segments[tree_idx - tree_begin],
d_cat_tree_segments[tree_idx - tree_begin + 1] -
d_cat_tree_segments[tree_idx - tree_begin]);
float leaf = GetLeafWeight(global_idx, d_tree, d_tree_split_types,
tree_cat_ptrs,
tree_categories,
&loader);
sum += leaf;
}
d_out_predictions[global_idx] += sum;
Expand All @@ -214,19 +243,38 @@ __global__ void PredictKernel(Data data,
const RegTree::Node* d_tree =
&d_nodes[d_tree_segments[tree_idx - tree_begin]];
bst_uint out_prediction_idx = global_idx * num_group + tree_group;
auto tree_cat_ptrs = d_cat_node_segments.subspan(
d_tree_segments[tree_idx - tree_begin],
d_tree_segments[tree_idx - tree_begin + 1] -
d_tree_segments[tree_idx - tree_begin]);
auto tree_categories =
d_categories.subspan(d_cat_tree_segments[tree_idx - tree_begin],
d_cat_tree_segments[tree_idx - tree_begin + 1] -
d_cat_tree_segments[tree_idx - tree_begin]);
d_out_predictions[out_prediction_idx] +=
GetLeafWeight(global_idx, d_tree, &loader);
GetLeafWeight(global_idx, d_tree, d_tree_split_types,
tree_cat_ptrs,
tree_categories,
&loader);
}
}
}

class DeviceModel {
public:
// Need to lazily construct the vectors because GPU id is only known at runtime
HostDeviceVector<RegTree::Node> nodes;
HostDeviceVector<RTreeNodeStat> stats;
HostDeviceVector<size_t> tree_segments;
HostDeviceVector<RegTree::Node> nodes;
HostDeviceVector<int> tree_group;
HostDeviceVector<FeatureType> split_types;

// Pointer to each tree, segmenting the node array.
HostDeviceVector<uint32_t> categories_tree_segments;
// Pointer to each node, segmenting categories array.
HostDeviceVector<RegTree::Segment> categories_node_segments;
HostDeviceVector<uint32_t> categories;

size_t tree_beg_; // NOLINT
size_t tree_end_; // NOLINT
int num_group;
Expand Down Expand Up @@ -264,10 +312,43 @@ class DeviceModel {
}

tree_group = std::move(HostDeviceVector<int>(model.tree_info.size(), 0, gpu_id));
auto d_tree_group = tree_group.DevicePointer();
dh::safe_cuda(cudaMemcpyAsync(d_tree_group, model.tree_info.data(),
sizeof(int) * model.tree_info.size(),
cudaMemcpyDefault));
auto& h_tree_group = tree_group.HostVector();
std::memcpy(h_tree_group.data(), model.tree_info.data(), sizeof(int) * model.tree_info.size());

// Initialize categorical splits.
split_types.SetDevice(gpu_id);
std::vector<FeatureType>& h_split_types = split_types.HostVector();
h_split_types.resize(h_tree_segments.back());
for (auto tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) {
auto const& src_st = model.trees.at(tree_idx)->GetSplitTypes();
std::copy(src_st.cbegin(), src_st.cend(),
h_split_types.begin() + h_tree_segments[tree_idx - tree_begin]);
}

categories = HostDeviceVector<uint32_t>({}, gpu_id);
categories_tree_segments = HostDeviceVector<uint32_t>(1, 0, gpu_id);
std::vector<uint32_t> &h_categories = categories.HostVector();
std::vector<uint32_t> &h_split_cat_segments = categories_tree_segments.HostVector();
for (auto tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) {
auto const& src_cats = model.trees.at(tree_idx)->GetSplitCategories();
size_t orig_size = h_categories.size();
h_categories.resize(orig_size + src_cats.size());
std::copy(src_cats.cbegin(), src_cats.cend(),
h_categories.begin() + orig_size);
h_split_cat_segments.push_back(h_categories.size());
}

categories_node_segments =
HostDeviceVector<RegTree::Segment>(h_tree_segments.back(), {}, gpu_id);
std::vector<RegTree::Segment> &h_categories_node_segments =
categories_node_segments.HostVector();
for (auto tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) {
auto const &src_cats_ptr = model.trees.at(tree_idx)->GetSplitCategoriesPtr();
std::copy(src_cats_ptr.cbegin(), src_cats_ptr.cend(),
h_categories_node_segments.begin() +
h_tree_segments[tree_idx - tree_begin]);
}

this->tree_beg_ = tree_begin;
this->tree_end_ = tree_end;
this->num_group = model.learner_model_param->num_output_group;
Expand Down Expand Up @@ -360,7 +441,8 @@ void ExtractPaths(dh::device_vector<gpu_treeshap::PathElement>* paths,

class GPUPredictor : public xgboost::Predictor {
private:
void PredictInternal(const SparsePage& batch, size_t num_features,
void PredictInternal(const SparsePage& batch,
size_t num_features,
HostDeviceVector<bst_float>* predictions,
size_t batch_offset) {
batch.offset.SetDevice(generic_param_->gpu_id);
Expand All @@ -380,14 +462,18 @@ class GPUPredictor : public xgboost::Predictor {
SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
num_features);
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} (
PredictKernel<SparsePageLoader, SparsePageView>,
data,
model_.nodes.DeviceSpan(), predictions->DeviceSpan().subspan(batch_offset),
model_.tree_segments.DeviceSpan(), model_.tree_group.DeviceSpan(),
model_.tree_beg_, model_.tree_end_, num_features, num_rows,
entry_start, use_shared, model_.num_group);
PredictKernel<SparsePageLoader, SparsePageView>, data,
model_.nodes.ConstDeviceSpan(),
predictions->DeviceSpan().subspan(batch_offset),
model_.tree_segments.ConstDeviceSpan(), model_.tree_group.ConstDeviceSpan(),
model_.split_types.ConstDeviceSpan(),
model_.categories_tree_segments.ConstDeviceSpan(),
model_.categories_node_segments.ConstDeviceSpan(),
model_.categories.ConstDeviceSpan(), model_.tree_beg_, model_.tree_end_,
num_features, num_rows, entry_start, use_shared, model_.num_group);
}
void PredictInternal(EllpackDeviceAccessor const& batch, HostDeviceVector<bst_float>* out_preds,
void PredictInternal(EllpackDeviceAccessor const& batch,
HostDeviceVector<bst_float>* out_preds,
size_t batch_offset) {
const uint32_t BLOCK_THREADS = 256;
size_t num_rows = batch.n_rows;
Expand All @@ -396,12 +482,15 @@ class GPUPredictor : public xgboost::Predictor {
bool use_shared = false;
size_t entry_start = 0;
dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS} (
PredictKernel<EllpackLoader, EllpackDeviceAccessor>,
batch,
model_.nodes.DeviceSpan(), out_preds->DeviceSpan().subspan(batch_offset),
model_.tree_segments.DeviceSpan(), model_.tree_group.DeviceSpan(),
model_.tree_beg_, model_.tree_end_, batch.NumFeatures(), num_rows,
entry_start, use_shared, model_.num_group);
PredictKernel<EllpackLoader, EllpackDeviceAccessor>, batch,
model_.nodes.ConstDeviceSpan(), out_preds->DeviceSpan().subspan(batch_offset),
model_.tree_segments.ConstDeviceSpan(), model_.tree_group.ConstDeviceSpan(),
model_.split_types.ConstDeviceSpan(),
model_.categories_tree_segments.ConstDeviceSpan(),
model_.categories_node_segments.ConstDeviceSpan(),
model_.categories.ConstDeviceSpan(), model_.tree_beg_, model_.tree_end_,
batch.NumFeatures(), num_rows, entry_start, use_shared,
model_.num_group);
}

void DevicePredictInternal(DMatrix* dmat, HostDeviceVector<float>* out_preds,
Expand All @@ -413,6 +502,7 @@ class GPUPredictor : public xgboost::Predictor {
}
model_.Init(model, tree_begin, tree_end, generic_param_->gpu_id);
out_preds->SetDevice(generic_param_->gpu_id);
auto const& info = dmat->Info();

if (dmat->PageExists<SparsePage>()) {
size_t batch_offset = 0;
Expand All @@ -425,7 +515,8 @@ class GPUPredictor : public xgboost::Predictor {
size_t batch_offset = 0;
for (auto const& page : dmat->GetBatches<EllpackPage>()) {
this->PredictInternal(
page.Impl()->GetDeviceAccessor(generic_param_->gpu_id), out_preds,
page.Impl()->GetDeviceAccessor(generic_param_->gpu_id),
out_preds,
batch_offset);
batch_offset += page.Impl()->n_rows;
}
Expand Down Expand Up @@ -528,12 +619,14 @@ class GPUPredictor : public xgboost::Predictor {
size_t entry_start = 0;

dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} (
PredictKernel<Loader, typename Loader::BatchT>,
m->Value(),
d_model.nodes.DeviceSpan(), out_preds->predictions.DeviceSpan(),
d_model.tree_segments.DeviceSpan(), d_model.tree_group.DeviceSpan(),
tree_begin, tree_end, m->NumColumns(), info.num_row_,
entry_start, use_shared, output_groups);
PredictKernel<Loader, typename Loader::BatchT>, m->Value(),
d_model.nodes.ConstDeviceSpan(), out_preds->predictions.DeviceSpan(),
d_model.tree_segments.ConstDeviceSpan(), d_model.tree_group.ConstDeviceSpan(),
d_model.split_types.ConstDeviceSpan(),
d_model.categories_tree_segments.ConstDeviceSpan(),
d_model.categories_node_segments.ConstDeviceSpan(),
d_model.categories.ConstDeviceSpan(), tree_begin, tree_end, m->NumColumns(),
info.num_row_, entry_start, use_shared, output_groups);
}

void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model,
Expand Down
3 changes: 3 additions & 0 deletions tests/cpp/predictor/test_gpu_predictor.cu
Expand Up @@ -219,5 +219,8 @@ TEST(GPUPredictor, Shap) {
}
}

TEST(GPUPredictor, CategoricalPrediction) {
TestCategoricalPrediction("gpu_predictor");
}
} // namespace predictor
} // namespace xgboost
55 changes: 54 additions & 1 deletion tests/cpp/predictor/test_predictor.cc
Expand Up @@ -12,6 +12,8 @@

#include "../helpers.h"
#include "../../../src/common/io.h"
#include "../../../src/common/categorical.h"
#include "../../../src/common/bitfield.h"

namespace xgboost {
TEST(Predictor, PredictionCache) {
Expand All @@ -27,7 +29,7 @@ TEST(Predictor, PredictionCache) {
};

add_cache();
ASSERT_EQ(container.Container().size(), 0);
ASSERT_EQ(container.Container().size(), 0ul);
add_cache();
EXPECT_ANY_THROW(container.Entry(m));
}
Expand Down Expand Up @@ -174,4 +176,55 @@ void TestPredictionWithLesserFeatures(std::string predictor_name) {
}
#endif // defined(XGBOOST_USE_CUDA)
}

void TestCategoricalPrediction(std::string name) {
size_t constexpr kCols = 10;
PredictionCacheEntry out_predictions;

LearnerModelParam param;
param.num_feature = kCols;
param.num_output_group = 1;
param.base_score = 0.5;

gbm::GBTreeModel model(&param);

std::vector<std::unique_ptr<RegTree>> trees;
trees.push_back(std::unique_ptr<RegTree>(new RegTree));
auto& p_tree = trees.front();

uint32_t split_ind = 3;
bst_cat_t split_cat = 4;
float left_weight = 1.3f;
float right_weight = 1.7f;

std::vector<uint32_t> split_cats(LBitField32::ComputeStorageSize(split_cat));
LBitField32 cats_bits(split_cats);
cats_bits.Set(split_cat);

p_tree->ExpandCategorical(0, split_ind, split_cats, true, 1.5f,
left_weight, right_weight,
3.0f, 2.2f, 7.0f, 9.0f);
model.CommitModel(std::move(trees), 0);

GenericParameter runtime;
runtime.gpu_id = 0;
std::unique_ptr<Predictor> predictor{
Predictor::Create(name.c_str(), &runtime)};

std::vector<float> row(kCols);
row[split_ind] = split_cat;
auto m = GetDMatrixFromData(row, 1, kCols);

predictor->PredictBatch(m.get(), &out_predictions, model, 0);
ASSERT_EQ(out_predictions.predictions.Size(), 1ul);
ASSERT_EQ(out_predictions.predictions.HostVector()[0],
right_weight + param.base_score); // go to right for matching cat

row[split_ind] = split_cat + 1;
m = GetDMatrixFromData(row, 1, kCols);
out_predictions.version = 0;
predictor->PredictBatch(m.get(), &out_predictions, model, 0);
ASSERT_EQ(out_predictions.predictions.HostVector()[0],
left_weight + param.base_score);
}
} // namespace xgboost
2 changes: 2 additions & 0 deletions tests/cpp/predictor/test_predictor.h
Expand Up @@ -61,6 +61,8 @@ void TestInplacePrediction(dmlc::any x, std::string predictor,
int32_t device = -1);

void TestPredictionWithLesserFeatures(std::string preditor_name);

void TestCategoricalPrediction(std::string name);
} // namespace xgboost

#endif // XGBOOST_TEST_PREDICTOR_H_