From 44becaa313cbc4f81702571a9d99b6130cb1d11e Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Mon, 6 Jun 2022 11:16:01 +0200 Subject: [PATCH 1/4] update --- tests/bases/test_aggregation.py | 1 + torchmetrics/aggregation.py | 37 ++++++++++++++++++++------------- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/tests/bases/test_aggregation.py b/tests/bases/test_aggregation.py index 106621e9cb4..4a3757bcc61 100644 --- a/tests/bases/test_aggregation.py +++ b/tests/bases/test_aggregation.py @@ -137,6 +137,7 @@ def test_nan_error(value, nan_strategy, metric_class): (CatMetric, 2.0, _case1, torch.tensor([2.0, 2.0, 2.0, 2.0, 2.0])), (CatMetric, "ignore", _case2, torch.tensor([1.0, 2.0, 4.0, 5.0])), (CatMetric, 2.0, _case2, torch.tensor([1.0, 2.0, 2.0, 4.0, 5.0])), + (CatMetric, "ignore", torch.zeros(5), torch.zeros(5)), ], ) def test_nan_expected(metric_class, nan_strategy, value, expected): diff --git a/torchmetrics/aggregation.py b/torchmetrics/aggregation.py index 0833646d4eb..8c7f92b2ed4 100644 --- a/torchmetrics/aggregation.py +++ b/torchmetrics/aggregation.py @@ -43,6 +43,7 @@ class BaseAggregator(Metric): value: Tensor is_differentiable = None higher_is_better = None + full_state_update = False def __init__( self, @@ -116,6 +117,8 @@ class MaxMetric(BaseAggregator): tensor(3.) """ + full_state_update = True + def __init__( self, nan_strategy: Union[str, float] = "warn", @@ -136,7 +139,7 @@ def update(self, value: Union[float, Tensor]) -> None: # type: ignore dimensions will be flattened """ value = self._cast_and_nan_check_input(value) - if any(value.flatten()): # make sure tensor not empty + if value.numel() != 0: # make sure tensor not empty self.value = torch.max(self.value, torch.max(value)) @@ -165,6 +168,8 @@ class MinMetric(BaseAggregator): tensor(1.) """ + full_state_update = True + def __init__( self, nan_strategy: Union[str, float] = "warn", @@ -185,7 +190,7 @@ def update(self, value: Union[float, Tensor]) -> None: # type: ignore dimensions will be flattened """ value = self._cast_and_nan_check_input(value) - if any(value.flatten()): # make sure tensor not empty + if value.numel() != 0: # make sure tensor not empty self.value = torch.min(self.value, torch.min(value)) @@ -234,7 +239,8 @@ def update(self, value: Union[float, Tensor]) -> None: # type: ignore dimensions will be flattened """ value = self._cast_and_nan_check_input(value) - self.value += value.sum() + if value.numel() != 0: + self.value += value.sum() class CatMetric(BaseAggregator): @@ -277,7 +283,7 @@ def update(self, value: Union[float, Tensor]) -> None: # type: ignore dimensions will be flattened """ value = self._cast_and_nan_check_input(value) - if any(value.flatten()): + if value.numel() != 0: self.value.append(value) def compute(self) -> Tensor: @@ -339,17 +345,18 @@ def update(self, value: Union[float, Tensor], weight: Union[float, Tensor] = 1.0 value = self._cast_and_nan_check_input(value) weight = self._cast_and_nan_check_input(weight) - # broadcast weight to values shape - if not hasattr(torch, "broadcast_to"): - if weight.shape == (): - weight = torch.ones_like(value) * weight - if weight.shape != value.shape: - raise ValueError("Broadcasting not supported on PyTorch <1.8") - else: - weight = torch.broadcast_to(weight, value.shape) - - self.value += (value * weight).sum() - self.weight += weight.sum() + if value.numel() != 0: + # broadcast weight to values shape + if not hasattr(torch, "broadcast_to"): + if weight.shape == (): + weight = torch.ones_like(value) * weight + if weight.shape != value.shape: + raise ValueError("Broadcasting not supported on PyTorch <1.8") + else: + weight = torch.broadcast_to(weight, value.shape) + + self.value += (value * weight).sum() + self.weight += weight.sum() def compute(self) -> Tensor: """Compute the aggregated value.""" From a6d8edd6d5e2d4377d5441aca98d07adc97b1b7f Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Mon, 6 Jun 2022 11:19:36 +0200 Subject: [PATCH 2/4] changelog --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ee1bc72d827..53d07d5588d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,7 +39,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- +- Fixed aggregation metrics when input only contains zero ([#1070](https://github.com/PyTorchLightning/metrics/pull/1070)) - From 1164ac255f99202dc13a72668ca099d4e7a350ca Mon Sep 17 00:00:00 2001 From: Jirka Date: Mon, 6 Jun 2022 14:26:58 +0200 Subject: [PATCH 3/4] simple --- torchmetrics/aggregation.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/torchmetrics/aggregation.py b/torchmetrics/aggregation.py index 8c7f92b2ed4..0e3d782f886 100644 --- a/torchmetrics/aggregation.py +++ b/torchmetrics/aggregation.py @@ -345,18 +345,19 @@ def update(self, value: Union[float, Tensor], weight: Union[float, Tensor] = 1.0 value = self._cast_and_nan_check_input(value) weight = self._cast_and_nan_check_input(weight) - if value.numel() != 0: - # broadcast weight to values shape - if not hasattr(torch, "broadcast_to"): - if weight.shape == (): - weight = torch.ones_like(value) * weight - if weight.shape != value.shape: - raise ValueError("Broadcasting not supported on PyTorch <1.8") - else: - weight = torch.broadcast_to(weight, value.shape) - - self.value += (value * weight).sum() - self.weight += weight.sum() + if value.numel() == 0: + return + # broadcast weight to value shape + if hasattr(torch, "broadcast_to"): + weight = torch.broadcast_to(weight, value.shape) + else: + if weight.shape == (): + weight = torch.ones_like(value) * weight + if weight.shape != value.shape: + raise ValueError("Broadcasting not supported on PyTorch <1.8") + + self.value += (value * weight).sum() + self.weight += weight.sum() def compute(self) -> Tensor: """Compute the aggregated value.""" From 11db4f34eb101635c7ae72c636ac21ef4fcd12a8 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 7 Jun 2022 12:26:07 +0200 Subject: [PATCH 4/4] Apply suggestions from code review Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com> --- torchmetrics/aggregation.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchmetrics/aggregation.py b/torchmetrics/aggregation.py index 0e3d782f886..ac06452da57 100644 --- a/torchmetrics/aggregation.py +++ b/torchmetrics/aggregation.py @@ -139,7 +139,7 @@ def update(self, value: Union[float, Tensor]) -> None: # type: ignore dimensions will be flattened """ value = self._cast_and_nan_check_input(value) - if value.numel() != 0: # make sure tensor not empty + if value.numel(): # make sure tensor not empty self.value = torch.max(self.value, torch.max(value)) @@ -190,7 +190,7 @@ def update(self, value: Union[float, Tensor]) -> None: # type: ignore dimensions will be flattened """ value = self._cast_and_nan_check_input(value) - if value.numel() != 0: # make sure tensor not empty + if value.numel(): # make sure tensor not empty self.value = torch.min(self.value, torch.min(value)) @@ -239,7 +239,7 @@ def update(self, value: Union[float, Tensor]) -> None: # type: ignore dimensions will be flattened """ value = self._cast_and_nan_check_input(value) - if value.numel() != 0: + if value.numel(): self.value += value.sum() @@ -283,7 +283,7 @@ def update(self, value: Union[float, Tensor]) -> None: # type: ignore dimensions will be flattened """ value = self._cast_and_nan_check_input(value) - if value.numel() != 0: + if value.numel(): self.value.append(value) def compute(self) -> Tensor: