Skip to content

Commit

Permalink
Optimize GPU evaluation function for categorical data. (#7705)
Browse files Browse the repository at this point in the history
* Use transform and cache.
  • Loading branch information
trivialfis committed Feb 28, 2022
1 parent 18a4af6 commit 1d468e2
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 30 deletions.
9 changes: 4 additions & 5 deletions demo/guide-python/cat_in_the_dat.py
Expand Up @@ -74,12 +74,12 @@ def categorical_model(X: pd.DataFrame, y: pd.Series, output_dir: str) -> None:
X_train, X_test, y_train, y_test = train_test_split(
X, y, random_state=1994, test_size=0.2
)
# Specify `enable_categorical`.
# Specify `enable_categorical` to True.
clf = xgb.XGBClassifier(
**params,
eval_metric="auc",
enable_categorical=True,
max_cat_to_onehot=1, # We use optimal partitioning exclusively
max_cat_to_onehot=1, # We use optimal partitioning exclusively
)
clf.fit(X_train, y_train, eval_set=[(X_test, y_test), (X_train, y_train)])
clf.save_model(os.path.join(output_dir, "categorical.json"))
Expand All @@ -94,13 +94,12 @@ def onehot_encoding_model(X: pd.DataFrame, y: pd.Series, output_dir: str) -> Non
X_train, X_test, y_train, y_test = train_test_split(
X, y, random_state=42, test_size=0.2
)
# Specify `enable_categorical`.
clf = xgb.XGBClassifier(**params, enable_categorical=False)
# Specify `enable_categorical` to False as we are using encoded data.
clf = xgb.XGBClassifier(**params, eval_metric="auc", enable_categorical=False)
clf.fit(
X_train,
y_train,
eval_set=[(X_test, y_test), (X_train, y_train)],
eval_metric="auc",
)
clf.save_model(os.path.join(output_dir, "one-hot.json"))

Expand Down
13 changes: 13 additions & 0 deletions src/tree/gpu_hist/evaluate_splits.cuh
Expand Up @@ -51,6 +51,12 @@ class GPUHistEvaluator {
dh::CUDAStream copy_stream_;
// storage for sorted index of feature histogram, used for sort based splits.
dh::device_vector<bst_feature_t> cat_sorted_idx_;
// cached input for sorting the histogram, used for sort based splits.
using SortPair = thrust::tuple<uint32_t, double>;
dh::device_vector<SortPair> sort_input_;
// cache for feature index
dh::device_vector<bst_feature_t> feature_idx_;
// Training param used for evaluation
TrainParam param_;
// whether the input data requires sort based split, which is more complicated so we try
// to avoid it if possible.
Expand Down Expand Up @@ -95,6 +101,13 @@ class GPUHistEvaluator {
return dh::ToSpan(cat_sorted_idx_);
}

auto SortInput(EvaluateSplitInputs<GradientSumT> left) {
if (left.nidx == RegTree::kRoot && !cat_sorted_idx_.empty()) {
return dh::ToSpan(sort_input_).first(left.feature_values.size());
}
return dh::ToSpan(sort_input_);
}

public:
GPUHistEvaluator(TrainParam const &param, bst_feature_t n_features, int32_t device)
: tree_evaluator_{param, n_features, device}, param_{param} {}
Expand Down
85 changes: 60 additions & 25 deletions src/tree/gpu_hist/evaluator.cu
Expand Up @@ -54,6 +54,21 @@ void GPUHistEvaluator<GradientSumT>::Reset(common::HistogramCuts const &cuts,
cudaMemsetAsync(split_cats_.data().get(), '\0', split_cats_.size() * sizeof(CatST)));

cat_sorted_idx_.resize(cuts.cut_values_.Size() * 2); // evaluate 2 nodes at a time.
sort_input_.resize(cat_sorted_idx_.size());

/**
* cache feature index binary search result
*/
feature_idx_.resize(cat_sorted_idx_.size());
auto d_fidxes = dh::ToSpan(feature_idx_);
auto it = thrust::make_counting_iterator(0ul);
auto values = cuts.cut_values_.ConstDeviceSpan();
auto ptrs = cuts.cut_ptrs_.ConstDeviceSpan();
thrust::transform(thrust::cuda::par(alloc), it, it + feature_idx_.size(),
feature_idx_.begin(), [=] XGBOOST_DEVICE(size_t i) {
auto fidx = dh::SegmentId(ptrs, i);
return fidx;
});
}
}
}
Expand All @@ -62,35 +77,55 @@ template <typename GradientSumT>
common::Span<bst_feature_t const> GPUHistEvaluator<GradientSumT>::SortHistogram(
EvaluateSplitInputs<GradientSumT> const &left, EvaluateSplitInputs<GradientSumT> const &right,
TreeEvaluator::SplitEvaluator<GPUTrainingParam> evaluator) {
dh::XGBDeviceAllocator<char> alloc;
dh::XGBCachingDeviceAllocator<char> alloc;
auto sorted_idx = this->SortedIdx(left);
dh::Iota(sorted_idx);
// sort 2 nodes and all the features at the same time, disregarding colmun sampling.
thrust::stable_sort(
thrust::cuda::par(alloc), dh::tbegin(sorted_idx), dh::tend(sorted_idx),
[evaluator, left, right] XGBOOST_DEVICE(size_t l, size_t r) {
auto l_is_left = l < left.feature_values.size();
auto r_is_left = r < left.feature_values.size();
if (l_is_left != r_is_left) {
return l_is_left; // not the same node
}
auto data = this->SortInput(left);
auto it = thrust::make_counting_iterator(0u);
auto d_feature_idx = dh::ToSpan(feature_idx_);
thrust::transform(thrust::cuda::par(alloc), it, it + data.size(), dh::tbegin(data),
[=] XGBOOST_DEVICE(uint32_t i) {
auto is_left = i < left.feature_values.size();
auto const &input = is_left ? left : right;
auto j = i - (is_left ? 0 : input.feature_values.size());
auto fidx = d_feature_idx[j];
if (common::IsCat(input.feature_types, fidx)) {
auto lw = evaluator.CalcWeightCat(input.param, input.gradient_histogram[j]);
return thrust::make_tuple(i, lw);
}
return thrust::make_tuple(i, 0.0);
});
thrust::stable_sort_by_key(thrust::cuda::par(alloc), dh::tbegin(data), dh::tend(data),
dh::tbegin(sorted_idx),
[=] XGBOOST_DEVICE(SortPair const &l, SortPair const &r) {
auto li = thrust::get<0>(l);
auto ri = thrust::get<0>(r);

auto l_is_left = li < left.feature_values.size();
auto r_is_left = ri < left.feature_values.size();

if (l_is_left != r_is_left) {
return l_is_left; // not the same node
}

auto const &input = l_is_left ? left : right;
li -= (l_is_left ? 0 : input.feature_values.size());
ri -= (r_is_left ? 0 : input.feature_values.size());

auto lfidx = d_feature_idx[li];
auto rfidx = d_feature_idx[ri];

auto const &input = l_is_left ? left : right;
l -= (l_is_left ? 0 : input.feature_values.size());
r -= (r_is_left ? 0 : input.feature_values.size());
if (lfidx != rfidx) {
return lfidx < rfidx; // not the same feature
}

auto lfidx = dh::SegmentId(input.feature_segments, l);
auto rfidx = dh::SegmentId(input.feature_segments, r);
if (lfidx != rfidx) {
return lfidx < rfidx; // not the same feature
}
if (common::IsCat(input.feature_types, lfidx)) {
auto lw = evaluator.CalcWeightCat(input.param, input.gradient_histogram[l]);
auto rw = evaluator.CalcWeightCat(input.param, input.gradient_histogram[r]);
return lw < rw;
}
return l < r;
});
if (common::IsCat(input.feature_types, lfidx)) {
auto lw = thrust::get<1>(l);
auto rw = thrust::get<1>(r);
return lw < rw;
}
return li < ri;
});
return dh::ToSpan(cat_sorted_idx_);
}

Expand Down

0 comments on commit 1d468e2

Please sign in to comment.