diff --git a/CHANGELOG.md b/CHANGELOG.md index 18840193b2e..894903f48a8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -46,6 +46,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed JaccardIndex multi-label compute ([#1125](https://github.com/Lightning-AI/metrics/pull/1125)) +- Fix SSIM propagate device if `gaussian_kernel` is False, add test ([#1149](https://github.com/Lightning-AI/metrics/pull/1149)) + + ## [0.9.2] - 2022-06-29 diff --git a/tests/image/test_ssim.py b/tests/image/test_ssim.py index e868fe96570..daed5445a30 100644 --- a/tests/image/test_ssim.py +++ b/tests/image/test_ssim.py @@ -125,6 +125,23 @@ def test_ssim(self, preds, target, sigma, ddp, dist_sync_on_step): dist_sync_on_step=dist_sync_on_step, ) + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_ssim_without_gaussian_kernel(self, preds, target, sigma, ddp, dist_sync_on_step): + self.run_class_metric_test( + ddp, + preds, + target, + StructuralSimilarityIndexMeasure, + partial(_sk_ssim, data_range=1.0, sigma=sigma, kernel_size=None), + metric_args={ + "gaussian_kernel": False, + "data_range": 1.0, + "sigma": sigma, + }, + dist_sync_on_step=dist_sync_on_step, + ) + def test_ssim_functional(self, preds, target, sigma): self.run_functional_metric_test( preds, diff --git a/torchmetrics/functional/image/ssim.py b/torchmetrics/functional/image/ssim.py index 2c775994018..57636469867 100644 --- a/torchmetrics/functional/image/ssim.py +++ b/torchmetrics/functional/image/ssim.py @@ -150,7 +150,9 @@ def _ssim_compute( kernel = _gaussian_kernel_2d(channel, gauss_kernel_size, sigma, dtype, device) if not gaussian_kernel: - kernel = torch.ones((1, 1, *kernel_size)) / torch.prod(Tensor(kernel_size)) + kernel = torch.ones((channel, 1, *kernel_size), dtype=dtype, device=device) / torch.prod( + torch.tensor(kernel_size, dtype=dtype, device=device) + ) input_list = torch.cat((preds, target, preds * preds, target * target, preds * target)) # (5 * B, C, H, W)