diff --git a/src/predictor/gpu_predictor.cu b/src/predictor/gpu_predictor.cu index dad92e630425..86352487f805 100644 --- a/src/predictor/gpu_predictor.cu +++ b/src/predictor/gpu_predictor.cu @@ -529,32 +529,22 @@ class GPUPredictor : public xgboost::Predictor { size_t entry_start = 0; SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(), num_features); + auto const kernel = [&](auto predict_fn) { + predict_fn(data, model.nodes.ConstDeviceSpan(), + predictions->DeviceSpan().subspan(batch_offset), + model.tree_segments.ConstDeviceSpan(), + model.tree_group.ConstDeviceSpan(), + model.split_types.ConstDeviceSpan(), + model.categories_tree_segments.ConstDeviceSpan(), + model.categories_node_segments.ConstDeviceSpan(), + model.categories.ConstDeviceSpan(), model.tree_beg_, + model.tree_end_, num_features, num_rows, entry_start, + use_shared, model.num_group, nan("")); + }; if (is_dense) { - dh::LaunchKernel{GRID_SIZE, BLOCK_THREADS, shared_memory_bytes}( - PredictKernel, data, - model.nodes.ConstDeviceSpan(), - predictions->DeviceSpan().subspan(batch_offset), - model.tree_segments.ConstDeviceSpan(), - model.tree_group.ConstDeviceSpan(), - model.split_types.ConstDeviceSpan(), - model.categories_tree_segments.ConstDeviceSpan(), - model.categories_node_segments.ConstDeviceSpan(), - model.categories.ConstDeviceSpan(), model.tree_beg_, model.tree_end_, - num_features, num_rows, entry_start, use_shared, model.num_group, - nan("")); + kernel(PredictKernel); } else { - dh::LaunchKernel{GRID_SIZE, BLOCK_THREADS, shared_memory_bytes}( - PredictKernel, data, - model.nodes.ConstDeviceSpan(), - predictions->DeviceSpan().subspan(batch_offset), - model.tree_segments.ConstDeviceSpan(), - model.tree_group.ConstDeviceSpan(), - model.split_types.ConstDeviceSpan(), - model.categories_tree_segments.ConstDeviceSpan(), - model.categories_node_segments.ConstDeviceSpan(), - model.categories.ConstDeviceSpan(), model.tree_beg_, model.tree_end_, - num_features, num_rows, entry_start, use_shared, model.num_group, - nan("")); + kernel(PredictKernel); } } void PredictInternal(EllpackDeviceAccessor const& batch,