You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When gaussian_kernel is set to False and images are on gpu, we can compute ssim metric because type difference.
Code sample
img=torch.rand(10, 3, 64, 64).to(torch.device("cuda"))
ssim=SSIM(sigma=(1.5, 1.5), kernel_size=(11, 11), gaussian_kernel=False, compute_on_cpu=False, device=torch.device("cuda")).to(torch.device("cuda")) # force computation on gpussim(img, img)
Expected behavior
Compute StructuralSimilarityIndexMeasure between img and img
Environment
TorchMetrics version (and how you installed TM, e.g. conda, pip, build from source): 0.9.2
Python & PyTorch Version (e.g., 1.0): 3.9.7
Any other relevant information such as OS (e.g., Linux):
Additional context
RuntimeError Traceback (most recent call last)
<ipython-input-4-4817dbad2239>in<module>
----> 1 ssim(img, img), mean_squared_error(img, img)
~/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []
~/anaconda3/lib/python3.8/site-packages/torchmetrics/metric.py in forward(self, *args, **kwargs)
235 self._forward_cache = self._forward_full_state_update(*args, **kwargs)
236 else:
--> 237 self._forward_cache = self._forward_reduce_state_update(*args, **kwargs)
238
239 return self._forward_cache
~/anaconda3/lib/python3.8/site-packages/torchmetrics/metric.py in _forward_reduce_state_update(self, *args, **kwargs)
300 # calculate batch state and compute batch value
301 self.update(*args, **kwargs)
--> 302 batch_val = self.compute()
303
304 # reduce batch and global state
...
--> 160 outputs = F.conv2d(input_list, kernel, groups=channel)
161
162 output_list = outputs.split(preds.shape[0])
RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same
The text was updated successfully, but these errors were encountered:
Hi, @KameniAlexNea - Thanks for the issue! I've attempted a fix here: #1149 and waiting for the CI. Haven't tested it locally yet (Tested locally now), but it seemed to be a missing device propagation when gaussian_kernel is False.
馃悰 Bug
When
gaussian_kernel
is set toFalse
and images are on gpu, we can compute ssim metric because type difference.Code sample
Expected behavior
Compute StructuralSimilarityIndexMeasure between
img
andimg
Environment
conda
,pip
, build from source): 0.9.2Additional context
The text was updated successfully, but these errors were encountered: