Skip to content

Commit

Permalink
Simplify.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed May 27, 2021
1 parent fbd852b commit dab1750
Showing 1 changed file with 14 additions and 24 deletions.
38 changes: 14 additions & 24 deletions src/predictor/gpu_predictor.cu
Expand Up @@ -529,32 +529,22 @@ class GPUPredictor : public xgboost::Predictor {
size_t entry_start = 0;
SparsePageView data(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
num_features);
auto const kernel = [&](auto predict_fn) {
predict_fn(data, model.nodes.ConstDeviceSpan(),
predictions->DeviceSpan().subspan(batch_offset),
model.tree_segments.ConstDeviceSpan(),
model.tree_group.ConstDeviceSpan(),
model.split_types.ConstDeviceSpan(),
model.categories_tree_segments.ConstDeviceSpan(),
model.categories_node_segments.ConstDeviceSpan(),
model.categories.ConstDeviceSpan(), model.tree_beg_,
model.tree_end_, num_features, num_rows, entry_start,
use_shared, model.num_group, nan(""));
};
if (is_dense) {
dh::LaunchKernel{GRID_SIZE, BLOCK_THREADS, shared_memory_bytes}(
PredictKernel<SparsePageLoader, SparsePageView, false>, data,
model.nodes.ConstDeviceSpan(),
predictions->DeviceSpan().subspan(batch_offset),
model.tree_segments.ConstDeviceSpan(),
model.tree_group.ConstDeviceSpan(),
model.split_types.ConstDeviceSpan(),
model.categories_tree_segments.ConstDeviceSpan(),
model.categories_node_segments.ConstDeviceSpan(),
model.categories.ConstDeviceSpan(), model.tree_beg_, model.tree_end_,
num_features, num_rows, entry_start, use_shared, model.num_group,
nan(""));
kernel(PredictKernel<SparsePageLoader, SparsePageView, false>);
} else {
dh::LaunchKernel{GRID_SIZE, BLOCK_THREADS, shared_memory_bytes}(
PredictKernel<SparsePageLoader, SparsePageView, true>, data,
model.nodes.ConstDeviceSpan(),
predictions->DeviceSpan().subspan(batch_offset),
model.tree_segments.ConstDeviceSpan(),
model.tree_group.ConstDeviceSpan(),
model.split_types.ConstDeviceSpan(),
model.categories_tree_segments.ConstDeviceSpan(),
model.categories_node_segments.ConstDeviceSpan(),
model.categories.ConstDeviceSpan(), model.tree_beg_, model.tree_end_,
num_features, num_rows, entry_start, use_shared, model.num_group,
nan(""));
kernel(PredictKernel<SparsePageLoader, SparsePageView, true>);
}
}
void PredictInternal(EllpackDeviceAccessor const& batch,
Expand Down

0 comments on commit dab1750

Please sign in to comment.