From 798af22ff4fd92a707b1d9e2934b4f9b31a8225d Mon Sep 17 00:00:00 2001 From: Jiaming Yuan Date: Tue, 29 Sep 2020 11:25:34 +0800 Subject: [PATCH] Add categorical data support to GPU predictor. (#6165) --- src/common/host_device_vector.cc | 2 + src/common/host_device_vector.cu | 1 + src/predictor/gpu_predictor.cu | 179 ++++++++++++++++------ tests/cpp/predictor/test_gpu_predictor.cu | 3 + tests/cpp/predictor/test_predictor.cc | 55 ++++++- tests/cpp/predictor/test_predictor.h | 2 + 6 files changed, 198 insertions(+), 44 deletions(-) diff --git a/src/common/host_device_vector.cc b/src/common/host_device_vector.cc index f9974f8ecfaf..a16154966588 100644 --- a/src/common/host_device_vector.cc +++ b/src/common/host_device_vector.cc @@ -10,6 +10,7 @@ #include #include #include +#include "xgboost/tree_model.h" #include "xgboost/host_device_vector.h" namespace xgboost { @@ -176,6 +177,7 @@ template class HostDeviceVector; template class HostDeviceVector; template class HostDeviceVector; // bst_row_t template class HostDeviceVector; // bst_feature_t +template class HostDeviceVector; #if defined(__APPLE__) /* diff --git a/src/common/host_device_vector.cu b/src/common/host_device_vector.cu index 2de1bb652900..02a47bea8e60 100644 --- a/src/common/host_device_vector.cu +++ b/src/common/host_device_vector.cu @@ -404,6 +404,7 @@ template class HostDeviceVector; template class HostDeviceVector; // bst_row_t template class HostDeviceVector; // bst_feature_t template class HostDeviceVector; +template class HostDeviceVector; template class HostDeviceVector; #if defined(__APPLE__) diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 992e2cc65753..39035a3d8b55 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -18,6 +18,8 @@ #include "../data/ellpack_page.cuh" #include "../data/device_adapter.cuh" #include "../common/common.h" +#include "../common/bitfield.h" +#include "../common/categorical.h" #include "../common/device_helpers.cuh" namespace xgboost { @@ -169,33 +171,49 @@ struct DeviceAdapterLoader { template __device__ float GetLeafWeight(bst_uint ridx, const RegTree::Node* tree, + common::Span split_types, + common::Span d_cat_ptrs, + common::Span d_categories, Loader* loader) { - RegTree::Node n = tree[0]; + bst_node_t nidx = 0; + RegTree::Node n = tree[nidx]; while (!n.IsLeaf()) { float fvalue = loader->GetElement(ridx, n.SplitIndex()); // Missing value - if (isnan(fvalue)) { - n = tree[n.DefaultChild()]; + if (common::CheckNAN(fvalue)) { + nidx = n.DefaultChild(); } else { - if (fvalue < n.SplitCond()) { - n = tree[n.LeftChild()]; + bool go_left = true; + if (common::IsCat(split_types, nidx)) { + auto categories = d_categories.subspan(d_cat_ptrs[nidx].beg, + d_cat_ptrs[nidx].size); + go_left = Decision(categories, common::AsCat(fvalue)); } else { - n = tree[n.RightChild()]; + go_left = fvalue < n.SplitCond(); + } + if (go_left) { + nidx = n.LeftChild(); + } else { + nidx = n.RightChild(); } } + n = tree[nidx]; } - return n.LeafValue(); + return tree[nidx].LeafValue(); } template -__global__ void PredictKernel(Data data, - common::Span d_nodes, - common::Span d_out_predictions, - common::Span d_tree_segments, - common::Span d_tree_group, - size_t tree_begin, size_t tree_end, size_t num_features, - size_t num_rows, size_t entry_start, - bool use_shared, int num_group) { +__global__ void +PredictKernel(Data data, common::Span d_nodes, + common::Span d_out_predictions, + common::Span d_tree_segments, + common::Span d_tree_group, + common::Span d_tree_split_types, + common::Span d_cat_tree_segments, + common::Span d_cat_node_segments, + common::Span d_categories, size_t tree_begin, + size_t tree_end, size_t num_features, size_t num_rows, + size_t entry_start, bool use_shared, int num_group) { bst_uint global_idx = blockDim.x * blockIdx.x + threadIdx.x; Loader loader(data, use_shared, num_features, num_rows, entry_start); if (global_idx >= num_rows) return; @@ -204,7 +222,18 @@ __global__ void PredictKernel(Data data, for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { const RegTree::Node* d_tree = &d_nodes[d_tree_segments[tree_idx - tree_begin]]; - float leaf = GetLeafWeight(global_idx, d_tree, &loader); + auto tree_cat_ptrs = d_cat_node_segments.subspan( + d_tree_segments[tree_idx - tree_begin], + d_tree_segments[tree_idx - tree_begin + 1] - + d_tree_segments[tree_idx - tree_begin]); + auto tree_categories = + d_categories.subspan(d_cat_tree_segments[tree_idx - tree_begin], + d_cat_tree_segments[tree_idx - tree_begin + 1] - + d_cat_tree_segments[tree_idx - tree_begin]); + float leaf = GetLeafWeight(global_idx, d_tree, d_tree_split_types, + tree_cat_ptrs, + tree_categories, + &loader); sum += leaf; } d_out_predictions[global_idx] += sum; @@ -214,8 +243,19 @@ __global__ void PredictKernel(Data data, const RegTree::Node* d_tree = &d_nodes[d_tree_segments[tree_idx - tree_begin]]; bst_uint out_prediction_idx = global_idx * num_group + tree_group; + auto tree_cat_ptrs = d_cat_node_segments.subspan( + d_tree_segments[tree_idx - tree_begin], + d_tree_segments[tree_idx - tree_begin + 1] - + d_tree_segments[tree_idx - tree_begin]); + auto tree_categories = + d_categories.subspan(d_cat_tree_segments[tree_idx - tree_begin], + d_cat_tree_segments[tree_idx - tree_begin + 1] - + d_cat_tree_segments[tree_idx - tree_begin]); d_out_predictions[out_prediction_idx] += - GetLeafWeight(global_idx, d_tree, &loader); + GetLeafWeight(global_idx, d_tree, d_tree_split_types, + tree_cat_ptrs, + tree_categories, + &loader); } } } @@ -223,10 +263,18 @@ __global__ void PredictKernel(Data data, class DeviceModel { public: // Need to lazily construct the vectors because GPU id is only known at runtime - HostDeviceVector nodes; HostDeviceVector stats; HostDeviceVector tree_segments; + HostDeviceVector nodes; HostDeviceVector tree_group; + HostDeviceVector split_types; + + // Pointer to each tree, segmenting the node array. + HostDeviceVector categories_tree_segments; + // Pointer to each node, segmenting categories array. + HostDeviceVector categories_node_segments; + HostDeviceVector categories; + size_t tree_beg_; // NOLINT size_t tree_end_; // NOLINT int num_group; @@ -264,10 +312,43 @@ class DeviceModel { } tree_group = std::move(HostDeviceVector(model.tree_info.size(), 0, gpu_id)); - auto d_tree_group = tree_group.DevicePointer(); - dh::safe_cuda(cudaMemcpyAsync(d_tree_group, model.tree_info.data(), - sizeof(int) * model.tree_info.size(), - cudaMemcpyDefault)); + auto& h_tree_group = tree_group.HostVector(); + std::memcpy(h_tree_group.data(), model.tree_info.data(), sizeof(int) * model.tree_info.size()); + + // Initialize categorical splits. + split_types.SetDevice(gpu_id); + std::vector& h_split_types = split_types.HostVector(); + h_split_types.resize(h_tree_segments.back()); + for (auto tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) { + auto const& src_st = model.trees.at(tree_idx)->GetSplitTypes(); + std::copy(src_st.cbegin(), src_st.cend(), + h_split_types.begin() + h_tree_segments[tree_idx - tree_begin]); + } + + categories = HostDeviceVector({}, gpu_id); + categories_tree_segments = HostDeviceVector(1, 0, gpu_id); + std::vector &h_categories = categories.HostVector(); + std::vector &h_split_cat_segments = categories_tree_segments.HostVector(); + for (auto tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) { + auto const& src_cats = model.trees.at(tree_idx)->GetSplitCategories(); + size_t orig_size = h_categories.size(); + h_categories.resize(orig_size + src_cats.size()); + std::copy(src_cats.cbegin(), src_cats.cend(), + h_categories.begin() + orig_size); + h_split_cat_segments.push_back(h_categories.size()); + } + + categories_node_segments = + HostDeviceVector(h_tree_segments.back(), {}, gpu_id); + std::vector &h_categories_node_segments = + categories_node_segments.HostVector(); + for (auto tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) { + auto const &src_cats_ptr = model.trees.at(tree_idx)->GetSplitCategoriesPtr(); + std::copy(src_cats_ptr.cbegin(), src_cats_ptr.cend(), + h_categories_node_segments.begin() + + h_tree_segments[tree_idx - tree_begin]); + } + this->tree_beg_ = tree_begin; this->tree_end_ = tree_end; this->num_group = model.learner_model_param->num_output_group; @@ -360,7 +441,8 @@ void ExtractPaths(dh::device_vector* paths, class GPUPredictor : public xgboost::Predictor { private: - void PredictInternal(const SparsePage& batch, size_t num_features, + void PredictInternal(const SparsePage& batch, + size_t num_features, HostDeviceVector* predictions, size_t batch_offset) { batch.offset.SetDevice(generic_param_->gpu_id); @@ -380,14 +462,18 @@ class GPUPredictor : public xgboost::Predictor { SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(), num_features); dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} ( - PredictKernel, - data, - model_.nodes.DeviceSpan(), predictions->DeviceSpan().subspan(batch_offset), - model_.tree_segments.DeviceSpan(), model_.tree_group.DeviceSpan(), - model_.tree_beg_, model_.tree_end_, num_features, num_rows, - entry_start, use_shared, model_.num_group); + PredictKernel, data, + model_.nodes.ConstDeviceSpan(), + predictions->DeviceSpan().subspan(batch_offset), + model_.tree_segments.ConstDeviceSpan(), model_.tree_group.ConstDeviceSpan(), + model_.split_types.ConstDeviceSpan(), + model_.categories_tree_segments.ConstDeviceSpan(), + model_.categories_node_segments.ConstDeviceSpan(), + model_.categories.ConstDeviceSpan(), model_.tree_beg_, model_.tree_end_, + num_features, num_rows, entry_start, use_shared, model_.num_group); } - void PredictInternal(EllpackDeviceAccessor const& batch, HostDeviceVector* out_preds, + void PredictInternal(EllpackDeviceAccessor const& batch, + HostDeviceVector* out_preds, size_t batch_offset) { const uint32_t BLOCK_THREADS = 256; size_t num_rows = batch.n_rows; @@ -396,12 +482,15 @@ class GPUPredictor : public xgboost::Predictor { bool use_shared = false; size_t entry_start = 0; dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS} ( - PredictKernel, - batch, - model_.nodes.DeviceSpan(), out_preds->DeviceSpan().subspan(batch_offset), - model_.tree_segments.DeviceSpan(), model_.tree_group.DeviceSpan(), - model_.tree_beg_, model_.tree_end_, batch.NumFeatures(), num_rows, - entry_start, use_shared, model_.num_group); + PredictKernel, batch, + model_.nodes.ConstDeviceSpan(), out_preds->DeviceSpan().subspan(batch_offset), + model_.tree_segments.ConstDeviceSpan(), model_.tree_group.ConstDeviceSpan(), + model_.split_types.ConstDeviceSpan(), + model_.categories_tree_segments.ConstDeviceSpan(), + model_.categories_node_segments.ConstDeviceSpan(), + model_.categories.ConstDeviceSpan(), model_.tree_beg_, model_.tree_end_, + batch.NumFeatures(), num_rows, entry_start, use_shared, + model_.num_group); } void DevicePredictInternal(DMatrix* dmat, HostDeviceVector* out_preds, @@ -413,6 +502,7 @@ class GPUPredictor : public xgboost::Predictor { } model_.Init(model, tree_begin, tree_end, generic_param_->gpu_id); out_preds->SetDevice(generic_param_->gpu_id); + auto const& info = dmat->Info(); if (dmat->PageExists()) { size_t batch_offset = 0; @@ -425,7 +515,8 @@ class GPUPredictor : public xgboost::Predictor { size_t batch_offset = 0; for (auto const& page : dmat->GetBatches()) { this->PredictInternal( - page.Impl()->GetDeviceAccessor(generic_param_->gpu_id), out_preds, + page.Impl()->GetDeviceAccessor(generic_param_->gpu_id), + out_preds, batch_offset); batch_offset += page.Impl()->n_rows; } @@ -528,12 +619,14 @@ class GPUPredictor : public xgboost::Predictor { size_t entry_start = 0; dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} ( - PredictKernel, - m->Value(), - d_model.nodes.DeviceSpan(), out_preds->predictions.DeviceSpan(), - d_model.tree_segments.DeviceSpan(), d_model.tree_group.DeviceSpan(), - tree_begin, tree_end, m->NumColumns(), info.num_row_, - entry_start, use_shared, output_groups); + PredictKernel, m->Value(), + d_model.nodes.ConstDeviceSpan(), out_preds->predictions.DeviceSpan(), + d_model.tree_segments.ConstDeviceSpan(), d_model.tree_group.ConstDeviceSpan(), + d_model.split_types.ConstDeviceSpan(), + d_model.categories_tree_segments.ConstDeviceSpan(), + d_model.categories_node_segments.ConstDeviceSpan(), + d_model.categories.ConstDeviceSpan(), tree_begin, tree_end, m->NumColumns(), + info.num_row_, entry_start, use_shared, output_groups); } void InplacePredict(dmlc::any const &x, const gbm::GBTreeModel &model, diff --git a/tests/cpp/predictor/test_gpu_predictor.cu b/tests/cpp/predictor/test_gpu_predictor.cu index 596385afdec5..b48e490864b1 100644 --- a/tests/cpp/predictor/test_gpu_predictor.cu +++ b/tests/cpp/predictor/test_gpu_predictor.cu @@ -221,5 +221,8 @@ TEST(GPUPredictor, Shap) { } } +TEST(GPUPredictor, CategoricalPrediction) { + TestCategoricalPrediction("gpu_predictor"); +} } // namespace predictor } // namespace xgboost diff --git a/tests/cpp/predictor/test_predictor.cc b/tests/cpp/predictor/test_predictor.cc index 3005d585f5a6..9206ba2aabf3 100644 --- a/tests/cpp/predictor/test_predictor.cc +++ b/tests/cpp/predictor/test_predictor.cc @@ -12,6 +12,8 @@ #include "../helpers.h" #include "../../../src/common/io.h" +#include "../../../src/common/categorical.h" +#include "../../../src/common/bitfield.h" namespace xgboost { TEST(Predictor, PredictionCache) { @@ -27,7 +29,7 @@ TEST(Predictor, PredictionCache) { }; add_cache(); - ASSERT_EQ(container.Container().size(), 0); + ASSERT_EQ(container.Container().size(), 0ul); add_cache(); EXPECT_ANY_THROW(container.Entry(m)); } @@ -174,4 +176,55 @@ void TestPredictionWithLesserFeatures(std::string predictor_name) { } #endif // defined(XGBOOST_USE_CUDA) } + +void TestCategoricalPrediction(std::string name) { + size_t constexpr kCols = 10; + PredictionCacheEntry out_predictions; + + LearnerModelParam param; + param.num_feature = kCols; + param.num_output_group = 1; + param.base_score = 0.5; + + gbm::GBTreeModel model(¶m); + + std::vector> trees; + trees.push_back(std::unique_ptr(new RegTree)); + auto& p_tree = trees.front(); + + uint32_t split_ind = 3; + bst_cat_t split_cat = 4; + float left_weight = 1.3f; + float right_weight = 1.7f; + + std::vector split_cats(LBitField32::ComputeStorageSize(split_cat)); + LBitField32 cats_bits(split_cats); + cats_bits.Set(split_cat); + + p_tree->ExpandCategorical(0, split_ind, split_cats, true, 1.5f, + left_weight, right_weight, + 3.0f, 2.2f, 7.0f, 9.0f); + model.CommitModel(std::move(trees), 0); + + GenericParameter runtime; + runtime.gpu_id = 0; + std::unique_ptr predictor{ + Predictor::Create(name.c_str(), &runtime)}; + + std::vector row(kCols); + row[split_ind] = split_cat; + auto m = GetDMatrixFromData(row, 1, kCols); + + predictor->PredictBatch(m.get(), &out_predictions, model, 0); + ASSERT_EQ(out_predictions.predictions.Size(), 1ul); + ASSERT_EQ(out_predictions.predictions.HostVector()[0], + right_weight + param.base_score); // go to right for matching cat + + row[split_ind] = split_cat + 1; + m = GetDMatrixFromData(row, 1, kCols); + out_predictions.version = 0; + predictor->PredictBatch(m.get(), &out_predictions, model, 0); + ASSERT_EQ(out_predictions.predictions.HostVector()[0], + left_weight + param.base_score); +} } // namespace xgboost diff --git a/tests/cpp/predictor/test_predictor.h b/tests/cpp/predictor/test_predictor.h index b6a3180111f2..68e034e0a581 100644 --- a/tests/cpp/predictor/test_predictor.h +++ b/tests/cpp/predictor/test_predictor.h @@ -61,6 +61,8 @@ void TestInplacePrediction(dmlc::any x, std::string predictor, int32_t device = -1); void TestPredictionWithLesserFeatures(std::string preditor_name); + +void TestCategoricalPrediction(std::string name); } // namespace xgboost #endif // XGBOOST_TEST_PREDICTOR_H_