Skip to content

Commit

Permalink
fix(deep): torch model outputs need to be transferred to cpu before c…
Browse files Browse the repository at this point in the history
…hecking for additivity #3280 (#3281)
  • Loading branch information
noxthot committed Sep 27, 2023
1 parent 1f05e2c commit e51a87d
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion shap/explainers/_deep/deep_pytorch.py
Expand Up @@ -215,7 +215,7 @@ def shap_values(self, X, ranked_outputs=None, output_rank_order="max", check_add
with torch.no_grad():
model_output_values = self.model(*X)

_check_additivity(self, model_output_values, output_phis)
_check_additivity(self, model_output_values.cpu(), output_phis)

if not self.multi_output:
return output_phis[0]
Expand Down

0 comments on commit e51a87d

Please sign in to comment.