diff --git a/CHANGELOG.md b/CHANGELOG.md index 6eac7de9e84..5f8125863b0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -50,6 +50,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed JaccardIndex multi-label compute ([#1125](https://github.com/Lightning-AI/metrics/pull/1125)) +- Fix SSIM propagate device if `gaussian_kernel` is False, add test ([#1149](https://github.com/Lightning-AI/metrics/pull/1149)) + + ## [0.9.2] - 2022-06-29 diff --git a/src/torchmetrics/functional/image/ssim.py b/src/torchmetrics/functional/image/ssim.py index 2c775994018..57636469867 100644 --- a/src/torchmetrics/functional/image/ssim.py +++ b/src/torchmetrics/functional/image/ssim.py @@ -150,7 +150,9 @@ def _ssim_compute( kernel = _gaussian_kernel_2d(channel, gauss_kernel_size, sigma, dtype, device) if not gaussian_kernel: - kernel = torch.ones((1, 1, *kernel_size)) / torch.prod(Tensor(kernel_size)) + kernel = torch.ones((channel, 1, *kernel_size), dtype=dtype, device=device) / torch.prod( + torch.tensor(kernel_size, dtype=dtype, device=device) + ) input_list = torch.cat((preds, target, preds * preds, target * target, preds * target)) # (5 * B, C, H, W) diff --git a/tests/unittests/image/test_ssim.py b/tests/unittests/image/test_ssim.py index 1ae4ff683cb..c5f5f4f6f33 100644 --- a/tests/unittests/image/test_ssim.py +++ b/tests/unittests/image/test_ssim.py @@ -125,6 +125,23 @@ def test_ssim(self, preds, target, sigma, ddp, dist_sync_on_step): dist_sync_on_step=dist_sync_on_step, ) + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_ssim_without_gaussian_kernel(self, preds, target, sigma, ddp, dist_sync_on_step): + self.run_class_metric_test( + ddp, + preds, + target, + StructuralSimilarityIndexMeasure, + partial(_sk_ssim, data_range=1.0, sigma=sigma, kernel_size=None), + metric_args={ + "gaussian_kernel": False, + "data_range": 1.0, + "sigma": sigma, + }, + dist_sync_on_step=dist_sync_on_step, + ) + def test_ssim_functional(self, preds, target, sigma): self.run_functional_metric_test( preds,