Skip to content

Commit

Permalink
Check tree weights.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jun 25, 2021
1 parent 41d5106 commit 5af6c2c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 6 deletions.
2 changes: 1 addition & 1 deletion gputreeshap
19 changes: 14 additions & 5 deletions src/predictor/gpu_predictor.cu
Expand Up @@ -814,11 +814,16 @@ class GPUPredictor : public xgboost::Predictor {
void PredictContribution(DMatrix* p_fmat,
HostDeviceVector<bst_float>* out_contribs,
const gbm::GBTreeModel& model, unsigned tree_end,
std::vector<bst_float> const*,
std::vector<bst_float> const* tree_weights,
bool approximate, int,
unsigned) const override {
std::string not_implemented{"contribution is not implemented in GPU "
"predictor, use `cpu_predictor` instead."};
if (approximate) {
LOG(FATAL) << "Approximated contribution is not implemented in GPU Predictor.";
LOG(FATAL) << "Approximated " << not_implemented;
}
if (tree_weights != nullptr) {
LOG(FATAL) << "Dart booster feature " << not_implemented;
}
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
out_contribs->SetDevice(generic_param_->gpu_id);
Expand Down Expand Up @@ -869,11 +874,15 @@ class GPUPredictor : public xgboost::Predictor {
HostDeviceVector<bst_float>* out_contribs,
const gbm::GBTreeModel& model,
unsigned tree_end,
std::vector<bst_float> const*,
std::vector<bst_float> const* tree_weights,
bool approximate) const override {
std::string not_implemented{"contribution is not implemented in GPU "
"predictor, use `cpu_predictor` instead."};
if (approximate) {
LOG(FATAL) << "[Internal error]: " << __func__
<< " approximate is not implemented in GPU Predictor.";
LOG(FATAL) << "Approximated " << not_implemented;
}
if (tree_weights != nullptr) {
LOG(FATAL) << "Dart booster feature " << not_implemented;
}
dh::safe_cuda(cudaSetDevice(generic_param_->gpu_id));
out_contribs->SetDevice(generic_param_->gpu_id);
Expand Down

0 comments on commit 5af6c2c

Please sign in to comment.