diff --git a/include/xgboost/tree_model.h b/include/xgboost/tree_model.h index fd2d5cc1b2e0..8b15a81217bb 100644 --- a/include/xgboost/tree_model.h +++ b/include/xgboost/tree_model.h @@ -445,6 +445,10 @@ class RegTree : public Model { bst_float right_leaf_weight, bst_float loss_change, float sum_hess, float left_sum, float right_sum); + bool HasCategoricalSplit() const { + return !split_categories_.empty(); + } + /*! * \brief get current depth * \param nid node id @@ -537,13 +541,6 @@ class RegTree : public Model { std::vector data_; bool has_missing_; }; - /*! - * \brief get the leaf index - * \param feat dense feature vector, if the feature is missing the field is set to NaN - * \return the leaf index of the given feature - */ - template - int GetLeafIndex(const FVec& feat) const; /*! * \brief calculate the feature contributions (https://arxiv.org/abs/1706.06060) for the tree @@ -582,14 +579,6 @@ class RegTree : public Model { */ void CalculateContributionsApprox(const RegTree::FVec& feat, bst_float* out_contribs) const; - /*! - * \brief get next position of the tree given current pid - * \param pid Current node id. - * \param fvalue feature value if not missing. - * \param is_unknown Whether current required feature is missing. - */ - template - inline int GetNext(int pid, bst_float fvalue, bool is_unknown) const; /*! * \brief dump the model in the requested format as a text string * \param fmap feature map that may help give interpretations of feature @@ -627,6 +616,20 @@ class RegTree : public Model { size_t size {0}; }; + struct CategoricalSplitMatrix { + common::Span split_type; + common::Span categories; + common::Span node_ptr; + }; + + CategoricalSplitMatrix GetCategoriesMatrix() const { + CategoricalSplitMatrix view; + view.split_type = common::Span(this->GetSplitTypes()); + view.categories = this->GetSplitCategories(); + view.node_ptr = common::Span(split_categories_segments_); + return view; + } + private: void LoadCategoricalSplit(Json const& in); void SaveCategoricalSplit(Json* p_out) const; @@ -724,38 +727,5 @@ inline bool RegTree::FVec::IsMissing(size_t i) const { inline bool RegTree::FVec::HasMissing() const { return has_missing_; } - -template -inline int RegTree::GetLeafIndex(const RegTree::FVec& feat) const { - bst_node_t nid = 0; - while (!(*this)[nid].IsLeaf()) { - unsigned split_index = (*this)[nid].SplitIndex(); - nid = this->GetNext(nid, feat.GetFvalue(split_index), - has_missing && feat.IsMissing(split_index)); - } - return nid; -} - -/*! \brief get next position of the tree given current pid */ -template -inline int RegTree::GetNext(int pid, bst_float fvalue, bool is_unknown) const { - if (has_missing) { - if (is_unknown) { - return (*this)[pid].DefaultChild(); - } else { - if (fvalue < (*this)[pid].SplitCond()) { - return (*this)[pid].LeftChild(); - } else { - return (*this)[pid].RightChild(); - } - } - } else { - // 35% speed up due to reduced miss branch predictions - // The following expression returns the left child if (fvalue < split_cond); - // the right child otherwise. - return (*this)[pid].LeftChild() + !(fvalue < (*this)[pid].SplitCond()); - } -} - } // namespace xgboost #endif // XGBOOST_TREE_MODEL_H_ diff --git a/src/predictor/cpu_predictor.cc b/src/predictor/cpu_predictor.cc index 41584a50e595..58d1633b8c63 100644 --- a/src/predictor/cpu_predictor.cc +++ b/src/predictor/cpu_predictor.cc @@ -16,9 +16,11 @@ #include "xgboost/logging.h" #include "xgboost/host_device_vector.h" +#include "predict_fn.h" #include "../data/adapter.h" #include "../common/math.h" #include "../common/threading_utils.h" +#include "../common/categorical.h" #include "../gbm/gbtree_model.h" namespace xgboost { @@ -26,6 +28,19 @@ namespace predictor { DMLC_REGISTRY_FILE_TAG(cpu_predictor); +template +bst_node_t GetLeafIndex(RegTree const &tree, const RegTree::FVec &feat, + RegTree::CategoricalSplitMatrix const& cats) { + bst_node_t nid = 0; + while (!tree[nid].IsLeaf()) { + unsigned split_index = tree[nid].SplitIndex(); + auto fvalue = feat.GetFvalue(split_index); + nid = GetNextNode( + tree[nid], nid, fvalue, has_missing && feat.IsMissing(split_index), cats); + } + return nid; +} + bst_float PredValue(const SparsePage::Inst &inst, const std::vector> &trees, const std::vector &tree_info, int bst_group, @@ -35,32 +50,59 @@ bst_float PredValue(const SparsePage::Inst &inst, p_feats->Fill(inst); for (size_t i = tree_begin; i < tree_end; ++i) { if (tree_info[i] == bst_group) { - int tid = trees[i]->GetLeafIndex(*p_feats); - psum += (*trees[i])[tid].LeafValue(); + auto const &tree = *trees[i]; + bool has_categorical = tree.HasCategoricalSplit(); + + auto categories = common::Span{tree.GetSplitCategories()}; + auto split_types = tree.GetSplitTypes(); + auto categories_ptr = + common::Span{tree.GetSplitCategoriesPtr()}; + auto cats = tree.GetCategoriesMatrix(); + bst_node_t nidx = -1; + if (has_categorical) { + nidx = GetLeafIndex(tree, *p_feats, cats); + } else { + nidx = GetLeafIndex(tree, *p_feats, cats); + } + psum += (*trees[i])[nidx].LeafValue(); } } p_feats->Drop(inst); return psum; } -inline bst_float PredValueByOneTree(const RegTree::FVec& p_feats, - const std::unique_ptr& tree) { - const int lid = p_feats.HasMissing() ? tree->GetLeafIndex(p_feats) : - tree->GetLeafIndex(p_feats); // 35% speed up - return (*tree)[lid].LeafValue(); +template +bst_float +PredValueByOneTree(const RegTree::FVec &p_feats, RegTree const &tree, + RegTree::CategoricalSplitMatrix const& cats) { + const bst_node_t leaf = p_feats.HasMissing() ? + GetLeafIndex(tree, p_feats, cats) : + GetLeafIndex(tree, p_feats, cats); + return tree[leaf].LeafValue(); } -inline void PredictByAllTrees(gbm::GBTreeModel const &model, const size_t tree_begin, - const size_t tree_end, std::vector* out_preds, - const size_t predict_offset, const size_t num_group, - const std::vector &thread_temp, - const size_t offset, const size_t block_size) { +void PredictByAllTrees(gbm::GBTreeModel const &model, const size_t tree_begin, + const size_t tree_end, std::vector *out_preds, + const size_t predict_offset, const size_t num_group, + const std::vector &thread_temp, + const size_t offset, const size_t block_size) { std::vector &preds = *out_preds; for (size_t tree_id = tree_begin; tree_id < tree_end; ++tree_id) { const size_t gid = model.tree_info[tree_id]; - for (size_t i = 0; i < block_size; ++i) { - preds[(predict_offset + i) * num_group + gid] += PredValueByOneTree(thread_temp[offset + i], - model.trees[tree_id]); + auto const &tree = *model.trees[tree_id]; + auto const& cats = tree.GetCategoriesMatrix(); + auto has_categorical = tree.HasCategoricalSplit(); + + if (has_categorical) { + for (size_t i = 0; i < block_size; ++i) { + preds[(predict_offset + i) * num_group + gid] += + PredValueByOneTree(thread_temp[offset + i], tree, cats); + } + } else { + for (size_t i = 0; i < block_size; ++i) { + preds[(predict_offset + i) * num_group + gid] += + PredValueByOneTree(thread_temp[offset + i], tree, cats); + } } } } @@ -77,6 +119,7 @@ void FVecFill(const size_t block_size, const size_t batch_offset, const int num_ feats.Fill(inst); } } + template void FVecDrop(const size_t block_size, const size_t batch_offset, DataView* batch, const size_t fvec_offset, std::vector* p_feats) { @@ -145,11 +188,11 @@ class AdapterView { }; template -void PredictBatchByBlockOfRowsKernel(DataView batch, std::vector *out_preds, - gbm::GBTreeModel const &model, int32_t tree_begin, - int32_t tree_end, - std::vector *p_thread_temp) { - auto& thread_temp = *p_thread_temp; +void PredictBatchByBlockOfRowsKernel( + DataView batch, std::vector *out_preds, + gbm::GBTreeModel const &model, int32_t tree_begin, int32_t tree_end, + std::vector *p_thread_temp) { + auto &thread_temp = *p_thread_temp; int32_t const num_group = model.learner_model_param->num_output_group; CHECK_EQ(model.param.size_leaf_vector, 0) @@ -157,16 +200,20 @@ void PredictBatchByBlockOfRowsKernel(DataView batch, std::vector *out // parallel over local batch const auto nsize = static_cast(batch.Size()); const int num_feature = model.learner_model_param->num_feature; - const bst_omp_uint n_row_blocks = (nsize) / block_of_rows_size + !!((nsize) % block_of_rows_size); - common::ParallelFor(n_row_blocks, [&](bst_omp_uint block_id) { + omp_ulong n_blocks = common::DivRoundUp(nsize, block_of_rows_size); + + common::ParallelFor(n_blocks, [&](bst_omp_uint block_id) { const size_t batch_offset = block_id * block_of_rows_size; - const size_t block_size = std::min(nsize - batch_offset, block_of_rows_size); + const size_t block_size = + std::min(nsize - batch_offset, block_of_rows_size); const size_t fvec_offset = omp_get_thread_num() * block_of_rows_size; - FVecFill(block_size, batch_offset, num_feature, &batch, fvec_offset, p_thread_temp); + FVecFill(block_size, batch_offset, num_feature, &batch, fvec_offset, + p_thread_temp); // process block of rows through all trees to keep cache locality - PredictByAllTrees(model, tree_begin, tree_end, out_preds, batch_offset + batch.base_rowid, - num_group, thread_temp, fvec_offset, block_size); + PredictByAllTrees(model, tree_begin, tree_end, out_preds, + batch_offset + batch.base_rowid, num_group, thread_temp, + fvec_offset, block_size); FVecDrop(block_size, batch_offset, &batch, fvec_offset, p_thread_temp); }); } @@ -344,7 +391,9 @@ class CPUPredictor : public Predictor { } feats.Fill(page[i]); for (unsigned j = 0; j < ntree_limit; ++j) { - int tid = model.trees[j]->GetLeafIndex(feats); + auto const& tree = *model.trees[j]; + auto const& cats = tree.GetCategoriesMatrix(); + bst_node_t tid = GetLeafIndex(tree, feats, cats); preds[ridx * ntree_limit + j] = static_cast(tid); } feats.Drop(page[i]); diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index 8a576a4cefbc..002ee4c3dd90 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -14,6 +14,7 @@ #include "xgboost/tree_updater.h" #include "xgboost/host_device_vector.h" +#include "predict_fn.h" #include "../gbm/gbtree_model.h" #include "../data/ellpack_page.cuh" #include "../data/device_adapter.cuh" @@ -27,6 +28,42 @@ namespace predictor { DMLC_REGISTRY_FILE_TAG(gpu_predictor); +struct TreeView { + RegTree::CategoricalSplitMatrix cats; + common::Span d_tree; + + XGBOOST_DEVICE + TreeView(size_t tree_begin, size_t tree_idx, + common::Span d_nodes, + common::Span d_tree_segments, + common::Span d_tree_split_types, + common::Span d_cat_tree_segments, + common::Span d_cat_node_segments, + common::Span d_categories) { + auto begin = d_tree_segments[tree_idx - tree_begin]; + auto n_nodes = d_tree_segments[tree_idx - tree_begin + 1] - + d_tree_segments[tree_idx - tree_begin]; + + d_tree = d_nodes.subspan(begin, n_nodes); + + auto tree_cat_ptrs = d_cat_node_segments.subspan(begin, n_nodes); + auto tree_split_types = d_tree_split_types.subspan(begin, n_nodes); + + auto tree_categories = + d_categories.subspan(d_cat_tree_segments[tree_idx - tree_begin], + d_cat_tree_segments[tree_idx - tree_begin + 1] - + d_cat_tree_segments[tree_idx - tree_begin]); + + cats.split_type = tree_split_types; + cats.categories = tree_categories; + cats.node_ptr = tree_cat_ptrs; + } + + __device__ bool HasCategoricalSplit() const { + return !cats.categories.empty(); + } +}; + struct SparsePageView { common::Span d_data; common::Span d_row_ptr; @@ -178,84 +215,69 @@ struct DeviceAdapterLoader { } }; -template -__device__ float GetLeafWeight(bst_row_t ridx, const RegTree::Node* tree, - common::Span split_types, - common::Span d_cat_ptrs, - common::Span d_categories, - Loader* loader) { +template +__device__ bst_node_t GetLeafIndex(bst_row_t ridx, TreeView const &tree, + Loader *loader) { bst_node_t nidx = 0; - RegTree::Node n = tree[nidx]; + RegTree::Node n = tree.d_tree[nidx]; while (!n.IsLeaf()) { float fvalue = loader->GetElement(ridx, n.SplitIndex()); - // Missing value - if (common::CheckNAN(fvalue)) { - nidx = n.DefaultChild(); - } else { - bool go_left = true; - if (common::IsCat(split_types, nidx)) { - auto categories = d_categories.subspan(d_cat_ptrs[nidx].beg, - d_cat_ptrs[nidx].size); - go_left = Decision(categories, common::AsCat(fvalue)); - } else { - go_left = fvalue < n.SplitCond(); - } - if (go_left) { - nidx = n.LeftChild(); - } else { - nidx = n.RightChild(); - } - } - n = tree[nidx]; + bool is_missing = common::CheckNAN(fvalue); + nidx = GetNextNode(n, nidx, fvalue, + is_missing, tree.cats); + n = tree.d_tree[nidx]; } - return tree[nidx].LeafValue(); + return nidx; } -template -__device__ bst_node_t GetLeafIndex(bst_row_t ridx, const RegTree::Node* tree, - Loader const& loader) { - bst_node_t nidx = 0; - RegTree::Node n = tree[nidx]; - while (!n.IsLeaf()) { - float fvalue = loader.GetElement(ridx, n.SplitIndex()); - // Missing value - if (common::CheckNAN(fvalue)) { - nidx = n.DefaultChild(); - n = tree[nidx]; - } else { - if (fvalue < n.SplitCond()) { - nidx = n.LeftChild(); - n = tree[nidx]; - } else { - nidx = n.RightChild(); - n = tree[nidx]; - } - } +template +__device__ float GetLeafWeight(bst_row_t ridx, TreeView const &tree, + Loader *loader) { + bst_node_t nidx = -1; + if (tree.HasCategoricalSplit()) { + nidx = GetLeafIndex(ridx, tree, loader); + } else { + nidx = GetLeafIndex(ridx, tree, loader); } - return nidx; + return tree.d_tree[nidx].LeafValue(); } template -__global__ void PredictLeafKernel(Data data, - common::Span d_nodes, - common::Span d_out_predictions, - common::Span d_tree_segments, - size_t tree_begin, size_t tree_end, size_t num_features, - size_t num_rows, size_t entry_start, bool use_shared, - float missing) { +__global__ void +PredictLeafKernel(Data data, common::Span d_nodes, + common::Span d_out_predictions, + common::Span d_tree_segments, + + common::Span d_tree_split_types, + common::Span d_cat_tree_segments, + common::Span d_cat_node_segments, + common::Span d_categories, + + size_t tree_begin, size_t tree_end, size_t num_features, + size_t num_rows, size_t entry_start, bool use_shared, + float missing) { bst_row_t ridx = blockDim.x * blockIdx.x + threadIdx.x; if (ridx >= num_rows) { return; } Loader loader(data, use_shared, num_features, num_rows, entry_start, missing); - for (int tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) { - const RegTree::Node* d_tree = &d_nodes[d_tree_segments[tree_idx - tree_begin]]; - auto leaf = GetLeafIndex(ridx, d_tree, loader); + for (size_t tree_idx = tree_begin; tree_idx < tree_end; ++tree_idx) { + TreeView d_tree{ + tree_begin, tree_idx, d_nodes, + d_tree_segments, d_tree_split_types, d_cat_tree_segments, + d_cat_node_segments, d_categories}; + + bst_node_t leaf = -1; + if (d_tree.HasCategoricalSplit()) { + leaf = GetLeafIndex(ridx, d_tree, &loader); + } else { + leaf = GetLeafIndex(ridx, d_tree, &loader); + } d_out_predictions[ridx * (tree_end - tree_begin) + tree_idx] = leaf; } } -template +template __global__ void PredictKernel(Data data, common::Span d_nodes, common::Span d_out_predictions, @@ -272,47 +294,25 @@ PredictKernel(Data data, common::Span d_nodes, if (global_idx >= num_rows) return; if (num_group == 1) { float sum = 0; - for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { - const RegTree::Node* d_tree = - &d_nodes[d_tree_segments[tree_idx - tree_begin]]; - auto tree_cat_ptrs = d_cat_node_segments.subspan( - d_tree_segments[tree_idx - tree_begin], - d_tree_segments[tree_idx - tree_begin + 1] - - d_tree_segments[tree_idx - tree_begin]); - auto tree_categories = - d_categories.subspan(d_cat_tree_segments[tree_idx - tree_begin], - d_cat_tree_segments[tree_idx - tree_begin + 1] - - d_cat_tree_segments[tree_idx - tree_begin]); - auto tree_split_types = - d_tree_split_types.subspan(d_tree_segments[tree_idx - tree_begin], - d_tree_segments[tree_idx - tree_begin + 1] - - d_tree_segments[tree_idx - tree_begin]); - float leaf = GetLeafWeight(global_idx, d_tree, tree_split_types, - tree_cat_ptrs, - tree_categories, - &loader); + for (size_t tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { + TreeView d_tree{ + tree_begin, tree_idx, d_nodes, + d_tree_segments, d_tree_split_types, d_cat_tree_segments, + d_cat_node_segments, d_categories}; + float leaf = GetLeafWeight(global_idx, d_tree, &loader); sum += leaf; } d_out_predictions[global_idx] += sum; } else { - for (int tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { + for (size_t tree_idx = tree_begin; tree_idx < tree_end; tree_idx++) { int tree_group = d_tree_group[tree_idx]; - const RegTree::Node* d_tree = - &d_nodes[d_tree_segments[tree_idx - tree_begin]]; + TreeView d_tree{ + tree_begin, tree_idx, d_nodes, + d_tree_segments, d_tree_split_types, d_cat_tree_segments, + d_cat_node_segments, d_categories}; bst_uint out_prediction_idx = global_idx * num_group + tree_group; - auto tree_cat_ptrs = d_cat_node_segments.subspan( - d_tree_segments[tree_idx - tree_begin], - d_tree_segments[tree_idx - tree_begin + 1] - - d_tree_segments[tree_idx - tree_begin]); - auto tree_categories = - d_categories.subspan(d_cat_tree_segments[tree_idx - tree_begin], - d_cat_tree_segments[tree_idx - tree_begin + 1] - - d_cat_tree_segments[tree_idx - tree_begin]); d_out_predictions[out_prediction_idx] += - GetLeafWeight(global_idx, d_tree, d_tree_split_types, - tree_cat_ptrs, - tree_categories, - &loader); + GetLeafWeight(global_idx, d_tree, &loader); } } } @@ -515,7 +515,7 @@ class GPUPredictor : public xgboost::Predictor { DeviceModel const& model, size_t num_features, HostDeviceVector* predictions, - size_t batch_offset) const { + size_t batch_offset, bool is_dense) const { batch.offset.SetDevice(generic_param_->gpu_id); batch.data.SetDevice(generic_param_->gpu_id); const uint32_t BLOCK_THREADS = 128; @@ -529,16 +529,24 @@ class GPUPredictor : public xgboost::Predictor { size_t entry_start = 0; SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(), num_features); - dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} ( - PredictKernel, data, - model.nodes.ConstDeviceSpan(), - predictions->DeviceSpan().subspan(batch_offset), - model.tree_segments.ConstDeviceSpan(), model.tree_group.ConstDeviceSpan(), - model.split_types.ConstDeviceSpan(), - model.categories_tree_segments.ConstDeviceSpan(), - model.categories_node_segments.ConstDeviceSpan(), - model.categories.ConstDeviceSpan(), model.tree_beg_, model.tree_end_, - num_features, num_rows, entry_start, use_shared, model.num_group, nan("")); + auto const kernel = [&](auto predict_fn) { + dh::LaunchKernel {GRID_SIZE, BLOCK_THREADS, shared_memory_bytes} ( + predict_fn, data, model.nodes.ConstDeviceSpan(), + predictions->DeviceSpan().subspan(batch_offset), + model.tree_segments.ConstDeviceSpan(), + model.tree_group.ConstDeviceSpan(), + model.split_types.ConstDeviceSpan(), + model.categories_tree_segments.ConstDeviceSpan(), + model.categories_node_segments.ConstDeviceSpan(), + model.categories.ConstDeviceSpan(), model.tree_beg_, model.tree_end_, + num_features, num_rows, entry_start, use_shared, model.num_group, + nan("")); + }; + if (is_dense) { + kernel(PredictKernel); + } else { + kernel(PredictKernel); + } } void PredictInternal(EllpackDeviceAccessor const& batch, DeviceModel const& model, @@ -578,7 +586,7 @@ class GPUPredictor : public xgboost::Predictor { size_t batch_offset = 0; for (auto &batch : dmat->GetBatches()) { this->PredictInternal(batch, d_model, model.learner_model_param->num_feature, - out_preds, batch_offset); + out_preds, batch_offset, dmat->IsDense()); batch_offset += batch.Size() * model.learner_model_param->num_output_group; } } else { @@ -846,6 +854,12 @@ class GPUPredictor : public xgboost::Predictor { d_model.nodes.ConstDeviceSpan(), predictions->DeviceSpan().subspan(batch_offset), d_model.tree_segments.ConstDeviceSpan(), + + d_model.split_types.ConstDeviceSpan(), + d_model.categories_tree_segments.ConstDeviceSpan(), + d_model.categories_node_segments.ConstDeviceSpan(), + d_model.categories.ConstDeviceSpan(), + d_model.tree_beg_, d_model.tree_end_, num_features, num_rows, entry_start, use_shared, nan("")); batch_offset += batch.Size(); @@ -862,6 +876,12 @@ class GPUPredictor : public xgboost::Predictor { d_model.nodes.ConstDeviceSpan(), predictions->DeviceSpan().subspan(batch_offset), d_model.tree_segments.ConstDeviceSpan(), + + d_model.split_types.ConstDeviceSpan(), + d_model.categories_tree_segments.ConstDeviceSpan(), + d_model.categories_node_segments.ConstDeviceSpan(), + d_model.categories.ConstDeviceSpan(), + d_model.tree_beg_, d_model.tree_end_, num_features, num_rows, entry_start, use_shared, nan("")); batch_offset += batch.Size(); diff --git a/src/predictor/predict_fn.h b/src/predictor/predict_fn.h new file mode 100644 index 000000000000..1547d6e774ae --- /dev/null +++ b/src/predictor/predict_fn.h @@ -0,0 +1,31 @@ +/*! + * Copyright 2021 by XGBoost Contributors + */ +#ifndef XGBOOST_PREDICTOR_PREDICT_FN_H_ +#define XGBOOST_PREDICTOR_PREDICT_FN_H_ +#include "../common/categorical.h" +#include "xgboost/tree_model.h" + +namespace xgboost { +namespace predictor { +template +inline XGBOOST_DEVICE bst_node_t +GetNextNode(const RegTree::Node &node, const bst_node_t nid, float fvalue, + bool is_missing, RegTree::CategoricalSplitMatrix const &cats) { + if (has_missing && is_missing) { + return node.DefaultChild(); + } else { + if (has_categorical && common::IsCat(cats.split_type, nid)) { + auto node_categories = cats.categories.subspan(cats.node_ptr[nid].beg, + cats.node_ptr[nid].size); + return Decision(node_categories, common::AsCat(fvalue)) + ? node.LeftChild() + : node.RightChild(); + } else { + return node.LeftChild() + !(fvalue < node.SplitCond()); + } + } +} +} // namespace predictor +} // namespace xgboost +#endif // XGBOOST_PREDICTOR_PREDICT_FN_H_ diff --git a/src/tree/tree_model.cc b/src/tree/tree_model.cc index 06354bbf5e1e..1847015bd3a7 100644 --- a/src/tree/tree_model.cc +++ b/src/tree/tree_model.cc @@ -19,6 +19,7 @@ #include "param.h" #include "../common/common.h" #include "../common/categorical.h" +#include "../predictor/predict_fn.h" namespace xgboost { // register tree parameter @@ -1052,10 +1053,15 @@ void RegTree::CalculateContributionsApprox(const RegTree::FVec &feat, // nothing to do anymore return; } + bst_node_t nid = 0; + auto cats = this->GetCategoriesMatrix(); + while (!(*this)[nid].IsLeaf()) { split_index = (*this)[nid].SplitIndex(); - nid = this->GetNext(nid, feat.GetFvalue(split_index), feat.IsMissing(split_index)); + nid = predictor::GetNextNode((*this)[nid], nid, + feat.GetFvalue(split_index), + feat.IsMissing(split_index), cats); bst_float new_value = this->node_mean_values_[nid]; // update feature weight out_contribs[split_index] += new_value - node_value; diff --git a/src/tree/updater_refresh.cc b/src/tree/updater_refresh.cc index 968cf0320994..1d54ad9e3d9b 100644 --- a/src/tree/updater_refresh.cc +++ b/src/tree/updater_refresh.cc @@ -14,6 +14,7 @@ #include "./param.h" #include "../common/io.h" #include "../common/threading_utils.h" +#include "../predictor/predict_fn.h" namespace xgboost { namespace tree { @@ -123,10 +124,13 @@ class TreeRefresher: public TreeUpdater { // start from groups that belongs to current data auto pid = 0; gstats[pid].Add(gpair[ridx]); + auto const& cats = tree.GetCategoriesMatrix(); // traverse tree while (!tree[pid].IsLeaf()) { unsigned split_index = tree[pid].SplitIndex(); - pid = tree.GetNext(pid, feat.GetFvalue(split_index), feat.IsMissing(split_index)); + pid = predictor::GetNextNode( + tree[pid], pid, feat.GetFvalue(split_index), feat.IsMissing(split_index), + cats); gstats[pid].Add(gpair[ridx]); } } diff --git a/tests/cpp/predictor/test_cpu_predictor.cc b/tests/cpp/predictor/test_cpu_predictor.cc index ab1454bae5bd..13808e5c9a54 100644 --- a/tests/cpp/predictor/test_cpu_predictor.cc +++ b/tests/cpp/predictor/test_cpu_predictor.cc @@ -229,9 +229,17 @@ void TestUpdatePredictionCache(bool use_subsampling) { } } +TEST(CPUPredictor, CategoricalPrediction) { + TestCategoricalPrediction("cpu_predictor"); +} + +TEST(CPUPredictor, CategoricalPredictLeaf) { + TestCategoricalPredictLeaf(StringView{"cpu_predictor"}); +} + TEST(CpuPredictor, UpdatePredictionCache) { - TestUpdatePredictionCache(false); - TestUpdatePredictionCache(true); + TestUpdatePredictionCache(false); + TestUpdatePredictionCache(true); } TEST(CpuPredictor, LesserFeatures) { diff --git a/tests/cpp/predictor/test_gpu_predictor.cu b/tests/cpp/predictor/test_gpu_predictor.cu index 79ea0c8cf9b0..d4a5dce63e29 100644 --- a/tests/cpp/predictor/test_gpu_predictor.cu +++ b/tests/cpp/predictor/test_gpu_predictor.cu @@ -228,6 +228,10 @@ TEST(GPUPredictor, CategoricalPrediction) { TestCategoricalPrediction("gpu_predictor"); } +TEST(GPUPredictor, CategoricalPredictLeaf) { + TestCategoricalPredictLeaf(StringView{"gpu_predictor"}); +} + TEST(GPUPredictor, PredictLeafBasic) { size_t constexpr kRows = 5, kCols = 5; auto dmat = RandomDataGenerator(kRows, kCols, 0).Device(0).GenerateDMatrix(); diff --git a/tests/cpp/predictor/test_predictor.cc b/tests/cpp/predictor/test_predictor.cc index 388a59cb8157..0705dac40e3f 100644 --- a/tests/cpp/predictor/test_predictor.cc +++ b/tests/cpp/predictor/test_predictor.cc @@ -180,6 +180,25 @@ void TestPredictionWithLesserFeatures(std::string predictor_name) { #endif // defined(XGBOOST_USE_CUDA) } +void GBTreeModelForTest(gbm::GBTreeModel *model, uint32_t split_ind, + bst_cat_t split_cat, float left_weight, + float right_weight) { + PredictionCacheEntry out_predictions; + + std::vector> trees; + trees.push_back(std::unique_ptr(new RegTree)); + auto& p_tree = trees.front(); + + std::vector split_cats(LBitField32::ComputeStorageSize(split_cat)); + LBitField32 cats_bits(split_cats); + cats_bits.Set(split_cat); + + p_tree->ExpandCategorical(0, split_ind, split_cats, true, 1.5f, + left_weight, right_weight, + 3.0f, 2.2f, 7.0f, 9.0f); + model->CommitModel(std::move(trees), 0); +} + void TestCategoricalPrediction(std::string name) { size_t constexpr kCols = 10; PredictionCacheEntry out_predictions; @@ -189,25 +208,13 @@ void TestCategoricalPrediction(std::string name) { param.num_output_group = 1; param.base_score = 0.5; - gbm::GBTreeModel model(¶m); - - std::vector> trees; - trees.push_back(std::unique_ptr(new RegTree)); - auto& p_tree = trees.front(); - uint32_t split_ind = 3; bst_cat_t split_cat = 4; float left_weight = 1.3f; float right_weight = 1.7f; - std::vector split_cats(LBitField32::ComputeStorageSize(split_cat)); - LBitField32 cats_bits(split_cats); - cats_bits.Set(split_cat); - - p_tree->ExpandCategorical(0, split_ind, split_cats, true, 1.5f, - left_weight, right_weight, - 3.0f, 2.2f, 7.0f, 9.0f); - model.CommitModel(std::move(trees), 0); + gbm::GBTreeModel model(¶m); + GBTreeModelForTest(&model, split_ind, split_cat, left_weight, right_weight); GenericParameter runtime; runtime.gpu_id = 0; @@ -232,4 +239,43 @@ void TestCategoricalPrediction(std::string name) { ASSERT_EQ(out_predictions.predictions.HostVector()[0], left_weight + param.base_score); } + +void TestCategoricalPredictLeaf(StringView name) { + size_t constexpr kCols = 10; + PredictionCacheEntry out_predictions; + + LearnerModelParam param; + param.num_feature = kCols; + param.num_output_group = 1; + param.base_score = 0.5; + + uint32_t split_ind = 3; + bst_cat_t split_cat = 4; + float left_weight = 1.3f; + float right_weight = 1.7f; + + gbm::GBTreeModel model(¶m); + GBTreeModelForTest(&model, split_ind, split_cat, left_weight, right_weight); + + GenericParameter runtime; + runtime.gpu_id = 0; + std::unique_ptr predictor{ + Predictor::Create(name.c_str(), &runtime)}; + + std::vector row(kCols); + row[split_ind] = split_cat; + auto m = GetDMatrixFromData(row, 1, kCols); + + predictor->PredictLeaf(m.get(), &out_predictions.predictions, model); + CHECK_EQ(out_predictions.predictions.Size(), 1); + // go to left if it doesn't match the category, otherwise right. + ASSERT_EQ(out_predictions.predictions.HostVector()[0], 2); + + row[split_ind] = split_cat + 1; + m = GetDMatrixFromData(row, 1, kCols); + out_predictions.version = 0; + predictor->InitOutPredictions(m->Info(), &out_predictions.predictions, model); + predictor->PredictLeaf(m.get(), &out_predictions.predictions, model); + ASSERT_EQ(out_predictions.predictions.HostVector()[0], 1); +} } // namespace xgboost diff --git a/tests/cpp/predictor/test_predictor.h b/tests/cpp/predictor/test_predictor.h index 296d532d6f8c..d5eccf6a0668 100644 --- a/tests/cpp/predictor/test_predictor.h +++ b/tests/cpp/predictor/test_predictor.h @@ -66,6 +66,8 @@ void TestInplacePrediction(dmlc::any x, std::string predictor, void TestPredictionWithLesserFeatures(std::string preditor_name); void TestCategoricalPrediction(std::string name); + +void TestCategoricalPredictLeaf(StringView name); } // namespace xgboost #endif // XGBOOST_TEST_PREDICTOR_H_