diff --git a/CHANGELOG.md b/CHANGELOG.md index 93316f2ec29..6fcd68d4b8d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,7 +39,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- +- Fixed aggregation metrics when input only contains zero ([#1070](https://github.com/PyTorchLightning/metrics/pull/1070)) - diff --git a/tests/bases/test_aggregation.py b/tests/bases/test_aggregation.py index 106621e9cb4..4a3757bcc61 100644 --- a/tests/bases/test_aggregation.py +++ b/tests/bases/test_aggregation.py @@ -137,6 +137,7 @@ def test_nan_error(value, nan_strategy, metric_class): (CatMetric, 2.0, _case1, torch.tensor([2.0, 2.0, 2.0, 2.0, 2.0])), (CatMetric, "ignore", _case2, torch.tensor([1.0, 2.0, 4.0, 5.0])), (CatMetric, 2.0, _case2, torch.tensor([1.0, 2.0, 2.0, 4.0, 5.0])), + (CatMetric, "ignore", torch.zeros(5), torch.zeros(5)), ], ) def test_nan_expected(metric_class, nan_strategy, value, expected): diff --git a/torchmetrics/aggregation.py b/torchmetrics/aggregation.py index 0833646d4eb..ac06452da57 100644 --- a/torchmetrics/aggregation.py +++ b/torchmetrics/aggregation.py @@ -43,6 +43,7 @@ class BaseAggregator(Metric): value: Tensor is_differentiable = None higher_is_better = None + full_state_update = False def __init__( self, @@ -116,6 +117,8 @@ class MaxMetric(BaseAggregator): tensor(3.) """ + full_state_update = True + def __init__( self, nan_strategy: Union[str, float] = "warn", @@ -136,7 +139,7 @@ def update(self, value: Union[float, Tensor]) -> None: # type: ignore dimensions will be flattened """ value = self._cast_and_nan_check_input(value) - if any(value.flatten()): # make sure tensor not empty + if value.numel(): # make sure tensor not empty self.value = torch.max(self.value, torch.max(value)) @@ -165,6 +168,8 @@ class MinMetric(BaseAggregator): tensor(1.) """ + full_state_update = True + def __init__( self, nan_strategy: Union[str, float] = "warn", @@ -185,7 +190,7 @@ def update(self, value: Union[float, Tensor]) -> None: # type: ignore dimensions will be flattened """ value = self._cast_and_nan_check_input(value) - if any(value.flatten()): # make sure tensor not empty + if value.numel(): # make sure tensor not empty self.value = torch.min(self.value, torch.min(value)) @@ -234,7 +239,8 @@ def update(self, value: Union[float, Tensor]) -> None: # type: ignore dimensions will be flattened """ value = self._cast_and_nan_check_input(value) - self.value += value.sum() + if value.numel(): + self.value += value.sum() class CatMetric(BaseAggregator): @@ -277,7 +283,7 @@ def update(self, value: Union[float, Tensor]) -> None: # type: ignore dimensions will be flattened """ value = self._cast_and_nan_check_input(value) - if any(value.flatten()): + if value.numel(): self.value.append(value) def compute(self) -> Tensor: @@ -339,14 +345,16 @@ def update(self, value: Union[float, Tensor], weight: Union[float, Tensor] = 1.0 value = self._cast_and_nan_check_input(value) weight = self._cast_and_nan_check_input(weight) - # broadcast weight to values shape - if not hasattr(torch, "broadcast_to"): + if value.numel() == 0: + return + # broadcast weight to value shape + if hasattr(torch, "broadcast_to"): + weight = torch.broadcast_to(weight, value.shape) + else: if weight.shape == (): weight = torch.ones_like(value) * weight if weight.shape != value.shape: raise ValueError("Broadcasting not supported on PyTorch <1.8") - else: - weight = torch.broadcast_to(weight, value.shape) self.value += (value * weight).sum() self.weight += weight.sum()