diff --git a/demo/guide-python/cat_in_the_dat.py b/demo/guide-python/cat_in_the_dat.py index 35840b44fb0b..27741f37abaf 100644 --- a/demo/guide-python/cat_in_the_dat.py +++ b/demo/guide-python/cat_in_the_dat.py @@ -61,7 +61,12 @@ def load_cat_in_the_dat() -> tuple[pd.DataFrame, pd.Series]: return X, y -params = {"tree_method": "gpu_hist", "use_label_encoder": False, "n_estimators": 32} +params = { + "tree_method": "gpu_hist", + "use_label_encoder": False, + "n_estimators": 32, + "colsample_bylevel": 0.7, +} def categorical_model(X: pd.DataFrame, y: pd.Series, output_dir: str) -> None: @@ -70,13 +75,13 @@ def categorical_model(X: pd.DataFrame, y: pd.Series, output_dir: str) -> None: X, y, random_state=1994, test_size=0.2 ) # Specify `enable_categorical`. - clf = xgb.XGBClassifier(**params, enable_categorical=True) - clf.fit( - X_train, - y_train, - eval_set=[(X_test, y_test), (X_train, y_train)], + clf = xgb.XGBClassifier( + **params, eval_metric="auc", + enable_categorical=True, + max_cat_to_onehot=1, # We use optimal partitioning exclusively ) + clf.fit(X_train, y_train, eval_set=[(X_test, y_test), (X_train, y_train)]) clf.save_model(os.path.join(output_dir, "categorical.json")) y_score = clf.predict_proba(X_test)[:, 1] # proba of positive samples diff --git a/demo/guide-python/categorical.py b/demo/guide-python/categorical.py index eed823ae8bb3..7af8b9e213fc 100644 --- a/demo/guide-python/categorical.py +++ b/demo/guide-python/categorical.py @@ -3,15 +3,15 @@ ===================================== Experimental support for categorical data. After 1.5 XGBoost `gpu_hist` tree method has -experimental support for one-hot encoding based tree split, and in 1.6 `approx` supported +experimental support for one-hot encoding based tree split, and in 1.6 `approx` support was added. In before, users need to run an encoder themselves before passing the data into XGBoost, -which creates a sparse matrix and potentially increase memory usage. This demo showcases -the experimental categorical data support, more advanced features are planned. - -Also, see :doc:`the tutorial ` for using XGBoost with categorical data +which creates a sparse matrix and potentially increase memory usage. This demo +showcases the experimental categorical data support, more advanced features are planned. +Also, see :doc:`the tutorial ` for using XGBoost with +categorical data. .. versionadded:: 1.5.0 @@ -55,8 +55,11 @@ def main() -> None: # For scikit-learn interface, the input data must be pandas DataFrame or cudf # DataFrame with categorical features X, y = make_categorical(100, 10, 4, False) - # Specify `enable_categorical` to True. - reg = xgb.XGBRegressor(tree_method="gpu_hist", enable_categorical=True) + # Specify `enable_categorical` to True, also we use onehot encoding based split + # here for demonstration. For details see the document of `max_cat_to_onehot`. + reg = xgb.XGBRegressor( + tree_method="gpu_hist", enable_categorical=True, max_cat_to_onehot=5 + ) reg.fit(X, y, eval_set=[(X, y)]) # Pass in already encoded data diff --git a/doc/parameter.rst b/doc/parameter.rst index 0ec58026a15c..2189cf65d25d 100644 --- a/doc/parameter.rst +++ b/doc/parameter.rst @@ -245,8 +245,8 @@ Additional parameters for ``hist``, ``gpu_hist`` and ``approx`` tree method - Use single precision to build histograms instead of double precision. -Additional parameters for ``approx`` tree method -================================================ +Additional parameters for ``approx`` and ``gpu_hist`` tree method +================================================================= * ``max_cat_to_onehot`` @@ -257,7 +257,8 @@ Additional parameters for ``approx`` tree method - A threshold for deciding whether XGBoost should use one-hot encoding based split for categorical data. When number of categories is lesser than the threshold then one-hot encoding is chosen, otherwise the categories will be partitioned into children nodes. - Only relevant for regression and binary classification with `approx` tree method. + Only relevant for regression and binary classification. Also, `approx` or `gpu_hist` + tree method is required. Additional parameters for Dart Booster (``booster=dart``) ========================================================= diff --git a/doc/tutorials/categorical.rst b/doc/tutorials/categorical.rst index c1d93fb45df4..65081be57030 100644 --- a/doc/tutorials/categorical.rst +++ b/doc/tutorials/categorical.rst @@ -2,6 +2,10 @@ Categorical Data ################ +.. note:: + + As of XGBoost 1.6, the feature is highly experimental and has limited features + Starting from version 1.5, XGBoost has experimental support for categorical data available for public testing. At the moment, the support is implemented as one-hot encoding based categorical tree splits. For numerical data, the split condition is defined as @@ -107,6 +111,28 @@ For numerical data, the feature type can be ``"q"`` or ``"float"``, while for ca feature it's specified as ``"c"``. The Dask module in XGBoost has the same interface so :class:`dask.Array ` can also be used as categorical data. +******************** +Optimal Partitioning +******************** + +.. versionadded:: 1.6 + +Optimal partitioning is a technique for partitioning the categorical predictors for each +node split, the proof of optimality for numerical objectives like ``RMSE`` was first +introduced by `[1] <#references>`__. The algorithm is used in decision trees for handling +regression and binary classification tasks `[2] <#references>`__, later LightGBM `[3] +<#references>`__ brought it to the context of gradient boosting trees and now is also +adopted in XGBoost as an optional feature for handling categorical splits. More +specifically, the proof by Fisher `[1] <#references>`__ states that, when trying to +partition a set of discrete values into groups based on the distances between a measure of +these values, one only needs to look at sorted partitions instead of enumerating all +possible permutations. In the context of decision trees, the discrete values are +categories, and the measure is the output leaf value. Intuitively, we want to group the +categories that output similar leaf values. During split finding, we first sort the +gradient histogram to prepare the contiguous partitions then enumerate the splits +according to these sorted values. One of the related parameters for XGBoost is +``max_cat_to_one_hot``, which controls whether one-hot encoding or partitioning should be +used for each feature, see :doc:`/parameter` for details. ************* Miscellaneous @@ -120,10 +146,20 @@ actual number of unique categories. During training this is validated but for p it's treated as the same as missing value for performance reasons. Lastly, missing values are treated as the same as numerical features (using the learned split direction). + ********** -Next Steps +References ********** -As of XGBoost 1.5, the feature is highly experimental and have limited features like CPU -training is not yet supported. Please see `this issue -`_ for progress. +[1] Walter D. Fisher. "`On Grouping for Maximum Homogeneity`_." Journal of the American Statistical Association. Vol. 53, No. 284 (Dec., 1958), pp. 789-798. + +[2] Trevor Hastie, Robert Tibshirani, Jerome Friedman. "`The Elements of Statistical Learning`_". Springer Series in Statistics Springer New York Inc. (2001). + +[3] Guolin Ke, Qi Meng, Thomas Finley, Taifeng Wang, Wei Chen, Weidong Ma, Qiwei Ye, Tie-Yan Liu. "`LightGBM\: A Highly Efficient Gradient Boosting Decision Tree`_." Advances in Neural Information Processing Systems 30 (NIPS 2017), pp. 3149-3157. + + +.. _On Grouping for Maximum Homogeneity: https://www.tandfonline.com/doi/abs/10.1080/01621459.1958.10501479 + +.. _The Elements of Statistical Learning: https://link.springer.com/book/10.1007/978-0-387-84858-7 + +.. _LightGBM\: A Highly Efficient Gradient Boosting Decision Tree: https://papers.nips.cc/paper/6907-lightgbm-a-highly-efficient-gradient-boosting-decision-tree.pdf diff --git a/include/xgboost/task.h b/include/xgboost/task.h index 69952d62c40d..0f702f63cbe8 100644 --- a/include/xgboost/task.h +++ b/include/xgboost/task.h @@ -1,5 +1,5 @@ /*! - * Copyright 2021 by XGBoost Contributors + * Copyright 2021-2022 by XGBoost Contributors */ #ifndef XGBOOST_TASK_H_ #define XGBOOST_TASK_H_ @@ -34,6 +34,10 @@ struct ObjInfo { explicit ObjInfo(Task t) : task{t} {} ObjInfo(Task t, bool khess) : task{t}, const_hess{khess} {} + + constexpr bool UseOneHot() const { + return (task != ObjInfo::kRegression && task != ObjInfo::kBinary); + } }; } // namespace xgboost #endif // XGBOOST_TASK_H_ diff --git a/python-package/xgboost/core.py b/python-package/xgboost/core.py index 22564db80267..3678f68362db 100644 --- a/python-package/xgboost/core.py +++ b/python-package/xgboost/core.py @@ -581,10 +581,10 @@ def __init__( .. versionadded:: 1.3.0 - Experimental support of specializing for categorical features. Do not set to - True unless you are interested in development. Currently it's only available - for `gpu_hist` tree method with 1 vs rest (one hot) categorical split. Also, - JSON serialization format is required. + Experimental support of specializing for categorical features. Do not set + to True unless you are interested in development. Currently it's only + available for `gpu_hist` and `approx` tree methods. Also, JSON/UBJSON + serialization format is required. (XGBoost 1.6 for approx) """ if group is not None and qid is not None: diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 697c769e2235..7d12a657f3a1 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -207,7 +207,9 @@ def inner(y_score: np.ndarray, dmatrix: DMatrix) -> Tuple[str, float]: .. versionadded:: 1.5.0 Experimental support for categorical data. Do not set to true unless you are - interested in development. Only valid when `gpu_hist` and dataframe are used. + interested in development. Only valid when `gpu_hist` or `approx` is used along + with dataframe as input. Also, JSON/UBJSON serialization format is + required. (XGBoost 1.6 for approx) max_cat_to_onehot : Optional[int] @@ -216,10 +218,11 @@ def inner(y_score: np.ndarray, dmatrix: DMatrix) -> Tuple[str, float]: .. note:: This parameter is experimental A threshold for deciding whether XGBoost should use one-hot encoding based split - for categorical data. When number of categories is lesser than the threshold then - one-hot encoding is chosen, otherwise the categories will be partitioned into - children nodes. Only relevant for regression and binary classification and - `approx` tree method. + for categorical data. When number of categories is lesser than the threshold + then one-hot encoding is chosen, otherwise the categories will be partitioned + into children nodes. Only relevant for regression and binary + classification. Also, ``approx`` or ``gpu_hist`` tree method is required. See + :doc:`Categorical Data ` for details. eval_metric : Optional[Union[str, List[str], Callable]] diff --git a/src/common/categorical.h b/src/common/categorical.h index ba6313225025..5eff62264cf2 100644 --- a/src/common/categorical.h +++ b/src/common/categorical.h @@ -16,6 +16,10 @@ namespace xgboost { namespace common { + +using CatBitField = LBitField32; +using KCatBitField = CLBitField32; + // Cast the categorical type. template XGBOOST_DEVICE bst_cat_t AsCat(T const& v) { @@ -57,6 +61,11 @@ inline XGBOOST_DEVICE bool Decision(common::Span cats, float cat if (XGBOOST_EXPECT(validate && (InvalidCat(cat) || cat >= s_cats.Size()), false)) { return dft_left; } + + auto pos = KCatBitField::ToBitPos(cat); + if (pos.int_pos >= cats.size()) { + return true; + } return !s_cats.Check(AsCat(cat)); } @@ -73,18 +82,14 @@ inline void InvalidCategory() { /*! * \brief Whether should we use onehot encoding for categorical data. */ -inline bool UseOneHot(uint32_t n_cats, uint32_t max_cat_to_onehot, ObjInfo task) { - bool use_one_hot = n_cats < max_cat_to_onehot || - (task.task != ObjInfo::kRegression && task.task != ObjInfo::kBinary); +XGBOOST_DEVICE inline bool UseOneHot(uint32_t n_cats, uint32_t max_cat_to_onehot, ObjInfo task) { + bool use_one_hot = n_cats < max_cat_to_onehot || task.UseOneHot(); return use_one_hot; } struct IsCatOp { XGBOOST_DEVICE bool operator()(FeatureType ft) { return ft == FeatureType::kCategorical; } }; - -using CatBitField = LBitField32; -using KCatBitField = CLBitField32; } // namespace common } // namespace xgboost diff --git a/src/common/device_helpers.cuh b/src/common/device_helpers.cuh index c74718554bed..9adf866fece9 100644 --- a/src/common/device_helpers.cuh +++ b/src/common/device_helpers.cuh @@ -952,22 +952,22 @@ thrust::device_ptr tcend(xgboost::HostDeviceVector const& vector) { } template -thrust::device_ptr tbegin(xgboost::common::Span& span) { // NOLINT +XGBOOST_DEVICE thrust::device_ptr tbegin(xgboost::common::Span& span) { // NOLINT return thrust::device_ptr(span.data()); } template -thrust::device_ptr tbegin(xgboost::common::Span const& span) { // NOLINT +XGBOOST_DEVICE thrust::device_ptr tbegin(xgboost::common::Span const& span) { // NOLINT return thrust::device_ptr(span.data()); } template -thrust::device_ptr tend(xgboost::common::Span& span) { // NOLINT +XGBOOST_DEVICE thrust::device_ptr tend(xgboost::common::Span& span) { // NOLINT return tbegin(span) + span.size(); } template -thrust::device_ptr tend(xgboost::common::Span const& span) { // NOLINT +XGBOOST_DEVICE thrust::device_ptr tend(xgboost::common::Span const& span) { // NOLINT return tbegin(span) + span.size(); } @@ -982,12 +982,12 @@ XGBOOST_DEVICE auto trend(xgboost::common::Span &span) { // NOLINT } template -thrust::device_ptr tcbegin(xgboost::common::Span const& span) { // NOLINT +XGBOOST_DEVICE thrust::device_ptr tcbegin(xgboost::common::Span const& span) { // NOLINT return thrust::device_ptr(span.data()); } template -thrust::device_ptr tcend(xgboost::common::Span const& span) { // NOLINT +XGBOOST_DEVICE thrust::device_ptr tcend(xgboost::common::Span const& span) { // NOLINT return tcbegin(span) + span.size(); } @@ -1536,4 +1536,69 @@ void SegmentedArgSort(xgboost::common::Span values, safe_cuda(cudaMemcpyAsync(sorted_idx.data(), sorted_idx_out.data().get(), sorted_idx.size_bytes(), cudaMemcpyDeviceToDevice)); } + +class CUDAStreamView; + +class CUDAEvent { + cudaEvent_t event_{nullptr}; + + public: + CUDAEvent() { dh::safe_cuda(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); } + ~CUDAEvent() { + if (event_) { + dh::safe_cuda(cudaEventDestroy(event_)); + } + } + + CUDAEvent(CUDAEvent const &that) = delete; + CUDAEvent &operator=(CUDAEvent const &that) = delete; + + inline void Record(CUDAStreamView stream); // NOLINT + + operator cudaEvent_t() const { return event_; } // NOLINT +}; + +class CUDAStreamView { + cudaStream_t stream_{nullptr}; + + public: + explicit CUDAStreamView(cudaStream_t s) : stream_{s} {} + void Wait(CUDAEvent const &e) { +#if defined(__CUDACC_VER_MAJOR__) +#if __CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ == 0 + // CUDA == 11.0 + dh::safe_cuda(cudaStreamWaitEvent(stream_, cudaEvent_t{e}, 0)); +#else + // CUDA > 11.0 + dh::safe_cuda(cudaStreamWaitEvent(stream_, cudaEvent_t{e}, cudaEventWaitDefault)); +#endif // __CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ == 0: +#else // clang + dh::safe_cuda(cudaStreamWaitEvent(stream_, cudaEvent_t{e}, cudaEventWaitDefault)); +#endif // defined(__CUDACC_VER_MAJOR__) + } + operator cudaStream_t() const { // NOLINT + return stream_; + } + void Sync() { dh::safe_cuda(cudaStreamSynchronize(stream_)); } +}; + +inline void CUDAEvent::Record(CUDAStreamView stream) { // NOLINT + dh::safe_cuda(cudaEventRecord(event_, cudaStream_t{stream})); +} + +inline CUDAStreamView DefaultStream() { return CUDAStreamView{cudaStreamLegacy}; } + +class CUDAStream { + cudaStream_t stream_; + + public: + CUDAStream() { + dh::safe_cuda(cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking)); + } + ~CUDAStream() { + dh::safe_cuda(cudaStreamDestroy(stream_)); + } + + CUDAStreamView View() const { return CUDAStreamView{stream_}; } +}; } // namespace dh diff --git a/src/common/hist_util.h b/src/common/hist_util.h index 8cb233605e3c..d138d102dfd7 100644 --- a/src/common/hist_util.h +++ b/src/common/hist_util.h @@ -33,66 +33,84 @@ namespace common { */ using GHistIndexRow = Span; -// A CSC matrix representing histogram cuts, used in CPU quantile hist. +// A CSC matrix representing histogram cuts. // The cut values represent upper bounds of bins containing approximately equal numbers of elements class HistogramCuts { + bool has_categorical_{false}; + float max_cat_{-1.0f}; + protected: using BinIdx = uint32_t; - public: - HostDeviceVector cut_values_; // NOLINT - HostDeviceVector cut_ptrs_; // NOLINT - // storing minimum value in a sketch set. - HostDeviceVector min_vals_; // NOLINT + void Swap(HistogramCuts&& that) noexcept(true) { + std::swap(cut_values_, that.cut_values_); + std::swap(cut_ptrs_, that.cut_ptrs_); + std::swap(min_vals_, that.min_vals_); - HistogramCuts(); - HistogramCuts(HistogramCuts const& that) { + std::swap(has_categorical_, that.has_categorical_); + std::swap(max_cat_, that.max_cat_); + } + + void Copy(HistogramCuts const& that) { cut_values_.Resize(that.cut_values_.Size()); cut_ptrs_.Resize(that.cut_ptrs_.Size()); min_vals_.Resize(that.min_vals_.Size()); cut_values_.Copy(that.cut_values_); cut_ptrs_.Copy(that.cut_ptrs_); min_vals_.Copy(that.min_vals_); + has_categorical_ = that.has_categorical_; + max_cat_ = that.max_cat_; } + public: + HostDeviceVector cut_values_; // NOLINT + HostDeviceVector cut_ptrs_; // NOLINT + // storing minimum value in a sketch set. + HostDeviceVector min_vals_; // NOLINT + + HistogramCuts(); + HistogramCuts(HistogramCuts const& that) { this->Copy(that); } + HistogramCuts(HistogramCuts&& that) noexcept(true) { - *this = std::forward(that); + this->Swap(std::forward(that)); } HistogramCuts& operator=(HistogramCuts const& that) { - cut_values_.Resize(that.cut_values_.Size()); - cut_ptrs_.Resize(that.cut_ptrs_.Size()); - min_vals_.Resize(that.min_vals_.Size()); - cut_values_.Copy(that.cut_values_); - cut_ptrs_.Copy(that.cut_ptrs_); - min_vals_.Copy(that.min_vals_); + this->Copy(that); return *this; } HistogramCuts& operator=(HistogramCuts&& that) noexcept(true) { - cut_ptrs_ = std::move(that.cut_ptrs_); - cut_values_ = std::move(that.cut_values_); - min_vals_ = std::move(that.min_vals_); + this->Swap(std::forward(that)); return *this; } - uint32_t FeatureBins(uint32_t feature) const { - return cut_ptrs_.ConstHostVector().at(feature + 1) - - cut_ptrs_.ConstHostVector()[feature]; + uint32_t FeatureBins(bst_feature_t feature) const { + return cut_ptrs_.ConstHostVector().at(feature + 1) - cut_ptrs_.ConstHostVector()[feature]; } - // Getters. Cuts should be of no use after building histogram indices, but currently - // they are deeply linked with quantile_hist, gpu sketcher and gpu_hist, so we preserve - // these for now. std::vector const& Ptrs() const { return cut_ptrs_.ConstHostVector(); } std::vector const& Values() const { return cut_values_.ConstHostVector(); } std::vector const& MinValues() const { return min_vals_.ConstHostVector(); } + bool HasCategorical() const { return has_categorical_; } + float MaxCategory() const { return max_cat_; } + /** + * \brief Set meta info about categorical features. + * + * \param has_cat Do we have categorical feature in the data? + * \param max_cat The maximum categorical value in all features. + */ + void SetCategorical(bool has_cat, float max_cat) { + has_categorical_ = has_cat; + max_cat_ = max_cat; + } + size_t TotalBins() const { return cut_ptrs_.ConstHostVector().back(); } // Return the index of a cut point that is strictly greater than the input // value, or the last available index if none exists - BinIdx SearchBin(float value, uint32_t column_id, std::vector const& ptrs, + BinIdx SearchBin(float value, bst_feature_t column_id, std::vector const& ptrs, std::vector const& values) const { auto end = ptrs[column_id + 1]; auto beg = ptrs[column_id]; @@ -102,7 +120,7 @@ class HistogramCuts { return idx; } - BinIdx SearchBin(float value, uint32_t column_id) const { + BinIdx SearchBin(float value, bst_feature_t column_id) const { return this->SearchBin(value, column_id, Ptrs(), Values()); } diff --git a/src/common/quantile.cc b/src/common/quantile.cc index 3d6fb1fc03fe..44e4178ef55f 100644 --- a/src/common/quantile.cc +++ b/src/common/quantile.cc @@ -272,7 +272,7 @@ void AllreduceCategories(Span feature_types, int32_t n_thread // move all categories into a flatten vector to prepare for allreduce size_t total = feature_ptr.back(); - std::vector flatten(total, 0); + std::vector flatten(total, 0); auto cursor{flatten.begin()}; for (auto const &feat : categories) { cursor = std::copy(feat.cbegin(), feat.cend(), cursor); @@ -287,15 +287,15 @@ void AllreduceCategories(Span feature_types, int32_t n_thread auto gtotal = global_worker_ptr.back(); // categories in all workers with all features. - std::vector global_categories(gtotal, 0); + std::vector global_categories(gtotal, 0); auto rank_begin = global_worker_ptr[rank]; auto rank_size = global_worker_ptr[rank + 1] - rank_begin; CHECK_EQ(rank_size, total); std::copy(flatten.cbegin(), flatten.cend(), global_categories.begin() + rank_begin); // gather values from all workers. rabit::Allreduce(global_categories.data(), global_categories.size()); - QuantileAllreduce allreduce_result{global_categories, global_worker_ptr, - global_feat_ptrs, categories.size()}; + QuantileAllreduce allreduce_result{global_categories, global_worker_ptr, global_feat_ptrs, + categories.size()}; ParallelFor(categories.size(), n_threads, [&](auto fidx) { if (!IsCat(feature_types, fidx)) { return; @@ -531,6 +531,22 @@ void SketchContainerImpl::MakeCuts(HistogramCuts* cuts) { InvalidCategory(); } } + auto const &ptrs = cuts->Ptrs(); + auto const &vals = cuts->Values(); + + float max_cat{-std::numeric_limits::infinity()}; + for (size_t i = 1; i < ptrs.size(); ++i) { + if (IsCat(feature_types_, i - 1)) { + auto beg = ptrs[i - 1]; + auto end = ptrs[i]; + auto feat = Span{vals}.subspan(beg, end - beg); + auto max_elem = *std::max_element(feat.cbegin(), feat.cend()); + if (max_elem > max_cat) { + max_cat = max_elem; + } + } + } + cuts->SetCategorical(true, max_cat); } monitor_.Stop(__func__); diff --git a/src/common/quantile.cu b/src/common/quantile.cu index d15d310c0516..1be6ea23bd30 100644 --- a/src/common/quantile.cu +++ b/src/common/quantile.cu @@ -1,22 +1,23 @@ /*! * Copyright 2020 by XGBoost Contributors */ -#include -#include #include -#include #include +#include +#include +#include +#include // std::numeric_limits #include #include -#include "xgboost/span.h" -#include "quantile.h" -#include "quantile.cuh" -#include "hist_util.h" -#include "device_helpers.cuh" #include "categorical.h" #include "common.h" +#include "device_helpers.cuh" +#include "hist_util.h" +#include "quantile.cuh" +#include "quantile.h" +#include "xgboost/span.h" namespace xgboost { namespace common { @@ -586,7 +587,7 @@ struct InvalidCatOp { Span ptrs; Span ft; - XGBOOST_DEVICE bool operator()(size_t i) { + XGBOOST_DEVICE bool operator()(size_t i) const { auto fidx = dh::SegmentId(ptrs, i); return IsCat(ft, fidx) && InvalidCat(values[i]); } @@ -683,18 +684,36 @@ void SketchContainer::MakeCuts(HistogramCuts* p_cuts) { out_column[idx] = in_column[idx+1].value; }); + float max_cat{-1.0f}; if (has_categorical_) { - dh::XGBCachingDeviceAllocator alloc; - auto ptrs = p_cuts->cut_ptrs_.ConstDeviceSpan(); - auto it = thrust::make_counting_iterator(0ul); + auto invalid_op = InvalidCatOp{out_cut_values, d_out_columns_ptr, d_ft}; + auto it = dh::MakeTransformIterator>( + thrust::make_counting_iterator(0ul), [=] XGBOOST_DEVICE(size_t i) { + auto fidx = dh::SegmentId(d_out_columns_ptr, i); + if (IsCat(d_ft, fidx)) { + auto invalid = invalid_op(i); + auto v = out_cut_values[i]; + return thrust::make_pair(invalid, v); + } + return thrust::make_pair(false, std::numeric_limits::min()); + }); - CHECK_EQ(p_cuts->Ptrs().back(), out_cut_values.size()); - auto invalid = thrust::any_of(thrust::cuda::par(alloc), it, it + out_cut_values.size(), - InvalidCatOp{out_cut_values, ptrs, d_ft}); + bool invalid{false}; + dh::XGBCachingDeviceAllocator alloc; + thrust::tie(invalid, max_cat) = + thrust::reduce(thrust::cuda::par(alloc), it, it + out_cut_values.size(), + thrust::make_pair(false, std::numeric_limits::min()), + [=] XGBOOST_DEVICE(thrust::pair const &l, + thrust::pair const &r) { + return thrust::make_pair(l.first || r.first, std::max(l.second, r.second)); + }); if (invalid) { InvalidCategory(); } } + + p_cuts->SetCategorical(this->has_categorical_, max_cat); + timer_.Stop(__func__); } } // namespace common diff --git a/src/tree/gpu_hist/evaluate_splits.cu b/src/tree/gpu_hist/evaluate_splits.cu index 90ea5a66db13..ce8b13d0def2 100644 --- a/src/tree/gpu_hist/evaluate_splits.cu +++ b/src/tree/gpu_hist/evaluate_splits.cu @@ -1,9 +1,14 @@ /*! - * Copyright 2020-2021 by XGBoost Contributors + * Copyright 2020-2022 by XGBoost Contributors */ +#include // std::max #include -#include "evaluate_splits.cuh" + #include "../../common/categorical.h" +#include "../../common/device_helpers.cuh" +#include "../../data/ellpack_page.cuh" +#include "evaluate_splits.cuh" +#include "expand_entry.cuh" namespace xgboost { namespace tree { @@ -23,7 +28,7 @@ XGBOOST_DEVICE float LossChangeMissing(const GradientPairPrecise &scan, float missing_right_gain = evaluator.CalcSplitGain( param, nidx, fidx, GradStats(scan), GradStats(parent_sum - scan)); - if (missing_left_gain >= missing_right_gain) { + if (missing_left_gain > missing_right_gain) { missing_left_out = true; return missing_left_gain - parent_gain; } else { @@ -69,108 +74,61 @@ ReduceFeature(common::Span feature_histogram, return shared_sum; } -template struct OneHotBin { - GradientSumT __device__ operator()(bool thread_active, uint32_t scan_begin, - SumCallbackOp *, - GradientPairPrecise const &missing, - EvaluateSplitInputs const &inputs, - TempStorageT *) { - GradientSumT bin = thread_active - ? inputs.gradient_histogram[scan_begin + threadIdx.x] - : GradientSumT(); - auto rest = inputs.parent_sum - GradientPairPrecise(bin) - missing; - return GradientSumT{rest}; - } -}; - -template -struct UpdateOneHot { - void __device__ operator()(bool missing_left, uint32_t scan_begin, float gain, - bst_feature_t fidx, GradientPairPrecise const &missing, - GradientSumT const &bin, - EvaluateSplitInputs const &inputs, - DeviceSplitCandidate *best_split) { - int split_gidx = (scan_begin + threadIdx.x); - float fvalue = inputs.feature_values[split_gidx]; - GradientPairPrecise left = - missing_left ? GradientPairPrecise{bin} + missing : GradientPairPrecise{bin}; - GradientPairPrecise right = inputs.parent_sum - left; - best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, fidx, left, right, true, - inputs.param); - } -}; - -template -struct NumericBin { - GradientSumT __device__ operator()(bool thread_active, uint32_t scan_begin, - SumCallbackOp *prefix_callback, - GradientPairPrecise const &missing, - EvaluateSplitInputs inputs, - TempStorageT *temp_storage) { - GradientSumT bin = thread_active - ? inputs.gradient_histogram[scan_begin + threadIdx.x] - : GradientSumT(); - ScanT(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum(), *prefix_callback); - return bin; - } -}; - -template -struct UpdateNumeric { - void __device__ operator()(bool missing_left, uint32_t scan_begin, float gain, - bst_feature_t fidx, GradientPairPrecise const &missing, - GradientSumT const &bin, - EvaluateSplitInputs const &inputs, - DeviceSplitCandidate *best_split) { - // Use pointer from cut to indicate begin and end of bins for each feature. - uint32_t gidx_begin = inputs.feature_segments[fidx]; // beginning bin - int split_gidx = (scan_begin + threadIdx.x) - 1; - float fvalue; - if (split_gidx < static_cast(gidx_begin)) { - fvalue = inputs.min_fvalue[fidx]; - } else { - fvalue = inputs.feature_values[split_gidx]; - } - GradientPairPrecise left = - missing_left ? GradientPairPrecise{bin} + missing : GradientPairPrecise{bin}; - GradientPairPrecise right = inputs.parent_sum - left; - best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, fidx, left, right, false, - inputs.param); - } -}; - /*! \brief Find the thread with best gain. */ -template +template __device__ void EvaluateFeature( int fidx, EvaluateSplitInputs inputs, TreeEvaluator::SplitEvaluator evaluator, - DeviceSplitCandidate* best_split, // shared memory storing best split - TempStorageT* temp_storage // temp memory for cub operations + common::Span sorted_idx, size_t offset, + DeviceSplitCandidate *best_split, // shared memory storing best split + TempStorageT *temp_storage // temp memory for cub operations ) { // Use pointer from cut to indicate begin and end of bins for each feature. uint32_t gidx_begin = inputs.feature_segments[fidx]; // beginning bin uint32_t gidx_end = inputs.feature_segments[fidx + 1]; // end bin for i^th feature auto feature_hist = inputs.gradient_histogram.subspan(gidx_begin, gidx_end - gidx_begin); - auto bin_fn = BinFn(); - auto update_fn = UpdateFn(); // Sum histogram bins for current feature GradientSumT const feature_sum = - ReduceFeature( - feature_hist, temp_storage); + ReduceFeature(feature_hist, temp_storage); GradientPairPrecise const missing = inputs.parent_sum - GradientPairPrecise{feature_sum}; float const null_gain = -std::numeric_limits::infinity(); SumCallbackOp prefix_op = SumCallbackOp(); - for (int scan_begin = gidx_begin; scan_begin < gidx_end; - scan_begin += BLOCK_THREADS) { + for (int scan_begin = gidx_begin; scan_begin < gidx_end; scan_begin += BLOCK_THREADS) { bool thread_active = (scan_begin + threadIdx.x) < gidx_end; - auto bin = bin_fn(thread_active, scan_begin, &prefix_op, missing, inputs, temp_storage); + auto calc_bin_value = [&]() { + GradientSumT bin; + switch (type) { + case kOneHot: { + auto rest = + thread_active ? inputs.gradient_histogram[scan_begin + threadIdx.x] : GradientSumT(); + bin = GradientSumT{inputs.parent_sum - GradientPairPrecise{rest} - missing}; // NOLINT + break; + } + case kNum: { + bin = + thread_active ? inputs.gradient_histogram[scan_begin + threadIdx.x] : GradientSumT(); + ScanT(temp_storage->scan).ExclusiveScan(bin, bin, cub::Sum(), prefix_op); + break; + } + case kPart: { + auto rest = thread_active + ? inputs.gradient_histogram[sorted_idx[scan_begin + threadIdx.x] - offset] + : GradientSumT(); + // No min value for cat feature, use inclusive scan. + ScanT(temp_storage->scan).InclusiveScan(rest, rest, cub::Sum(), prefix_op); + bin = GradientSumT{inputs.parent_sum - GradientPairPrecise{rest} - missing}; // NOLINT + break; + } + } + return bin; + }; + auto bin = calc_bin_value(); // Whether the gradient of missing values is put to the left side. bool missing_left = true; float gain = null_gain; @@ -193,10 +151,48 @@ __device__ void EvaluateFeature( cub::CTA_SYNC(); - // Best thread updates split + // Best thread updates the split if (threadIdx.x == block_max.key) { - update_fn(missing_left, scan_begin, gain, fidx, missing, bin, inputs, - best_split); + switch (type) { + case kNum: { + // Use pointer from cut to indicate begin and end of bins for each feature. + uint32_t gidx_begin = inputs.feature_segments[fidx]; // beginning bin + int split_gidx = (scan_begin + threadIdx.x) - 1; + float fvalue; + if (split_gidx < static_cast(gidx_begin)) { + fvalue = inputs.min_fvalue[fidx]; + } else { + fvalue = inputs.feature_values[split_gidx]; + } + GradientPairPrecise left = + missing_left ? GradientPairPrecise{bin} + missing : GradientPairPrecise{bin}; + GradientPairPrecise right = inputs.parent_sum - left; + best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, fidx, left, right, + false, inputs.param); + break; + } + case kOneHot: { + int32_t split_gidx = (scan_begin + threadIdx.x); + float fvalue = inputs.feature_values[split_gidx]; + GradientPairPrecise left = + missing_left ? GradientPairPrecise{bin} + missing : GradientPairPrecise{bin}; + GradientPairPrecise right = inputs.parent_sum - left; + best_split->Update(gain, missing_left ? kLeftDir : kRightDir, fvalue, fidx, left, right, + true, inputs.param); + break; + } + case kPart: { + int32_t split_gidx = (scan_begin + threadIdx.x); + float fvalue = inputs.feature_values[split_gidx]; + GradientPairPrecise left = + missing_left ? GradientPairPrecise{bin} + missing : GradientPairPrecise{bin}; + GradientPairPrecise right = inputs.parent_sum - left; + auto best_thresh = block_max.key; // index of best threshold inside a feature. + best_split->Update(gain, missing_left ? kLeftDir : kRightDir, best_thresh, fidx, left, + right, true, inputs.param); + break; + } + } } cub::CTA_SYNC(); } @@ -206,6 +202,8 @@ template __global__ void EvaluateSplitsKernel( EvaluateSplitInputs left, EvaluateSplitInputs right, + ObjInfo task, + common::Span sorted_idx, TreeEvaluator::SplitEvaluator evaluator, common::Span out_candidates) { // KeyValuePair here used as threadIdx.x -> gain_value @@ -240,22 +238,26 @@ __global__ void EvaluateSplitsKernel( // One block for each feature. Features are sampled, so fidx != blockIdx.x int fidx = inputs.feature_set[is_left ? blockIdx.x : blockIdx.x - left.feature_set.size()]; + if (common::IsCat(inputs.feature_types, fidx)) { - EvaluateFeature, - UpdateOneHot>(fidx, inputs, evaluator, &best_split, - &temp_storage); + auto n_bins_in_feat = inputs.feature_segments[fidx + 1] - inputs.feature_segments[fidx]; + if (common::UseOneHot(n_bins_in_feat, inputs.param.max_cat_to_onehot, task)) { + EvaluateFeature(fidx, inputs, evaluator, sorted_idx, 0, &best_split, &temp_storage); + } else { + auto node_sorted_idx = is_left ? sorted_idx.first(inputs.feature_values.size()) + : sorted_idx.last(inputs.feature_values.size()); + size_t offset = is_left ? 0 : inputs.feature_values.size(); + EvaluateFeature(fidx, inputs, evaluator, node_sorted_idx, offset, &best_split, + &temp_storage); + } } else { - EvaluateFeature, - UpdateNumeric>(fidx, inputs, evaluator, &best_split, - &temp_storage); + EvaluateFeature(fidx, inputs, evaluator, sorted_idx, 0, &best_split, &temp_storage); } cub::CTA_SYNC(); - if (threadIdx.x == 0) { // Record best loss for each feature out_candidates[blockIdx.x] = best_split; @@ -267,71 +269,175 @@ __device__ DeviceSplitCandidate operator+(const DeviceSplitCandidate& a, return b.loss_chg > a.loss_chg ? b : a; } +/** + * \brief Set the bits for categorical splits based on the split threshold. + */ +template +__device__ void SortBasedSplit(EvaluateSplitInputs const &input, + common::Span d_sorted_idx, bst_feature_t fidx, + bool is_left, common::Span out, + DeviceSplitCandidate *p_out_split) { + auto &out_split = *p_out_split; + out_split.split_cats = common::CatBitField{out}; + auto node_sorted_idx = + is_left ? d_sorted_idx.subspan(0, input.feature_values.size()) + : d_sorted_idx.subspan(input.feature_values.size(), input.feature_values.size()); + size_t node_offset = is_left ? 0 : input.feature_values.size(); + auto best_thresh = out_split.PopBestThresh(); + auto f_sorted_idx = + node_sorted_idx.subspan(input.feature_segments[fidx], input.FeatureBins(fidx)); + if (out_split.dir != kLeftDir) { + // forward, missing on right + auto beg = dh::tcbegin(f_sorted_idx); + // Don't put all the categories into one side + auto boundary = std::min(static_cast((best_thresh + 1)), (f_sorted_idx.size() - 1)); + boundary = std::max(boundary, static_cast(1ul)); + auto end = beg + boundary; + thrust::for_each(thrust::seq, beg, end, [&](auto c) { + auto cat = input.feature_values[c - node_offset]; + assert(!out_split.split_cats.Check(cat) && "already set"); + out_split.SetCat(cat); + }); + } else { + assert((f_sorted_idx.size() - best_thresh + 1) != 0 && " == 0"); + thrust::for_each(thrust::seq, dh::tcrbegin(f_sorted_idx), + dh::tcrbegin(f_sorted_idx) + (f_sorted_idx.size() - best_thresh), [&](auto c) { + auto cat = input.feature_values[c - node_offset]; + out_split.SetCat(cat); + }); + } +} + template -void EvaluateSplits(common::Span out_splits, - TreeEvaluator::SplitEvaluator evaluator, - EvaluateSplitInputs left, - EvaluateSplitInputs right) { - size_t combined_num_features = - left.feature_set.size() + right.feature_set.size(); - dh::TemporaryArray feature_best_splits( - combined_num_features); +void GPUHistEvaluator::EvaluateSplits( + EvaluateSplitInputs left, EvaluateSplitInputs right, ObjInfo task, + TreeEvaluator::SplitEvaluator evaluator, + common::Span out_splits) { + if (!split_cats_.empty()) { + this->SortHistogram(left, right, evaluator); + } + + size_t combined_num_features = left.feature_set.size() + right.feature_set.size(); + dh::TemporaryArray feature_best_splits(combined_num_features); + // One block for each feature uint32_t constexpr kBlockThreads = 256; - dh::LaunchKernel {uint32_t(combined_num_features), kBlockThreads, 0}( - EvaluateSplitsKernel, left, right, evaluator, - dh::ToSpan(feature_best_splits)); + dh::LaunchKernel {static_cast(combined_num_features), kBlockThreads, 0}( + EvaluateSplitsKernel, left, right, task, this->SortedIdx(left), + evaluator, dh::ToSpan(feature_best_splits)); // Reduce to get best candidate for left and right child over all features - auto reduce_offset = - dh::MakeTransformIterator(thrust::make_counting_iterator(0llu), - [=] __device__(size_t idx) -> size_t { - if (idx == 0) { - return 0; - } - if (idx == 1) { - return left.feature_set.size(); - } - if (idx == 2) { - return combined_num_features; - } - return 0; - }); + auto reduce_offset = dh::MakeTransformIterator(thrust::make_counting_iterator(0llu), + [=] __device__(size_t idx) -> size_t { + if (idx == 0) { + return 0; + } + if (idx == 1) { + return left.feature_set.size(); + } + if (idx == 2) { + return combined_num_features; + } + return 0; + }); size_t temp_storage_bytes = 0; auto num_segments = out_splits.size(); - cub::DeviceSegmentedReduce::Sum(nullptr, temp_storage_bytes, - feature_best_splits.data(), out_splits.data(), - num_segments, reduce_offset, reduce_offset + 1); + cub::DeviceSegmentedReduce::Sum(nullptr, temp_storage_bytes, feature_best_splits.data(), + out_splits.data(), num_segments, reduce_offset, + reduce_offset + 1); dh::TemporaryArray temp(temp_storage_bytes); - cub::DeviceSegmentedReduce::Sum(temp.data().get(), temp_storage_bytes, - feature_best_splits.data(), out_splits.data(), - num_segments, reduce_offset, reduce_offset + 1); + cub::DeviceSegmentedReduce::Sum(temp.data().get(), temp_storage_bytes, feature_best_splits.data(), + out_splits.data(), num_segments, reduce_offset, + reduce_offset + 1); } template -void EvaluateSingleSplit(common::Span out_split, - TreeEvaluator::SplitEvaluator evaluator, - EvaluateSplitInputs input) { - EvaluateSplits(out_split, evaluator, input, {}); +void GPUHistEvaluator::CopyToHost(EvaluateSplitInputs const &input, + common::Span cats_out) { + if (has_sort_) { + dh::CUDAEvent event; + event.Record(dh::DefaultStream()); + auto h_cats = this->HostCatStorage(input.nidx); + copy_stream_.View().Wait(event); + dh::safe_cuda(cudaMemcpyAsync(h_cats.data(), cats_out.data(), cats_out.size_bytes(), + cudaMemcpyDeviceToHost, copy_stream_.View())); + } } -template void EvaluateSplits( - common::Span out_splits, - TreeEvaluator::SplitEvaluator evaluator, - EvaluateSplitInputs left, - EvaluateSplitInputs right); -template void EvaluateSplits( - common::Span out_splits, - TreeEvaluator::SplitEvaluator evaluator, - EvaluateSplitInputs left, - EvaluateSplitInputs right); -template void EvaluateSingleSplit( - common::Span out_split, - TreeEvaluator::SplitEvaluator evaluator, - EvaluateSplitInputs input); -template void EvaluateSingleSplit( - common::Span out_split, - TreeEvaluator::SplitEvaluator evaluator, - EvaluateSplitInputs input); +template +void GPUHistEvaluator::EvaluateSplits(GPUExpandEntry candidate, ObjInfo task, + EvaluateSplitInputs left, + EvaluateSplitInputs right, + common::Span out_entries) { + auto evaluator = this->tree_evaluator_.template GetEvaluator(); + + dh::TemporaryArray splits_out_storage(2); + auto out_splits = dh::ToSpan(splits_out_storage); + this->EvaluateSplits(left, right, task, evaluator, out_splits); + + auto d_sorted_idx = this->SortedIdx(left); + auto d_entries = out_entries; + auto cats_out = this->DeviceCatStorage(left.nidx); + // turn candidate into entry, along with hanlding sort based split. + dh::LaunchN(right.feature_set.empty() ? 1 : 2, [=] __device__(size_t i) { + auto const &input = i == 0 ? left : right; + auto &split = out_splits[i]; + auto fidx = out_splits[i].findex; + + if (split.is_cat && + !common::UseOneHot(input.FeatureBins(fidx), input.param.max_cat_to_onehot, task)) { + bool is_left = i == 0; + auto out = is_left ? cats_out.first(cats_out.size() / 2) : cats_out.last(cats_out.size() / 2); + SortBasedSplit(input, d_sorted_idx, fidx, is_left, out, &out_splits[i]); + } + + float base_weight = + evaluator.CalcWeight(input.nidx, input.param, GradStats{split.left_sum + split.right_sum}); + float left_weight = evaluator.CalcWeight(input.nidx, input.param, GradStats{split.left_sum}); + float right_weight = evaluator.CalcWeight(input.nidx, input.param, GradStats{split.right_sum}); + + d_entries[i] = GPUExpandEntry{input.nidx, candidate.depth + 1, out_splits[i], + base_weight, left_weight, right_weight}; + }); + + this->CopyToHost(left, cats_out); +} + +template +GPUExpandEntry GPUHistEvaluator::EvaluateSingleSplit( + EvaluateSplitInputs input, float weight, ObjInfo task) { + dh::TemporaryArray splits_out(1); + auto out_split = dh::ToSpan(splits_out); + auto evaluator = tree_evaluator_.GetEvaluator(); + this->EvaluateSplits(input, {}, task, evaluator, out_split); + + auto cats_out = this->DeviceCatStorage(input.nidx); + auto d_sorted_idx = this->SortedIdx(input); + + dh::TemporaryArray entries(1); + auto d_entries = entries.data().get(); + dh::LaunchN(1, [=] __device__(size_t i) { + auto &split = out_split[i]; + auto fidx = out_split[i].findex; + + if (split.is_cat && + !common::UseOneHot(input.FeatureBins(fidx), input.param.max_cat_to_onehot, task)) { + SortBasedSplit(input, d_sorted_idx, fidx, true, cats_out, &out_split[i]); + } + + float left_weight = evaluator.CalcWeight(0, input.param, GradStats{split.left_sum}); + float right_weight = evaluator.CalcWeight(0, input.param, GradStats{split.right_sum}); + d_entries[0] = GPUExpandEntry(0, 0, split, weight, left_weight, right_weight); + }); + this->CopyToHost(input, cats_out); + + GPUExpandEntry root_entry; + dh::safe_cuda(cudaMemcpyAsync(&root_entry, entries.data().get(), + sizeof(GPUExpandEntry) * entries.size(), cudaMemcpyDeviceToHost)); + return root_entry; +} + +template class GPUHistEvaluator; +template class GPUHistEvaluator; } // namespace tree } // namespace xgboost diff --git a/src/tree/gpu_hist/evaluate_splits.cuh b/src/tree/gpu_hist/evaluate_splits.cuh index fd4abe7865a6..c22834e83860 100644 --- a/src/tree/gpu_hist/evaluate_splits.cuh +++ b/src/tree/gpu_hist/evaluate_splits.cuh @@ -3,15 +3,20 @@ */ #ifndef EVALUATE_SPLITS_CUH_ #define EVALUATE_SPLITS_CUH_ +#include #include -#include "../../data/ellpack_page.cuh" + +#include "../../common/categorical.h" #include "../split_evaluator.h" -#include "../constraints.cuh" #include "../updater_gpu_common.cuh" +#include "expand_entry.cuh" namespace xgboost { -namespace tree { +namespace common { +class HistogramCuts; +} +namespace tree { template struct EvaluateSplitInputs { int nidx; @@ -23,16 +28,131 @@ struct EvaluateSplitInputs { common::Span feature_values; common::Span min_fvalue; common::Span gradient_histogram; + + XGBOOST_DEVICE auto Features() const { return feature_segments.size() - 1; } + __device__ auto FeatureBins(bst_feature_t fidx) const { + return feature_segments[fidx + 1] - feature_segments[fidx]; + } }; + template -void EvaluateSplits(common::Span out_splits, - TreeEvaluator::SplitEvaluator evaluator, - EvaluateSplitInputs left, - EvaluateSplitInputs right); -template -void EvaluateSingleSplit(common::Span out_split, - TreeEvaluator::SplitEvaluator evaluator, - EvaluateSplitInputs input); +class GPUHistEvaluator { + using CatST = common::CatBitField::value_type; // categorical storage type + // use pinned memory to stage the categories, used for sort based splits. + using Alloc = thrust::system::cuda::experimental::pinned_allocator; + + private: + TreeEvaluator tree_evaluator_; + // storage for categories for each node, used for sort based splits. + dh::device_vector split_cats_; + // host storage for categories for each node, used for sort based splits. + std::vector h_split_cats_; + // stream for copying categories from device back to host for expanding the decision tree. + dh::CUDAStream copy_stream_; + // storage for sorted index of feature histogram, used for sort based splits. + dh::device_vector cat_sorted_idx_; + TrainParam param_; + // whether the input data requires sort based split, which is more complicated so we try + // to avoid it if possible. + bool has_sort_{false}; + + // Copy the categories from device to host asynchronously. + void CopyToHost(EvaluateSplitInputs const &input, common::Span cats_out); + + /** + * \brief Get host category storage of nidx for internal calculation. + */ + auto HostCatStorage(bst_node_t nidx) { + auto cat_bits = h_split_cats_.size() / param_.MaxNodes(); + if (nidx == RegTree::kRoot) { + auto cats_out = common::Span{h_split_cats_}.subspan(nidx * cat_bits, cat_bits); + return cats_out; + } + auto cats_out = common::Span{h_split_cats_}.subspan(nidx * cat_bits, cat_bits * 2); + return cats_out; + } + + /** + * \brief Get device category storage of nidx for internal calculation. + */ + auto DeviceCatStorage(bst_node_t nidx) { + auto cat_bits = split_cats_.size() / param_.MaxNodes(); + if (nidx == RegTree::kRoot) { + auto cats_out = dh::ToSpan(split_cats_).subspan(nidx * cat_bits, cat_bits); + return cats_out; + } + auto cats_out = dh::ToSpan(split_cats_).subspan(nidx * cat_bits, cat_bits * 2); + return cats_out; + } + + /** + * \brief Get sorted index storage based on the left node of inputs . + */ + auto SortedIdx(EvaluateSplitInputs left) { + if (left.nidx == RegTree::kRoot && !cat_sorted_idx_.empty()) { + return dh::ToSpan(cat_sorted_idx_).first(left.feature_values.size()); + } + return dh::ToSpan(cat_sorted_idx_); + } + + public: + GPUHistEvaluator(TrainParam const ¶m, bst_feature_t n_features, int32_t device) + : tree_evaluator_{param, n_features, device}, param_{param} {} + /** + * \brief Reset the evaluator, should be called before any use. + */ + void Reset(common::HistogramCuts const &cuts, common::Span ft, ObjInfo task, + bst_feature_t n_features, TrainParam const ¶m, int32_t device); + + /** + * \brief Get host category storage for nidx. Different from the internal version, this + * returns strictly 1 node. + */ + common::Span GetHostNodeCats(bst_node_t nidx) const { + copy_stream_.View().Sync(); + auto cat_bits = h_split_cats_.size() / param_.MaxNodes(); + auto cats_out = common::Span{h_split_cats_}.subspan(nidx * cat_bits, cat_bits); + return cats_out; + } + /** + * \brief Add a split to the internal tree evaluator. + */ + void ApplyTreeSplit(GPUExpandEntry const &candidate, RegTree *p_tree) { + auto &tree = *p_tree; + // Set up child constraints + auto left_child = tree[candidate.nid].LeftChild(); + auto right_child = tree[candidate.nid].RightChild(); + tree_evaluator_.AddSplit(candidate.nid, left_child, right_child, + tree[candidate.nid].SplitIndex(), candidate.left_weight, + candidate.right_weight); + } + + auto GetEvaluator() { return tree_evaluator_.GetEvaluator(); } + /** + * \brief Sort the histogram based on output to obtain contiguous partitions. + */ + common::Span SortHistogram( + EvaluateSplitInputs const &left, EvaluateSplitInputs const &right, + TreeEvaluator::SplitEvaluator evaluator); + + // impl of evaluate splits, contains CUDA kernels so it's public + void EvaluateSplits(EvaluateSplitInputs left, + EvaluateSplitInputs right, ObjInfo task, + TreeEvaluator::SplitEvaluator evaluator, + common::Span out_splits); + /** + * \brief Evaluate splits for left and right nodes. + */ + void EvaluateSplits(GPUExpandEntry candidate, ObjInfo task, + EvaluateSplitInputs left, + EvaluateSplitInputs right, + common::Span out_splits); + /** + * \brief Evaluate splits for root node. + */ + GPUExpandEntry EvaluateSingleSplit(EvaluateSplitInputs input, float weight, + ObjInfo task); +}; } // namespace tree } // namespace xgboost diff --git a/src/tree/gpu_hist/evaluator.cu b/src/tree/gpu_hist/evaluator.cu new file mode 100644 index 000000000000..7566b2847505 --- /dev/null +++ b/src/tree/gpu_hist/evaluator.cu @@ -0,0 +1,100 @@ +/*! + * Copyright 2022 by XGBoost Contributors + * + * \brief Some components of GPU Hist evaluator, this file only exist to reduce nvcc + * compilation time. + */ +#include // thrust::any_of +#include // thrust::stable_sort + +#include "../../common/device_helpers.cuh" +#include "../../common/hist_util.h" // common::HistogramCuts +#include "evaluate_splits.cuh" +#include "xgboost/data.h" + +namespace xgboost { +namespace tree { +template +void GPUHistEvaluator::Reset(common::HistogramCuts const &cuts, + common::Span ft, ObjInfo task, + bst_feature_t n_features, TrainParam const ¶m, + int32_t device) { + param_ = param; + tree_evaluator_ = TreeEvaluator{param, n_features, device}; + if (cuts.HasCategorical() && !task.UseOneHot()) { + dh::XGBCachingDeviceAllocator alloc; + auto ptrs = cuts.cut_ptrs_.ConstDeviceSpan(); + auto beg = thrust::make_counting_iterator(1ul); + auto end = thrust::make_counting_iterator(ptrs.size()); + auto to_onehot = param.max_cat_to_onehot; + // This condition avoids sort-based split function calls if the users want + // onehot-encoding-based splits. + // For some reason, any_of adds 1.5 minutes to compilation time for CUDA 11.x. + has_sort_ = thrust::any_of(thrust::cuda::par(alloc), beg, end, [=] XGBOOST_DEVICE(size_t i) { + auto idx = i - 1; + if (common::IsCat(ft, idx)) { + auto n_bins = ptrs[i] - ptrs[idx]; + bool use_sort = !common::UseOneHot(n_bins, to_onehot, task); + return use_sort; + } + return false; + }); + + if (has_sort_) { + auto bit_storage_size = common::CatBitField::ComputeStorageSize(cuts.MaxCategory() + 1); + CHECK_NE(bit_storage_size, 0); + // We need to allocate for all nodes since the updater can grow the tree layer by + // layer, all nodes in the same layer must be preserved until that layer is + // finished. We can allocate one layer at a time, but the best case is reducing the + // size of the bitset by about a half, at the cost of invoking CUDA malloc many more + // times than necessary. + split_cats_.resize(param.MaxNodes() * bit_storage_size); + h_split_cats_.resize(split_cats_.size()); + dh::safe_cuda( + cudaMemsetAsync(split_cats_.data().get(), '\0', split_cats_.size() * sizeof(CatST))); + + cat_sorted_idx_.resize(cuts.cut_values_.Size() * 2); // evaluate 2 nodes at a time. + } + } +} + +template +common::Span GPUHistEvaluator::SortHistogram( + EvaluateSplitInputs const &left, EvaluateSplitInputs const &right, + TreeEvaluator::SplitEvaluator evaluator) { + dh::XGBDeviceAllocator alloc; + auto sorted_idx = this->SortedIdx(left); + dh::Iota(sorted_idx); + // sort 2 nodes and all the features at the same time, disregarding colmun sampling. + thrust::stable_sort( + thrust::cuda::par(alloc), dh::tbegin(sorted_idx), dh::tend(sorted_idx), + [evaluator, left, right] XGBOOST_DEVICE(size_t l, size_t r) { + auto l_is_left = l < left.feature_values.size(); + auto r_is_left = r < left.feature_values.size(); + if (l_is_left != r_is_left) { + return l_is_left; // not the same node + } + + auto const &input = l_is_left ? left : right; + l -= (l_is_left ? 0 : input.feature_values.size()); + r -= (r_is_left ? 0 : input.feature_values.size()); + + auto lfidx = dh::SegmentId(input.feature_segments, l); + auto rfidx = dh::SegmentId(input.feature_segments, r); + if (lfidx != rfidx) { + return lfidx < rfidx; // not the same feature + } + if (common::IsCat(input.feature_types, lfidx)) { + auto lw = evaluator.CalcWeightCat(input.param, input.gradient_histogram[l]); + auto rw = evaluator.CalcWeightCat(input.param, input.gradient_histogram[r]); + return lw < rw; + } + return l < r; + }); + return dh::ToSpan(cat_sorted_idx_); +} + +template class GPUHistEvaluator; +template class GPUHistEvaluator; +} // namespace tree +} // namespace xgboost diff --git a/src/tree/gpu_hist/expand_entry.cuh b/src/tree/gpu_hist/expand_entry.cuh index ac22b652c785..44762e699ece 100644 --- a/src/tree/gpu_hist/expand_entry.cuh +++ b/src/tree/gpu_hist/expand_entry.cuh @@ -4,8 +4,9 @@ #ifndef EXPAND_ENTRY_CUH_ #define EXPAND_ENTRY_CUH_ #include + #include "../param.h" -#include "evaluate_splits.cuh" +#include "../updater_gpu_common.cuh" namespace xgboost { namespace tree { diff --git a/src/tree/hist/evaluate_splits.h b/src/tree/hist/evaluate_splits.h index 9fde7ee3818f..9fd3ce5ca0b9 100644 --- a/src/tree/hist/evaluate_splits.h +++ b/src/tree/hist/evaluate_splits.h @@ -53,7 +53,6 @@ template class HistEvaluator { return true; } } - enum SplitType { kNum = 0, kOneHot = 1, kPart = 2 }; // Enumerate/Scan the split values of specific feature // Returns the sum of gradients corresponding to the data points that contains @@ -137,7 +136,7 @@ template class HistEvaluator { static_cast(evaluator.CalcSplitGain(param_, nidx, fidx, GradStats{left_sum}, GradStats{right_sum}) - parent.root_gain); - split_pt = cut_val[i]; + split_pt = cut_val[i]; // not used for partition based improved = best.Update(loss_chg, fidx, split_pt, d_step == -1, split_type != kNum, left_sum, right_sum); } else { @@ -180,10 +179,10 @@ template class HistEvaluator { if (d_step == 1) { std::for_each(sorted_idx.begin(), sorted_idx.begin() + (best_thresh - ibegin + 1), - [&cat_bits](size_t c) { cat_bits.Set(c); }); + [&](size_t c) { cat_bits.Set(cut_val[c + ibegin]); }); } else { std::for_each(sorted_idx.rbegin(), sorted_idx.rbegin() + (ibegin - best_thresh), - [&cat_bits](size_t c) { cat_bits.Set(c); }); + [&](size_t c) { cat_bits.Set(cut_val[c + cut_ptr[fidx]]); }); } } p_best->Update(best); @@ -231,6 +230,7 @@ template class HistEvaluator { } } auto evaluator = tree_evaluator_.GetEvaluator(); + auto const& cut_ptrs = cut.Ptrs(); common::ParallelFor2d(space, n_threads_, [&](size_t nidx_in_set, common::Range1d r) { auto tidx = omp_get_thread_num(); @@ -246,26 +246,22 @@ template class HistEvaluator { continue; } if (is_cat) { - auto n_bins = cut.Ptrs().at(fidx + 1) - cut.Ptrs()[fidx]; + auto n_bins = cut_ptrs.at(fidx + 1) - cut_ptrs[fidx]; if (common::UseOneHot(n_bins, param_.max_cat_to_onehot, task_)) { EnumerateSplit<+1, kOneHot>(cut, {}, histogram, fidx, nidx, evaluator, best); EnumerateSplit<-1, kOneHot>(cut, {}, histogram, fidx, nidx, evaluator, best); } else { - auto const &cut_ptr = cut.Ptrs(); std::vector sorted_idx(n_bins); std::iota(sorted_idx.begin(), sorted_idx.end(), 0); - auto feat_hist = histogram.subspan(cut_ptr[fidx], n_bins); + auto feat_hist = histogram.subspan(cut_ptrs[fidx], n_bins); + // Sort the histogram to get contiguous partitions. std::stable_sort(sorted_idx.begin(), sorted_idx.end(), [&](size_t l, size_t r) { auto ret = evaluator.CalcWeightCat(param_, feat_hist[l]) < evaluator.CalcWeightCat(param_, feat_hist[r]); - static_assert(std::is_same::value, ""); return ret; }); - auto grad_stats = - EnumerateSplit<+1, kPart>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best); - if (SplitContainsMissingValues(grad_stats, snode_[nidx])) { - EnumerateSplit<-1, kPart>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best); - } + EnumerateSplit<+1, kPart>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best); + EnumerateSplit<-1, kPart>(cut, sorted_idx, histogram, fidx, nidx, evaluator, best); } } else { auto grad_stats = @@ -313,6 +309,7 @@ template class HistEvaluator { cat_bits.Set(cat); } else { split_cats = candidate.split.cat_bits; + common::CatBitField cat_bits{split_cats}; } tree.ExpandCategorical( diff --git a/src/tree/split_evaluator.h b/src/tree/split_evaluator.h index 5030fcb6db92..8cdf88834559 100644 --- a/src/tree/split_evaluator.h +++ b/src/tree/split_evaluator.h @@ -110,6 +110,9 @@ class TreeEvaluator { template XGBOOST_DEVICE double CalcWeightCat(ParamT const& param, GradientSumT const& stats) const { + // FIXME(jiamingy): This is a temporary solution until we have categorical feature + // specific regularization parameters. During sorting we should try to avoid any + // regularization. return ::xgboost::tree::CalcWeight(param, stats); } @@ -180,6 +183,15 @@ class TreeEvaluator { .Eval(&lower_bounds_, &upper_bounds_, &monotone_); } }; + +enum SplitType { + // numerical split + kNum = 0, + // onehot encoding based categorical split + kOneHot = 1, + // partition-based categorical split + kPart = 2 +}; } // namespace tree } // namespace xgboost diff --git a/src/tree/updater_gpu_common.cuh b/src/tree/updater_gpu_common.cuh index fc4fe6f89fa4..c7c81e964848 100644 --- a/src/tree/updater_gpu_common.cuh +++ b/src/tree/updater_gpu_common.cuh @@ -8,6 +8,7 @@ #include #include #include +#include "../common/categorical.h" #include "../common/device_helpers.cuh" #include "../common/random.h" #include "param.h" @@ -27,6 +28,7 @@ struct GPUTrainingParam { // default=0 means no constraint on weight delta float max_delta_step; float learning_rate; + uint32_t max_cat_to_onehot; GPUTrainingParam() = default; @@ -35,14 +37,10 @@ struct GPUTrainingParam { reg_lambda(param.reg_lambda), reg_alpha(param.reg_alpha), max_delta_step(param.max_delta_step), - learning_rate{param.learning_rate} {} + learning_rate{param.learning_rate}, + max_cat_to_onehot{param.max_cat_to_onehot} {} }; -using NodeIdT = int32_t; - -/** used to assign default id to a Node */ -static const bst_node_t kUnusedNode = -1; - /** * @enum DefaultDirection node.cuh * @brief Default direction to be followed in case of missing values @@ -59,6 +57,8 @@ struct DeviceSplitCandidate { DefaultDirection dir {kLeftDir}; int findex {-1}; float fvalue {0}; + + common::CatBitField split_cats; bool is_cat { false }; GradientPairPrecise left_sum; @@ -75,6 +75,28 @@ struct DeviceSplitCandidate { *this = other; } } + /** + * \brief The largest encoded category in the split bitset + */ + bst_cat_t MaxCat() const { + // Reuse the fvalue for categorical values. + return static_cast(fvalue); + } + /** + * \brief Return the best threshold for cat split, reset the value after return. + */ + XGBOOST_DEVICE size_t PopBestThresh() { + // fvalue is also being used for storing the threshold for categorical split + auto best_thresh = static_cast(this->fvalue); + this->fvalue = 0; + return best_thresh; + } + + template + XGBOOST_DEVICE void SetCat(T c) { + this->split_cats.Set(common::AsCat(c)); + fvalue = std::max(this->fvalue, static_cast(c)); + } XGBOOST_DEVICE void Update(float loss_chg_in, DefaultDirection dir_in, float fvalue_in, int findex_in, @@ -108,18 +130,6 @@ struct DeviceSplitCandidate { } }; -struct DeviceSplitCandidateReduceOp { - GPUTrainingParam param; - explicit DeviceSplitCandidateReduceOp(GPUTrainingParam param) : param(std::move(param)) {} - XGBOOST_DEVICE DeviceSplitCandidate operator()( - const DeviceSplitCandidate& a, const DeviceSplitCandidate& b) const { - DeviceSplitCandidate best; - best.Update(a, param); - best.Update(b, param); - return best; - } -}; - template struct SumCallbackOp { // Running prefix diff --git a/src/tree/updater_gpu_hist.cu b/src/tree/updater_gpu_hist.cu index 9587c3b839e3..25a953ea2062 100644 --- a/src/tree/updater_gpu_hist.cu +++ b/src/tree/updater_gpu_hist.cu @@ -159,6 +159,10 @@ class DeviceHistogram { // Manage memory for a single GPU template struct GPUHistMakerDevice { + private: + GPUHistEvaluator evaluator_; + + public: int device_id; EllpackPageImpl const* page; common::Span feature_types; @@ -182,7 +186,6 @@ struct GPUHistMakerDevice { dh::PinnedMemory pinned; common::Monitor monitor; - TreeEvaluator tree_evaluator; common::ColumnSampler column_sampler; FeatureInteractionConstraintDevice interaction_constraints; @@ -192,24 +195,20 @@ struct GPUHistMakerDevice { // Storing split categories for last node. dh::caching_device_vector node_categories; - GPUHistMakerDevice(int _device_id, - EllpackPageImpl const* _page, - common::Span _feature_types, - bst_uint _n_rows, - TrainParam _param, - uint32_t column_sampler_seed, - uint32_t n_features, + GPUHistMakerDevice(int _device_id, EllpackPageImpl const* _page, + common::Span _feature_types, bst_uint _n_rows, + TrainParam _param, uint32_t column_sampler_seed, uint32_t n_features, BatchParam _batch_param) - : device_id(_device_id), + : evaluator_{_param, n_features, _device_id}, + device_id(_device_id), page(_page), feature_types{_feature_types}, param(std::move(_param)), - tree_evaluator(param, n_features, _device_id), column_sampler(column_sampler_seed), interaction_constraints(param, n_features), batch_param(std::move(_batch_param)) { - sampler.reset(new GradientBasedSampler( - page, _n_rows, batch_param, param.subsample, param.sampling_method)); + sampler.reset(new GradientBasedSampler(page, _n_rows, batch_param, param.subsample, + param.sampling_method)); if (!param.monotone_constraints.empty()) { // Copy assigning an empty vector causes an exception in MSVC debug builds monotone_constraints = param.monotone_constraints; @@ -219,9 +218,8 @@ struct GPUHistMakerDevice { // Init histogram hist.Init(device_id, page->Cuts().TotalBins()); monitor.Init(std::string("GPUHistMakerDevice") + std::to_string(device_id)); - feature_groups.reset(new FeatureGroups(page->Cuts(), page->is_dense, - dh::MaxSharedMemoryOptin(device_id), - sizeof(GradientSumT))); + feature_groups.reset(new FeatureGroups( + page->Cuts(), page->is_dense, dh::MaxSharedMemoryOptin(device_id), sizeof(GradientSumT))); } ~GPUHistMakerDevice() { // NOLINT @@ -231,13 +229,17 @@ struct GPUHistMakerDevice { // Reset values for each update iteration // Note that the column sampler must be passed by value because it is not // thread safe - void Reset(HostDeviceVector* dh_gpair, DMatrix* dmat, int64_t num_columns) { + void Reset(HostDeviceVector* dh_gpair, DMatrix* dmat, int64_t num_columns, + ObjInfo task) { auto const& info = dmat->Info(); this->column_sampler.Init(num_columns, info.feature_weights.HostVector(), param.colsample_bynode, param.colsample_bylevel, param.colsample_bytree); dh::safe_cuda(cudaSetDevice(device_id)); - tree_evaluator = TreeEvaluator(param, dmat->Info().num_col_, device_id); + + this->evaluator_.Reset(page->Cuts(), feature_types, task, dmat->Info().num_col_, param, + device_id); + this->interaction_constraints.Reset(); std::fill(node_sum_gradients.begin(), node_sum_gradients.end(), GradientPairPrecise{}); @@ -258,10 +260,8 @@ struct GPUHistMakerDevice { hist.Reset(); } - - DeviceSplitCandidate EvaluateRootSplit(GradientPairPrecise root_sum) { + GPUExpandEntry EvaluateRootSplit(GradientPairPrecise root_sum, float weight, ObjInfo task) { int nidx = RegTree::kRoot; - dh::TemporaryArray splits_out(1); GPUTrainingParam gpu_param(param); auto sampled_features = column_sampler.GetFeatureSet(0); sampled_features->SetDevice(device_id); @@ -277,32 +277,23 @@ struct GPUHistMakerDevice { matrix.gidx_fvalue_map, matrix.min_fvalue, hist.GetNodeHistogram(nidx)}; - auto gain_calc = tree_evaluator.GetEvaluator(); - EvaluateSingleSplit(dh::ToSpan(splits_out), gain_calc, inputs); - std::vector result(1); - dh::safe_cuda(cudaMemcpy(result.data(), splits_out.data().get(), - sizeof(DeviceSplitCandidate) * splits_out.size(), - cudaMemcpyDeviceToHost)); - return result.front(); - } - - void EvaluateLeftRightSplits( - GPUExpandEntry candidate, int left_nidx, int right_nidx, const RegTree& tree, - common::Span pinned_candidates_out) { + auto split = this->evaluator_.EvaluateSingleSplit(inputs, weight, task); + return split; + } + + void EvaluateLeftRightSplits(GPUExpandEntry candidate, ObjInfo task, int left_nidx, + int right_nidx, const RegTree& tree, + common::Span pinned_candidates_out) { dh::TemporaryArray splits_out(2); GPUTrainingParam gpu_param(param); - auto left_sampled_features = - column_sampler.GetFeatureSet(tree.GetDepth(left_nidx)); + auto left_sampled_features = column_sampler.GetFeatureSet(tree.GetDepth(left_nidx)); left_sampled_features->SetDevice(device_id); common::Span left_feature_set = - interaction_constraints.Query(left_sampled_features->DeviceSpan(), - left_nidx); - auto right_sampled_features = - column_sampler.GetFeatureSet(tree.GetDepth(right_nidx)); + interaction_constraints.Query(left_sampled_features->DeviceSpan(), left_nidx); + auto right_sampled_features = column_sampler.GetFeatureSet(tree.GetDepth(right_nidx)); right_sampled_features->SetDevice(device_id); common::Span right_feature_set = - interaction_constraints.Query(right_sampled_features->DeviceSpan(), - left_nidx); + interaction_constraints.Query(right_sampled_features->DeviceSpan(), left_nidx); auto matrix = page->GetDeviceAccessor(device_id); EvaluateSplitInputs left{left_nidx, @@ -323,29 +314,11 @@ struct GPUHistMakerDevice { matrix.gidx_fvalue_map, matrix.min_fvalue, hist.GetNodeHistogram(right_nidx)}; - auto d_splits_out = dh::ToSpan(splits_out); - EvaluateSplits(d_splits_out, tree_evaluator.GetEvaluator(), left, right); + dh::TemporaryArray entries(2); - auto evaluator = tree_evaluator.GetEvaluator(); - auto d_entries = entries.data().get(); - dh::LaunchN(2, [=] __device__(size_t idx) { - auto split = d_splits_out[idx]; - auto nidx = idx == 0 ? left_nidx : right_nidx; - - float base_weight = evaluator.CalcWeight( - nidx, gpu_param, GradStats{split.left_sum + split.right_sum}); - float left_weight = - evaluator.CalcWeight(nidx, gpu_param, GradStats{split.left_sum}); - float right_weight = evaluator.CalcWeight( - nidx, gpu_param, GradStats{split.right_sum}); - - d_entries[idx] = - GPUExpandEntry{nidx, candidate.depth + 1, d_splits_out[idx], - base_weight, left_weight, right_weight}; - }); - dh::safe_cuda(cudaMemcpyAsync( - pinned_candidates_out.data(), entries.data().get(), - sizeof(GPUExpandEntry) * entries.size(), cudaMemcpyDeviceToHost)); + this->evaluator_.EvaluateSplits(candidate, task, left, right, dh::ToSpan(entries)); + dh::safe_cuda(cudaMemcpyAsync(pinned_candidates_out.data(), entries.data().get(), + sizeof(GPUExpandEntry) * entries.size(), cudaMemcpyDeviceToHost)); } void BuildHist(int nidx) { @@ -369,12 +342,10 @@ struct GPUHistMakerDevice { }); } - bool CanDoSubtractionTrick(int nidx_parent, int nidx_histogram, - int nidx_subtraction) { + bool CanDoSubtractionTrick(int nidx_parent, int nidx_histogram, int nidx_subtraction) { // Make sure histograms are already allocated hist.AllocateHistogram(nidx_subtraction); - return hist.HistogramExists(nidx_histogram) && - hist.HistogramExists(nidx_parent); + return hist.HistogramExists(nidx_histogram) && hist.HistogramExists(nidx_parent); } void UpdatePosition(int nidx, RegTree* p_tree) { @@ -503,13 +474,12 @@ struct GPUHistMakerDevice { cudaMemcpyHostToDevice)); auto d_position = row_partitioner->GetPosition(); auto d_node_sum_gradients = device_node_sum_gradients.data().get(); - auto evaluator = tree_evaluator.GetEvaluator(); + auto tree_evaluator = evaluator_.GetEvaluator(); - dh::LaunchN(d_ridx.size(), [=, out_preds_d = out_preds_d] __device__( - int local_idx) mutable { + dh::LaunchN(d_ridx.size(), [=, out_preds_d = out_preds_d] __device__(int local_idx) mutable { int pos = d_position[local_idx]; - bst_float weight = evaluator.CalcWeight( - pos, param_d, GradStats{d_node_sum_gradients[pos]}); + bst_float weight = + tree_evaluator.CalcWeight(pos, param_d, GradStats{d_node_sum_gradients[pos]}); static_assert(!std::is_const::value, ""); out_preds_d(d_ridx[local_idx]) += weight * param_d.learning_rate; }); @@ -562,7 +532,6 @@ struct GPUHistMakerDevice { void ApplySplit(const GPUExpandEntry& candidate, RegTree* p_tree) { RegTree& tree = *p_tree; - auto evaluator = tree_evaluator.GetEvaluator(); auto parent_sum = candidate.split.left_sum + candidate.split.right_sum; auto base_weight = candidate.base_weight; auto left_weight = candidate.left_weight * param.learning_rate; @@ -572,48 +541,50 @@ struct GPUHistMakerDevice { if (is_cat) { CHECK_LT(candidate.split.fvalue, std::numeric_limits::max()) << "Categorical feature value too large."; - if (common::InvalidCat(candidate.split.fvalue)) { - common::InvalidCategory(); + std::vector split_cats; + if (candidate.split.split_cats.Bits().empty()) { + if (common::InvalidCat(candidate.split.fvalue)) { + common::InvalidCategory(); + } + auto cat = common::AsCat(candidate.split.fvalue); + split_cats.resize(LBitField32::ComputeStorageSize(cat + 1), 0); + common::CatBitField cats_bits(split_cats); + cats_bits.Set(cat); + dh::CopyToD(split_cats, &node_categories); + } else { + auto h_cats = this->evaluator_.GetHostNodeCats(candidate.nid); + auto max_cat = candidate.split.MaxCat(); + split_cats.resize(common::CatBitField::ComputeStorageSize(max_cat + 1), 0); + CHECK_LE(split_cats.size(), h_cats.size()); + std::copy(h_cats.data(), h_cats.data() + split_cats.size(), split_cats.data()); + + node_categories.resize(candidate.split.split_cats.Bits().size()); + dh::safe_cuda(cudaMemcpyAsync( + node_categories.data().get(), candidate.split.split_cats.Data(), + candidate.split.split_cats.Bits().size_bytes(), cudaMemcpyDeviceToDevice)); } - auto cat = common::AsCat(candidate.split.fvalue); - std::vector split_cats(LBitField32::ComputeStorageSize(std::max(cat + 1, 1)), 0); - LBitField32 cats_bits(split_cats); - cats_bits.Set(cat); - dh::CopyToD(split_cats, &node_categories); + tree.ExpandCategorical( - candidate.nid, candidate.split.findex, split_cats, - candidate.split.dir == kLeftDir, base_weight, left_weight, - right_weight, candidate.split.loss_chg, parent_sum.GetHess(), - candidate.split.left_sum.GetHess(), - candidate.split.right_sum.GetHess()); + candidate.nid, candidate.split.findex, split_cats, candidate.split.dir == kLeftDir, + base_weight, left_weight, right_weight, candidate.split.loss_chg, parent_sum.GetHess(), + candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess()); } else { - tree.ExpandNode(candidate.nid, candidate.split.findex, - candidate.split.fvalue, candidate.split.dir == kLeftDir, - base_weight, left_weight, right_weight, + tree.ExpandNode(candidate.nid, candidate.split.findex, candidate.split.fvalue, + candidate.split.dir == kLeftDir, base_weight, left_weight, right_weight, candidate.split.loss_chg, parent_sum.GetHess(), - candidate.split.left_sum.GetHess(), - candidate.split.right_sum.GetHess()); + candidate.split.left_sum.GetHess(), candidate.split.right_sum.GetHess()); } + evaluator_.ApplyTreeSplit(candidate, p_tree); - // Set up child constraints - auto left_child = tree[candidate.nid].LeftChild(); - auto right_child = tree[candidate.nid].RightChild(); - - tree_evaluator.AddSplit(candidate.nid, left_child, right_child, - tree[candidate.nid].SplitIndex(), candidate.left_weight, - candidate.right_weight); - node_sum_gradients[tree[candidate.nid].LeftChild()] = - candidate.split.left_sum; - node_sum_gradients[tree[candidate.nid].RightChild()] = - candidate.split.right_sum; - - interaction_constraints.Split( - candidate.nid, tree[candidate.nid].SplitIndex(), - tree[candidate.nid].LeftChild(), + node_sum_gradients[tree[candidate.nid].LeftChild()] = candidate.split.left_sum; + node_sum_gradients[tree[candidate.nid].RightChild()] = candidate.split.right_sum; + + interaction_constraints.Split(candidate.nid, tree[candidate.nid].SplitIndex(), + tree[candidate.nid].LeftChild(), tree[candidate.nid].RightChild()); } - GPUExpandEntry InitRoot(RegTree* p_tree, dh::AllReducer* reducer) { + GPUExpandEntry InitRoot(RegTree* p_tree, ObjInfo task, dh::AllReducer* reducer) { constexpr bst_node_t kRootNIdx = 0; dh::XGBCachingDeviceAllocator alloc; auto gpair_it = dh::MakeTransformIterator( @@ -634,39 +605,21 @@ struct GPUHistMakerDevice { (*p_tree)[kRootNIdx].SetLeaf(param.learning_rate * weight); // Generate first split - auto split = this->EvaluateRootSplit(root_sum); - dh::TemporaryArray entries(1); - auto d_entries = entries.data().get(); - auto evaluator = tree_evaluator.GetEvaluator(); - GPUTrainingParam gpu_param(param); - auto depth = p_tree->GetDepth(kRootNIdx); - dh::LaunchN(1, [=] __device__(size_t idx) { - float left_weight = evaluator.CalcWeight(kRootNIdx, gpu_param, - GradStats{split.left_sum}); - float right_weight = evaluator.CalcWeight( - kRootNIdx, gpu_param, GradStats{split.right_sum}); - d_entries[0] = - GPUExpandEntry(kRootNIdx, depth, split, - weight, left_weight, right_weight); - }); - GPUExpandEntry root_entry; - dh::safe_cuda(cudaMemcpyAsync( - &root_entry, entries.data().get(), - sizeof(GPUExpandEntry) * entries.size(), cudaMemcpyDeviceToHost)); + auto root_entry = this->EvaluateRootSplit(root_sum, weight, task); return root_entry; } - void UpdateTree(HostDeviceVector* gpair_all, DMatrix* p_fmat, + void UpdateTree(HostDeviceVector* gpair_all, DMatrix* p_fmat, ObjInfo task, RegTree* p_tree, dh::AllReducer* reducer) { auto& tree = *p_tree; Driver driver(static_cast(param.grow_policy)); monitor.Start("Reset"); - this->Reset(gpair_all, p_fmat, p_fmat->Info().num_col_); + this->Reset(gpair_all, p_fmat, p_fmat->Info().num_col_, task); monitor.Stop("Reset"); monitor.Start("InitRoot"); - driver.Push({ this->InitRoot(p_tree, reducer) }); + driver.Push({ this->InitRoot(p_tree, task, reducer) }); monitor.Stop("InitRoot"); auto num_leaves = 1; @@ -700,8 +653,7 @@ struct GPUHistMakerDevice { monitor.Stop("BuildHist"); monitor.Start("EvaluateSplits"); - this->EvaluateLeftRightSplits(candidate, left_child_nidx, - right_child_nidx, *p_tree, + this->EvaluateLeftRightSplits(candidate, task, left_child_nidx, right_child_nidx, *p_tree, new_candidates.subspan(i * 2, 2)); monitor.Stop("EvaluateSplits"); } else { @@ -816,14 +768,13 @@ class GPUHistMakerSpecialised { CHECK(*local_tree == reference_tree); } - void UpdateTree(HostDeviceVector* gpair, DMatrix* p_fmat, - RegTree* p_tree) { + void UpdateTree(HostDeviceVector* gpair, DMatrix* p_fmat, RegTree* p_tree) { monitor_.Start("InitData"); this->InitData(p_fmat); monitor_.Stop("InitData"); gpair->SetDevice(device_); - maker->UpdateTree(gpair, p_fmat, p_tree, &reducer_); + maker->UpdateTree(gpair, p_fmat, task_, p_tree, &reducer_); } bool UpdatePredictionCache(const DMatrix *data, diff --git a/tests/cpp/common/test_hist_util.h b/tests/cpp/common/test_hist_util.h index 49952d202706..6a370c59700b 100644 --- a/tests/cpp/common/test_hist_util.h +++ b/tests/cpp/common/test_hist_util.h @@ -1,5 +1,5 @@ /*! - * Copyright 2019-2021 by XGBoost Contributors + * Copyright 2019-2022 by XGBoost Contributors */ #pragma once #include @@ -235,6 +235,7 @@ void TestCategoricalSketch(size_t n, size_t num_categories, int32_t num_bins, ASSERT_EQ(dmat->Info().feature_types.Size(), 1); auto cuts = sketch(dmat.get(), num_bins); + ASSERT_EQ(cuts.MaxCategory(), num_categories - 1); std::sort(x.begin(), x.end()); auto n_uniques = std::unique(x.begin(), x.end()) - x.begin(); ASSERT_NE(n_uniques, x.size()); diff --git a/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu b/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu index 0916d1181519..0cbfc9f2a6cf 100644 --- a/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu +++ b/tests/cpp/tree/gpu_hist/test_evaluate_splits.cu @@ -1,7 +1,11 @@ +/*! + * Copyright 2020-2022 by XGBoost contributors + */ #include #include "../../../../src/tree/gpu_hist/evaluate_splits.cuh" #include "../../helpers.h" #include "../../histogram_helpers.h" +#include "../test_evaluate_splits.h" // TestPartitionBasedSplit namespace xgboost { namespace tree { @@ -16,7 +20,6 @@ auto ZeroParam() { } // anonymous namespace void TestEvaluateSingleSplit(bool is_categorical) { - thrust::device_vector out_splits(1); GradientPairPrecise parent_sum(0.0, 1.0); TrainParam tparam = ZeroParam(); GPUTrainingParam param{tparam}; @@ -50,11 +53,13 @@ void TestEvaluateSingleSplit(bool is_categorical) { dh::ToSpan(feature_values), dh::ToSpan(feature_min_values), dh::ToSpan(feature_histogram)}; - TreeEvaluator tree_evaluator(tparam, feature_min_values.size(), 0); - auto evaluator = tree_evaluator.GetEvaluator(); - EvaluateSingleSplit(dh::ToSpan(out_splits), evaluator, input); - DeviceSplitCandidate result = out_splits[0]; + GPUHistEvaluator evaluator{ + tparam, static_cast(feature_min_values.size()), 0}; + dh::device_vector out_cats; + DeviceSplitCandidate result = + evaluator.EvaluateSingleSplit(input, 0, ObjInfo{ObjInfo::kRegression}).split; + EXPECT_EQ(result.findex, 1); EXPECT_EQ(result.fvalue, 11.0); EXPECT_FLOAT_EQ(result.left_sum.GetGrad() + result.right_sum.GetGrad(), @@ -72,7 +77,6 @@ TEST(GpuHist, EvaluateCategoricalSplit) { } TEST(GpuHist, EvaluateSingleSplitMissing) { - thrust::device_vector out_splits(1); GradientPairPrecise parent_sum(1.0, 1.5); TrainParam tparam = ZeroParam(); GPUTrainingParam param{tparam}; @@ -96,11 +100,10 @@ TEST(GpuHist, EvaluateSingleSplitMissing) { dh::ToSpan(feature_min_values), dh::ToSpan(feature_histogram)}; - TreeEvaluator tree_evaluator(tparam, feature_set.size(), 0); - auto evaluator = tree_evaluator.GetEvaluator(); - EvaluateSingleSplit(dh::ToSpan(out_splits), evaluator, input); + GPUHistEvaluator evaluator(tparam, feature_set.size(), 0); + DeviceSplitCandidate result = + evaluator.EvaluateSingleSplit(input, 0, ObjInfo{ObjInfo::kRegression}).split; - DeviceSplitCandidate result = out_splits[0]; EXPECT_EQ(result.findex, 0); EXPECT_EQ(result.fvalue, 1.0); EXPECT_EQ(result.dir, kRightDir); @@ -109,27 +112,18 @@ TEST(GpuHist, EvaluateSingleSplitMissing) { } TEST(GpuHist, EvaluateSingleSplitEmpty) { - DeviceSplitCandidate nonzeroed; - nonzeroed.findex = 1; - nonzeroed.loss_chg = 1.0; - - thrust::device_vector out_split(1); - out_split[0] = nonzeroed; - TrainParam tparam = ZeroParam(); - TreeEvaluator tree_evaluator(tparam, 1, 0); - auto evaluator = tree_evaluator.GetEvaluator(); - EvaluateSingleSplit(dh::ToSpan(out_split), evaluator, - EvaluateSplitInputs{}); - - DeviceSplitCandidate result = out_split[0]; + GPUHistEvaluator evaluator(tparam, 1, 0); + DeviceSplitCandidate result = evaluator + .EvaluateSingleSplit(EvaluateSplitInputs{}, 0, + ObjInfo{ObjInfo::kRegression}) + .split; EXPECT_EQ(result.findex, -1); EXPECT_LT(result.loss_chg, 0.0f); } // Feature 0 has a better split, but the algorithm must select feature 1 TEST(GpuHist, EvaluateSingleSplitFeatureSampling) { - thrust::device_vector out_splits(1); GradientPairPrecise parent_sum(0.0, 1.0); TrainParam tparam = ZeroParam(); tparam.UpdateAllowUnknown(Args{}); @@ -157,11 +151,10 @@ TEST(GpuHist, EvaluateSingleSplitFeatureSampling) { dh::ToSpan(feature_min_values), dh::ToSpan(feature_histogram)}; - TreeEvaluator tree_evaluator(tparam, feature_min_values.size(), 0); - auto evaluator = tree_evaluator.GetEvaluator(); - EvaluateSingleSplit(dh::ToSpan(out_splits), evaluator, input); + GPUHistEvaluator evaluator(tparam, feature_min_values.size(), 0); + DeviceSplitCandidate result = + evaluator.EvaluateSingleSplit(input, 0, ObjInfo{ObjInfo::kRegression}).split; - DeviceSplitCandidate result = out_splits[0]; EXPECT_EQ(result.findex, 1); EXPECT_EQ(result.fvalue, 11.0); EXPECT_EQ(result.left_sum, GradientPairPrecise(-0.5, 0.5)); @@ -170,7 +163,6 @@ TEST(GpuHist, EvaluateSingleSplitFeatureSampling) { // Features 0 and 1 have identical gain, the algorithm must select 0 TEST(GpuHist, EvaluateSingleSplitBreakTies) { - thrust::device_vector out_splits(1); GradientPairPrecise parent_sum(0.0, 1.0); TrainParam tparam = ZeroParam(); tparam.UpdateAllowUnknown(Args{}); @@ -198,11 +190,10 @@ TEST(GpuHist, EvaluateSingleSplitBreakTies) { dh::ToSpan(feature_min_values), dh::ToSpan(feature_histogram)}; - TreeEvaluator tree_evaluator(tparam, feature_min_values.size(), 0); - auto evaluator = tree_evaluator.GetEvaluator(); - EvaluateSingleSplit(dh::ToSpan(out_splits), evaluator, input); + GPUHistEvaluator evaluator(tparam, feature_min_values.size(), 0); + DeviceSplitCandidate result = + evaluator.EvaluateSingleSplit(input, 0, ObjInfo{ObjInfo::kRegression}).split; - DeviceSplitCandidate result = out_splits[0]; EXPECT_EQ(result.findex, 0); EXPECT_EQ(result.fvalue, 1.0); } @@ -250,9 +241,10 @@ TEST(GpuHist, EvaluateSplits) { dh::ToSpan(feature_min_values), dh::ToSpan(feature_histogram_right)}; - TreeEvaluator tree_evaluator(tparam, feature_min_values.size(), 0); - auto evaluator = tree_evaluator.GetEvaluator(); - EvaluateSplits(dh::ToSpan(out_splits), evaluator, input_left, input_right); + GPUHistEvaluator evaluator{ + tparam, static_cast(feature_min_values.size()), 0}; + evaluator.EvaluateSplits(input_left, input_right, ObjInfo{ObjInfo::kRegression}, + evaluator.GetEvaluator(), dh::ToSpan(out_splits)); DeviceSplitCandidate result_left = out_splits[0]; EXPECT_EQ(result_left.findex, 1); @@ -262,5 +254,36 @@ TEST(GpuHist, EvaluateSplits) { EXPECT_EQ(result_right.findex, 0); EXPECT_EQ(result_right.fvalue, 1.0); } + +TEST_F(TestPartitionBasedSplit, GpuHist) { + dh::device_vector ft{std::vector{FeatureType::kCategorical}}; + GPUHistEvaluator evaluator{param_, + static_cast(info_.num_col_), 0}; + + cuts_.cut_ptrs_.SetDevice(0); + cuts_.cut_values_.SetDevice(0); + cuts_.min_vals_.SetDevice(0); + + ObjInfo task{ObjInfo::kRegression}; + evaluator.Reset(cuts_, dh::ToSpan(ft), task, info_.num_col_, param_, 0); + + dh::device_vector d_hist(hist_[0].size()); + auto node_hist = hist_[0]; + dh::safe_cuda(cudaMemcpy(d_hist.data().get(), node_hist.data(), node_hist.size_bytes(), + cudaMemcpyHostToDevice)); + dh::device_vector feature_set{std::vector{0}}; + + EvaluateSplitInputs input{0, + total_gpair_, + GPUTrainingParam{param_}, + dh::ToSpan(feature_set), + dh::ToSpan(ft), + cuts_.cut_ptrs_.ConstDeviceSpan(), + cuts_.cut_values_.ConstDeviceSpan(), + cuts_.min_vals_.ConstDeviceSpan(), + dh::ToSpan(d_hist)}; + auto split = evaluator.EvaluateSingleSplit(input, 0, ObjInfo{ObjInfo::kRegression}).split; + ASSERT_NEAR(split.loss_chg, best_score_, 1e-16); +} } // namespace tree } // namespace xgboost diff --git a/tests/cpp/tree/hist/test_evaluate_splits.cc b/tests/cpp/tree/hist/test_evaluate_splits.cc index 7819ec307375..f3760534b4ab 100644 --- a/tests/cpp/tree/hist/test_evaluate_splits.cc +++ b/tests/cpp/tree/hist/test_evaluate_splits.cc @@ -3,9 +3,11 @@ */ #include #include + +#include "../../../../src/common/hist_util.h" #include "../../../../src/tree/hist/evaluate_splits.h" #include "../../../../src/tree/updater_quantile_hist.h" -#include "../../../../src/common/hist_util.h" +#include "../test_evaluate_splits.h" #include "../../helpers.h" namespace xgboost { @@ -108,80 +110,17 @@ TEST(HistEvaluator, Apply) { ASSERT_EQ(tree.Stat(tree[0].RightChild()).sum_hess, 0.7f); } -TEST(HistEvaluator, CategoricalPartition) { - int static constexpr kRows = 128, kCols = 1; - using GradientSumT = double; - std::vector ft(kCols, FeatureType::kCategorical); - - TrainParam param; - param.UpdateAllowUnknown(Args{{"min_child_weight", "0"}, {"reg_lambda", "0"}}); - - size_t n_cats{8}; - - auto dmat = - RandomDataGenerator(kRows, kCols, 0).Seed(3).Type(ft).MaxCategory(n_cats).GenerateDMatrix(); - - int32_t n_threads = 16; +TEST_F(TestPartitionBasedSplit, CPUHist) { + // check the evaluator is returning the optimal split + std::vector ft{FeatureType::kCategorical}; auto sampler = std::make_shared(); - auto evaluator = HistEvaluator{ - param, dmat->Info(), n_threads, sampler, ObjInfo{ObjInfo::kRegression}}; - - for (auto const &gmat : dmat->GetBatches({32, param.sparse_threshold})) { - common::HistCollection hist; - - std::vector entries(1); - entries.front().nid = 0; - entries.front().depth = 0; - - hist.Init(gmat.cut.TotalBins()); - hist.AddHistRow(0); - hist.AllocateAllData(); - auto node_hist = hist[0]; - ASSERT_EQ(node_hist.size(), n_cats); - ASSERT_EQ(node_hist.size(), gmat.cut.Ptrs().back()); - - GradientPairPrecise total_gpair; - for (size_t i = 0; i < node_hist.size(); ++i) { - node_hist[i] = {static_cast(node_hist.size() - i), 1.0}; - total_gpair += node_hist[i]; - } - SimpleLCG lcg; - std::shuffle(node_hist.begin(), node_hist.end(), lcg); - - RegTree tree; - evaluator.InitRoot(GradStats{total_gpair}); - evaluator.EvaluateSplits(hist, gmat.cut, ft, tree, &entries); - ASSERT_TRUE(entries.front().split.is_cat); - - auto run_eval = [&](auto fn) { - for (size_t i = 1; i < gmat.cut.Ptrs().size(); ++i) { - GradStats left, right; - for (size_t j = gmat.cut.Ptrs()[i - 1]; j < gmat.cut.Ptrs()[i]; ++j) { - auto loss_chg = evaluator.Evaluator().CalcSplitGain(param, 0, i - 1, left, right) - - evaluator.Stats().front().root_gain; - fn(loss_chg); - left.Add(node_hist[j].GetGrad(), node_hist[j].GetHess()); - right.SetSubstract(GradStats{total_gpair}, left); - } - } - }; - // Assert that's the best split - auto best_loss_chg = entries.front().split.loss_chg; - run_eval([&](auto loss_chg) { - // Approximated test that gain returned by optimal partition is greater than - // numerical split. - ASSERT_GT(best_loss_chg, loss_chg); - }); - // node_hist is captured in lambda. - std::sort(node_hist.begin(), node_hist.end(), [&](auto l, auto r) { - return evaluator.Evaluator().CalcWeightCat(param, l) < - evaluator.Evaluator().CalcWeightCat(param, r); - }); - - double reimpl = 0; - run_eval([&](auto loss_chg) { reimpl = std::max(loss_chg, reimpl); }); - CHECK_EQ(reimpl, best_loss_chg); - } + HistEvaluator evaluator{param_, info_, common::OmpGetNumThreads(0), + sampler, ObjInfo{ObjInfo::kRegression}}; + evaluator.InitRoot(GradStats{total_gpair_}); + RegTree tree; + std::vector entries(1); + evaluator.EvaluateSplits(hist_, cuts_, {ft}, tree, &entries); + ASSERT_NEAR(entries[0].split.loss_chg, best_score_, 1e-16); } namespace { diff --git a/tests/cpp/tree/test_evaluate_splits.h b/tests/cpp/tree/test_evaluate_splits.h new file mode 100644 index 000000000000..4b1a320319fb --- /dev/null +++ b/tests/cpp/tree/test_evaluate_splits.h @@ -0,0 +1,96 @@ +/*! + * Copyright 2022 by XGBoost Contributors + */ +#include + +#include // next_permutation +#include // iota + +#include "../../../src/tree/hist/evaluate_splits.h" +#include "../helpers.h" + +namespace xgboost { +namespace tree { +/** + * \brief Enumerate all possible partitions for categorical split. + */ +class TestPartitionBasedSplit : public ::testing::Test { + protected: + size_t n_bins_ = 6; + std::vector sorted_idx_; + TrainParam param_; + MetaInfo info_; + float best_score_{-std::numeric_limits::infinity()}; + common::HistogramCuts cuts_; + common::HistCollection hist_; + GradientPairPrecise total_gpair_; + + void SetUp() override { + param_.UpdateAllowUnknown(Args{{"min_child_weight", "0"}, {"reg_lambda", "0"}}); + sorted_idx_.resize(n_bins_); + std::iota(sorted_idx_.begin(), sorted_idx_.end(), 0); + + info_.num_col_ = 1; + + cuts_.cut_ptrs_.Resize(2); + cuts_.SetCategorical(true, n_bins_); + auto &h_cuts = cuts_.cut_ptrs_.HostVector(); + h_cuts[0] = 0; + h_cuts[1] = n_bins_; + auto &h_vals = cuts_.cut_values_.HostVector(); + h_vals.resize(n_bins_); + std::iota(h_vals.begin(), h_vals.end(), 0.0); + + hist_.Init(cuts_.TotalBins()); + hist_.AddHistRow(0); + hist_.AllocateAllData(); + auto node_hist = hist_[0]; + + SimpleLCG lcg; + SimpleRealUniformDistribution grad_dist{-4.0, 4.0}; + SimpleRealUniformDistribution hess_dist{0.0, 4.0}; + + for (auto &e : node_hist) { + e = GradientPairPrecise{grad_dist(&lcg), hess_dist(&lcg)}; + total_gpair_ += e; + } + + auto enumerate = [this, n_feat = info_.num_col_](common::GHistRow hist, + GradientPairPrecise parent_sum) { + int32_t best_thresh = -1; + float best_score{-std::numeric_limits::infinity()}; + TreeEvaluator evaluator{param_, static_cast(n_feat), -1}; + auto tree_evaluator = evaluator.GetEvaluator(); + GradientPairPrecise left_sum; + auto parent_gain = tree_evaluator.CalcGain(0, param_, GradStats{total_gpair_}); + for (size_t i = 0; i < hist.size() - 1; ++i) { + left_sum += hist[i]; + auto right_sum = parent_sum - left_sum; + auto gain = + tree_evaluator.CalcSplitGain(param_, 0, 0, GradStats{left_sum}, GradStats{right_sum}) - + parent_gain; + if (gain > best_score) { + best_score = gain; + best_thresh = i; + } + } + return std::make_tuple(best_thresh, best_score); + }; + + // enumerate all possible partitions to find the optimal split + do { + int32_t thresh; + float score; + std::vector sorted_hist(node_hist.size()); + for (size_t i = 0; i < sorted_hist.size(); ++i) { + sorted_hist[i] = node_hist[sorted_idx_[i]]; + } + std::tie(thresh, score) = enumerate({sorted_hist}, total_gpair_); + if (score > best_score_) { + best_score_ = score; + } + } while (std::next_permutation(sorted_idx_.begin(), sorted_idx_.end())); + } +}; +} // namespace tree +} // namespace xgboost diff --git a/tests/cpp/tree/test_gpu_hist.cu b/tests/cpp/tree/test_gpu_hist.cu index 5639c2f003bc..39c68eec469b 100644 --- a/tests/cpp/tree/test_gpu_hist.cu +++ b/tests/cpp/tree/test_gpu_hist.cu @@ -262,7 +262,8 @@ TEST(GpuHist, EvaluateRootSplit) { info.num_row_ = kNRows; info.num_col_ = kNCols; - DeviceSplitCandidate res = maker.EvaluateRootSplit({6.4f, 12.8f}); + DeviceSplitCandidate res = + maker.EvaluateRootSplit({6.4f, 12.8f}, 0, ObjInfo{ObjInfo::kRegression}).split; ASSERT_EQ(res.findex, 7); ASSERT_NEAR(res.fvalue, 0.26, xgboost::kRtEps); @@ -300,11 +301,11 @@ void TestHistogramIndexImpl() { const auto &maker = hist_maker.maker; auto grad = GenerateRandomGradients(kNRows); grad.SetDevice(0); - maker->Reset(&grad, hist_maker_dmat.get(), kNCols); + maker->Reset(&grad, hist_maker_dmat.get(), kNCols, ObjInfo{ObjInfo::kRegression}); std::vector h_gidx_buffer(maker->page->gidx_buffer.HostVector()); const auto &maker_ext = hist_maker_ext.maker; - maker_ext->Reset(&grad, hist_maker_ext_dmat.get(), kNCols); + maker_ext->Reset(&grad, hist_maker_ext_dmat.get(), kNCols, ObjInfo{ObjInfo::kRegression}); std::vector h_gidx_buffer_ext(maker_ext->page->gidx_buffer.HostVector()); ASSERT_EQ(maker->page->Cuts().TotalBins(), maker_ext->page->Cuts().TotalBins()); diff --git a/tests/python/test_updaters.py b/tests/python/test_updaters.py index df8adcc424da..3e80d273899f 100644 --- a/tests/python/test_updaters.py +++ b/tests/python/test_updaters.py @@ -211,6 +211,34 @@ def run_categorical_basic(self, rows, cols, rounds, cats, tree_method): ) assert tm.non_increasing(by_builtin_results["Train"]["rmse"]) + by_grouping: xgb.callback.TrainingCallback.EvalsLog = {} + parameters["max_cat_to_onehot"] = 1 + parameters["reg_lambda"] = 0 + m = xgb.DMatrix(cat, label, enable_categorical=True) + xgb.train( + parameters, + m, + num_boost_round=rounds, + evals=[(m, "Train")], + evals_result=by_grouping, + ) + rmse_oh = by_builtin_results["Train"]["rmse"] + rmse_group = by_grouping["Train"]["rmse"] + # always better or equal to onehot when there's no regularization. + for a, b in zip(rmse_oh, rmse_group): + assert a >= b + + parameters["reg_lambda"] = 1.0 + by_grouping = {} + xgb.train( + parameters, + m, + num_boost_round=32, + evals=[(m, "Train")], + evals_result=by_grouping, + ) + assert tm.non_increasing(by_grouping["Train"]["rmse"]), by_grouping + @given(strategies.integers(10, 400), strategies.integers(3, 8), strategies.integers(1, 2), strategies.integers(4, 7)) @settings(deadline=None)