diff --git a/demo/guide-python/cat_in_the_dat.py b/demo/guide-python/cat_in_the_dat.py index 27741f37abaf..29f55aba7de1 100644 --- a/demo/guide-python/cat_in_the_dat.py +++ b/demo/guide-python/cat_in_the_dat.py @@ -74,12 +74,12 @@ def categorical_model(X: pd.DataFrame, y: pd.Series, output_dir: str) -> None: X_train, X_test, y_train, y_test = train_test_split( X, y, random_state=1994, test_size=0.2 ) - # Specify `enable_categorical`. + # Specify `enable_categorical` to True. clf = xgb.XGBClassifier( **params, eval_metric="auc", enable_categorical=True, - max_cat_to_onehot=1, # We use optimal partitioning exclusively + max_cat_to_onehot=1, # We use optimal partitioning exclusively ) clf.fit(X_train, y_train, eval_set=[(X_test, y_test), (X_train, y_train)]) clf.save_model(os.path.join(output_dir, "categorical.json")) @@ -94,13 +94,12 @@ def onehot_encoding_model(X: pd.DataFrame, y: pd.Series, output_dir: str) -> Non X_train, X_test, y_train, y_test = train_test_split( X, y, random_state=42, test_size=0.2 ) - # Specify `enable_categorical`. - clf = xgb.XGBClassifier(**params, enable_categorical=False) + # Specify `enable_categorical` to False as we are using encoded data. + clf = xgb.XGBClassifier(**params, eval_metric="auc", enable_categorical=False) clf.fit( X_train, y_train, eval_set=[(X_test, y_test), (X_train, y_train)], - eval_metric="auc", ) clf.save_model(os.path.join(output_dir, "one-hot.json")) diff --git a/src/tree/gpu_hist/evaluate_splits.cuh b/src/tree/gpu_hist/evaluate_splits.cuh index c22834e83860..b03fd7b41b51 100644 --- a/src/tree/gpu_hist/evaluate_splits.cuh +++ b/src/tree/gpu_hist/evaluate_splits.cuh @@ -51,6 +51,12 @@ class GPUHistEvaluator { dh::CUDAStream copy_stream_; // storage for sorted index of feature histogram, used for sort based splits. dh::device_vector cat_sorted_idx_; + // cached input for sorting the histogram, used for sort based splits. + using SortPair = thrust::tuple; + dh::device_vector sort_input_; + // cache for feature index + dh::device_vector feature_idx_; + // Training param used for evaluation TrainParam param_; // whether the input data requires sort based split, which is more complicated so we try // to avoid it if possible. @@ -95,6 +101,13 @@ class GPUHistEvaluator { return dh::ToSpan(cat_sorted_idx_); } + auto SortInput(EvaluateSplitInputs left) { + if (left.nidx == RegTree::kRoot && !cat_sorted_idx_.empty()) { + return dh::ToSpan(sort_input_).first(left.feature_values.size()); + } + return dh::ToSpan(sort_input_); + } + public: GPUHistEvaluator(TrainParam const ¶m, bst_feature_t n_features, int32_t device) : tree_evaluator_{param, n_features, device}, param_{param} {} diff --git a/src/tree/gpu_hist/evaluator.cu b/src/tree/gpu_hist/evaluator.cu index 7566b2847505..bc2027489131 100644 --- a/src/tree/gpu_hist/evaluator.cu +++ b/src/tree/gpu_hist/evaluator.cu @@ -54,6 +54,21 @@ void GPUHistEvaluator::Reset(common::HistogramCuts const &cuts, cudaMemsetAsync(split_cats_.data().get(), '\0', split_cats_.size() * sizeof(CatST))); cat_sorted_idx_.resize(cuts.cut_values_.Size() * 2); // evaluate 2 nodes at a time. + sort_input_.resize(cat_sorted_idx_.size()); + + /** + * cache feature index binary search result + */ + feature_idx_.resize(cat_sorted_idx_.size()); + auto d_fidxes = dh::ToSpan(feature_idx_); + auto it = thrust::make_counting_iterator(0ul); + auto values = cuts.cut_values_.ConstDeviceSpan(); + auto ptrs = cuts.cut_ptrs_.ConstDeviceSpan(); + thrust::transform(thrust::cuda::par(alloc), it, it + feature_idx_.size(), + feature_idx_.begin(), [=] XGBOOST_DEVICE(size_t i) { + auto fidx = dh::SegmentId(ptrs, i); + return fidx; + }); } } } @@ -62,35 +77,55 @@ template common::Span GPUHistEvaluator::SortHistogram( EvaluateSplitInputs const &left, EvaluateSplitInputs const &right, TreeEvaluator::SplitEvaluator evaluator) { - dh::XGBDeviceAllocator alloc; + dh::XGBCachingDeviceAllocator alloc; auto sorted_idx = this->SortedIdx(left); dh::Iota(sorted_idx); - // sort 2 nodes and all the features at the same time, disregarding colmun sampling. - thrust::stable_sort( - thrust::cuda::par(alloc), dh::tbegin(sorted_idx), dh::tend(sorted_idx), - [evaluator, left, right] XGBOOST_DEVICE(size_t l, size_t r) { - auto l_is_left = l < left.feature_values.size(); - auto r_is_left = r < left.feature_values.size(); - if (l_is_left != r_is_left) { - return l_is_left; // not the same node - } + auto data = this->SortInput(left); + auto it = thrust::make_counting_iterator(0u); + auto d_feature_idx = dh::ToSpan(feature_idx_); + thrust::transform(thrust::cuda::par(alloc), it, it + data.size(), dh::tbegin(data), + [=] XGBOOST_DEVICE(uint32_t i) { + auto is_left = i < left.feature_values.size(); + auto const &input = is_left ? left : right; + auto j = i - (is_left ? 0 : input.feature_values.size()); + auto fidx = d_feature_idx[j]; + if (common::IsCat(input.feature_types, fidx)) { + auto lw = evaluator.CalcWeightCat(input.param, input.gradient_histogram[j]); + return thrust::make_tuple(i, lw); + } + return thrust::make_tuple(i, 0.0); + }); + thrust::stable_sort_by_key(thrust::cuda::par(alloc), dh::tbegin(data), dh::tend(data), + dh::tbegin(sorted_idx), + [=] XGBOOST_DEVICE(SortPair const &l, SortPair const &r) { + auto li = thrust::get<0>(l); + auto ri = thrust::get<0>(r); + + auto l_is_left = li < left.feature_values.size(); + auto r_is_left = ri < left.feature_values.size(); + + if (l_is_left != r_is_left) { + return l_is_left; // not the same node + } + + auto const &input = l_is_left ? left : right; + li -= (l_is_left ? 0 : input.feature_values.size()); + ri -= (r_is_left ? 0 : input.feature_values.size()); + + auto lfidx = d_feature_idx[li]; + auto rfidx = d_feature_idx[ri]; - auto const &input = l_is_left ? left : right; - l -= (l_is_left ? 0 : input.feature_values.size()); - r -= (r_is_left ? 0 : input.feature_values.size()); + if (lfidx != rfidx) { + return lfidx < rfidx; // not the same feature + } - auto lfidx = dh::SegmentId(input.feature_segments, l); - auto rfidx = dh::SegmentId(input.feature_segments, r); - if (lfidx != rfidx) { - return lfidx < rfidx; // not the same feature - } - if (common::IsCat(input.feature_types, lfidx)) { - auto lw = evaluator.CalcWeightCat(input.param, input.gradient_histogram[l]); - auto rw = evaluator.CalcWeightCat(input.param, input.gradient_histogram[r]); - return lw < rw; - } - return l < r; - }); + if (common::IsCat(input.feature_types, lfidx)) { + auto lw = thrust::get<1>(l); + auto rw = thrust::get<1>(r); + return lw < rw; + } + return li < ri; + }); return dh::ToSpan(cat_sorted_idx_); }