Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: (SSIM) propagate device if gaussian_kernel is False, add test #1149

Merged
merged 10 commits into from Jul 19, 2022
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -50,6 +50,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed JaccardIndex multi-label compute ([#1125](https://github.com/Lightning-AI/metrics/pull/1125))


- Fix SSIM propagate device if `gaussian_kernel` is False, add test ([#1149](https://github.com/Lightning-AI/metrics/pull/1149))



## [0.9.2] - 2022-06-29

Expand Down
4 changes: 3 additions & 1 deletion src/torchmetrics/functional/image/ssim.py
Expand Up @@ -150,7 +150,9 @@ def _ssim_compute(
kernel = _gaussian_kernel_2d(channel, gauss_kernel_size, sigma, dtype, device)

if not gaussian_kernel:
kernel = torch.ones((1, 1, *kernel_size)) / torch.prod(Tensor(kernel_size))
kernel = torch.ones((channel, 1, *kernel_size), dtype=dtype, device=device) / torch.prod(
torch.tensor(kernel_size, dtype=dtype, device=device)
)
Borda marked this conversation as resolved.
Show resolved Hide resolved

input_list = torch.cat((preds, target, preds * preds, target * target, preds * target)) # (5 * B, C, H, W)

Expand Down
17 changes: 17 additions & 0 deletions tests/unittests/image/test_ssim.py
Expand Up @@ -125,6 +125,23 @@ def test_ssim(self, preds, target, sigma, ddp, dist_sync_on_step):
dist_sync_on_step=dist_sync_on_step,
)

@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("dist_sync_on_step", [True, False])
def test_ssim_without_gaussian_kernel(self, preds, target, sigma, ddp, dist_sync_on_step):
self.run_class_metric_test(
ddp,
preds,
target,
StructuralSimilarityIndexMeasure,
partial(_sk_ssim, data_range=1.0, sigma=sigma, kernel_size=None),
metric_args={
"gaussian_kernel": False,
"data_range": 1.0,
"sigma": sigma,
},
dist_sync_on_step=dist_sync_on_step,
)

def test_ssim_functional(self, preds, target, sigma):
self.run_functional_metric_test(
preds,
Expand Down