Skip to content

Commit

Permalink
Specify the allocator.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jun 23, 2021
1 parent 957d778 commit 3050cfa
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
2 changes: 1 addition & 1 deletion gputreeshap
11 changes: 6 additions & 5 deletions src/predictor/gpu_predictor.cu
Expand Up @@ -788,8 +788,9 @@ class GPUPredictor : public xgboost::Predictor {
SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
model.learner_model_param->num_feature);
auto begin = dh::tbegin(phis) + batch.base_rowid * contributions_columns;
gpu_treeshap::GPUTreeShap(X, device_paths.begin(), device_paths.end(),
ngroup, begin, dh::tend(phis));
gpu_treeshap::GPUTreeShap<dh::XGBDeviceAllocator<int>>(
X, device_paths.begin(), device_paths.end(), ngroup, begin,
dh::tend(phis));
}
// Add the base margin term to last column
p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id);
Expand Down Expand Up @@ -844,9 +845,9 @@ class GPUPredictor : public xgboost::Predictor {
SparsePageView X(batch.data.DeviceSpan(), batch.offset.DeviceSpan(),
model.learner_model_param->num_feature);
auto begin = dh::tbegin(phis) + batch.base_rowid * contributions_columns;
gpu_treeshap::GPUTreeShapInteractions(X, device_paths.begin(),
device_paths.end(), ngroup, begin,
dh::tend(phis));
gpu_treeshap::GPUTreeShapInteractions<dh::XGBDeviceAllocator<int>>(
X, device_paths.begin(), device_paths.end(), ngroup, begin,
dh::tend(phis));
}
// Add the base margin term to last column
p_fmat->Info().base_margin_.SetDevice(generic_param_->gpu_id);
Expand Down

0 comments on commit 3050cfa

Please sign in to comment.