From a5498e8843c7f2f0639578b88b600d52a29c7c62 Mon Sep 17 00:00:00 2001 From: Kushashwa Ravi Shrimali Date: Sat, 16 Jul 2022 20:16:41 +0530 Subject: [PATCH 01/10] Fix: propagate device if gaussian_kernel is False, add test --- src/torchmetrics/functional/image/ssim.py | 2 +- tests/unittests/image/test_ssim.py | 18 ++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/image/ssim.py b/src/torchmetrics/functional/image/ssim.py index 2c775994018..c10afe52d7d 100644 --- a/src/torchmetrics/functional/image/ssim.py +++ b/src/torchmetrics/functional/image/ssim.py @@ -150,7 +150,7 @@ def _ssim_compute( kernel = _gaussian_kernel_2d(channel, gauss_kernel_size, sigma, dtype, device) if not gaussian_kernel: - kernel = torch.ones((1, 1, *kernel_size)) / torch.prod(Tensor(kernel_size)) + kernel = torch.ones((1, 1, *kernel_size), device=device) / torch.prod(Tensor(kernel_size, device=device)) input_list = torch.cat((preds, target, preds * preds, target * target, preds * target)) # (5 * B, C, H, W) diff --git a/tests/unittests/image/test_ssim.py b/tests/unittests/image/test_ssim.py index 1ae4ff683cb..cfe771003f7 100644 --- a/tests/unittests/image/test_ssim.py +++ b/tests/unittests/image/test_ssim.py @@ -125,6 +125,24 @@ def test_ssim(self, preds, target, sigma, ddp, dist_sync_on_step): dist_sync_on_step=dist_sync_on_step, ) + @pytest.mark.parametrize("ddp", [True, False]) + @pytest.mark.parametrize("dist_sync_on_step", [True, False]) + def test_ssim_without_gaussian_kernel(self, preds, target, sigma, ddp, dist_sync_on_step): + self.run_class_metric_test( + ddp, + preds, + target, + StructuralSimilarityIndexMeasure, + partial(_sk_ssim, data_range=1.0, sigma=sigma, kernel_size=None), + metric_args={ + "gaussian_kernel": False, + "data_range": 1.0, + "sigma": sigma, + }, + dist_sync_on_step=dist_sync_on_step, + ) + + def test_ssim_functional(self, preds, target, sigma): self.run_functional_metric_test( preds, From 6689814a8e2848d6234fea49a35d52eed962ac07 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 16 Jul 2022 14:49:16 +0000 Subject: [PATCH 02/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/unittests/image/test_ssim.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/unittests/image/test_ssim.py b/tests/unittests/image/test_ssim.py index cfe771003f7..c5f5f4f6f33 100644 --- a/tests/unittests/image/test_ssim.py +++ b/tests/unittests/image/test_ssim.py @@ -142,7 +142,6 @@ def test_ssim_without_gaussian_kernel(self, preds, target, sigma, ddp, dist_sync dist_sync_on_step=dist_sync_on_step, ) - def test_ssim_functional(self, preds, target, sigma): self.run_functional_metric_test( preds, From d222dab1edfb6e5d3694e7937541cb011016a52b Mon Sep 17 00:00:00 2001 From: Kushashwa Ravi Shrimali Date: Sat, 16 Jul 2022 20:40:31 +0530 Subject: [PATCH 03/10] propagate dtype and channel if not gaussian_kernel --- src/torchmetrics/functional/image/ssim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/image/ssim.py b/src/torchmetrics/functional/image/ssim.py index c10afe52d7d..ca8ca60c6a2 100644 --- a/src/torchmetrics/functional/image/ssim.py +++ b/src/torchmetrics/functional/image/ssim.py @@ -150,7 +150,7 @@ def _ssim_compute( kernel = _gaussian_kernel_2d(channel, gauss_kernel_size, sigma, dtype, device) if not gaussian_kernel: - kernel = torch.ones((1, 1, *kernel_size), device=device) / torch.prod(Tensor(kernel_size, device=device)) + kernel = torch.ones((channel, 1, *kernel_size), dtype=dtype, device=device) / torch.prod(torch.tensor(kernel_size, dtype=dtype, device=device)) input_list = torch.cat((preds, target, preds * preds, target * target, preds * target)) # (5 * B, C, H, W) From 186193f740c4202370c36d29da6e0674b32feeae Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 16 Jul 2022 15:12:58 +0000 Subject: [PATCH 04/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/functional/image/ssim.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/image/ssim.py b/src/torchmetrics/functional/image/ssim.py index ca8ca60c6a2..57636469867 100644 --- a/src/torchmetrics/functional/image/ssim.py +++ b/src/torchmetrics/functional/image/ssim.py @@ -150,7 +150,9 @@ def _ssim_compute( kernel = _gaussian_kernel_2d(channel, gauss_kernel_size, sigma, dtype, device) if not gaussian_kernel: - kernel = torch.ones((channel, 1, *kernel_size), dtype=dtype, device=device) / torch.prod(torch.tensor(kernel_size, dtype=dtype, device=device)) + kernel = torch.ones((channel, 1, *kernel_size), dtype=dtype, device=device) / torch.prod( + torch.tensor(kernel_size, dtype=dtype, device=device) + ) input_list = torch.cat((preds, target, preds * preds, target * target, preds * target)) # (5 * B, C, H, W) From e9ceac831223e62902acc0225af85d1b0417e744 Mon Sep 17 00:00:00 2001 From: Jirka Date: Mon, 18 Jul 2022 13:12:01 +0200 Subject: [PATCH 05/10] chlog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6eac7de9e84..5f8125863b0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -50,6 +50,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed JaccardIndex multi-label compute ([#1125](https://github.com/Lightning-AI/metrics/pull/1125)) +- Fix SSIM propagate device if `gaussian_kernel` is False, add test ([#1149](https://github.com/Lightning-AI/metrics/pull/1149)) + + ## [0.9.2] - 2022-06-29 From 77a3a9372cfa703700d1e574c6db7aca48a0e1e5 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 18 Jul 2022 14:31:21 +0200 Subject: [PATCH 06/10] Apply suggestions from code review Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> --- src/torchmetrics/functional/image/ssim.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/torchmetrics/functional/image/ssim.py b/src/torchmetrics/functional/image/ssim.py index 57636469867..3172cf953be 100644 --- a/src/torchmetrics/functional/image/ssim.py +++ b/src/torchmetrics/functional/image/ssim.py @@ -150,9 +150,7 @@ def _ssim_compute( kernel = _gaussian_kernel_2d(channel, gauss_kernel_size, sigma, dtype, device) if not gaussian_kernel: - kernel = torch.ones((channel, 1, *kernel_size), dtype=dtype, device=device) / torch.prod( - torch.tensor(kernel_size, dtype=dtype, device=device) - ) + kernel = torch.ones((channel, 1, *kernel_size), dtype=dtype, device=device) / math.prod(kernel_size) input_list = torch.cat((preds, target, preds * preds, target * target, preds * target)) # (5 * B, C, H, W) From 244a62de8f1e7c3956f3776d805f3fe13d5cb404 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 18 Jul 2022 14:31:49 +0200 Subject: [PATCH 07/10] math --- src/torchmetrics/functional/image/ssim.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/torchmetrics/functional/image/ssim.py b/src/torchmetrics/functional/image/ssim.py index 3172cf953be..5449ad1efe6 100644 --- a/src/torchmetrics/functional/image/ssim.py +++ b/src/torchmetrics/functional/image/ssim.py @@ -13,6 +13,7 @@ # limitations under the License. from typing import List, Optional, Sequence, Tuple, Union +import math import torch from torch import Tensor from torch.nn import functional as F From aa5d804d1e578c3d928f91dfd6a2f24e3a1bd45f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 18 Jul 2022 12:32:27 +0000 Subject: [PATCH 08/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/functional/image/ssim.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/image/ssim.py b/src/torchmetrics/functional/image/ssim.py index 5449ad1efe6..b8c8e8f2b54 100644 --- a/src/torchmetrics/functional/image/ssim.py +++ b/src/torchmetrics/functional/image/ssim.py @@ -11,9 +11,9 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import math from typing import List, Optional, Sequence, Tuple, Union -import math import torch from torch import Tensor from torch.nn import functional as F From 84f347d17ca12d8ae82ab4037c198e9440f4eff2 Mon Sep 17 00:00:00 2001 From: Kushashwa Ravi Shrimali Date: Mon, 18 Jul 2022 22:36:42 +0530 Subject: [PATCH 09/10] Use torch.prod --- src/torchmetrics/functional/image/ssim.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/torchmetrics/functional/image/ssim.py b/src/torchmetrics/functional/image/ssim.py index b8c8e8f2b54..ca8ca60c6a2 100644 --- a/src/torchmetrics/functional/image/ssim.py +++ b/src/torchmetrics/functional/image/ssim.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import math from typing import List, Optional, Sequence, Tuple, Union import torch @@ -151,7 +150,7 @@ def _ssim_compute( kernel = _gaussian_kernel_2d(channel, gauss_kernel_size, sigma, dtype, device) if not gaussian_kernel: - kernel = torch.ones((channel, 1, *kernel_size), dtype=dtype, device=device) / math.prod(kernel_size) + kernel = torch.ones((channel, 1, *kernel_size), dtype=dtype, device=device) / torch.prod(torch.tensor(kernel_size, dtype=dtype, device=device)) input_list = torch.cat((preds, target, preds * preds, target * target, preds * target)) # (5 * B, C, H, W) From 2d6cb05c4fbf290b5b899da97ea7d9409c3d2f4a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 18 Jul 2022 17:07:19 +0000 Subject: [PATCH 10/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/torchmetrics/functional/image/ssim.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/torchmetrics/functional/image/ssim.py b/src/torchmetrics/functional/image/ssim.py index ca8ca60c6a2..57636469867 100644 --- a/src/torchmetrics/functional/image/ssim.py +++ b/src/torchmetrics/functional/image/ssim.py @@ -150,7 +150,9 @@ def _ssim_compute( kernel = _gaussian_kernel_2d(channel, gauss_kernel_size, sigma, dtype, device) if not gaussian_kernel: - kernel = torch.ones((channel, 1, *kernel_size), dtype=dtype, device=device) / torch.prod(torch.tensor(kernel_size, dtype=dtype, device=device)) + kernel = torch.ones((channel, 1, *kernel_size), dtype=dtype, device=device) / torch.prod( + torch.tensor(kernel_size, dtype=dtype, device=device) + ) input_list = torch.cat((preds, target, preds * preds, target * target, preds * target)) # (5 * B, C, H, W)