diff --git a/CHANGELOG.md b/CHANGELOG.md index 9b1113b0fa1..18840193b2e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -40,7 +40,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed -- +- Fixed missing reset in `ClasswiseWrapper` ([#1129](https://github.com/Lightning-AI/metrics/pull/1129)) - Fixed JaccardIndex multi-label compute ([#1125](https://github.com/Lightning-AI/metrics/pull/1125)) diff --git a/tests/wrappers/test_classwise.py b/tests/wrappers/test_classwise.py index 3d7155eb59d..9e2891b1ad7 100644 --- a/tests/wrappers/test_classwise.py +++ b/tests/wrappers/test_classwise.py @@ -15,27 +15,35 @@ def test_raises_error_on_wrong_input(): def test_output_no_labels(): """Test that wrapper works with no label input.""" + base = Accuracy(num_classes=3, average=None) metric = ClasswiseWrapper(Accuracy(num_classes=3, average=None)) - preds = torch.randn(10, 3).softmax(dim=-1) - target = torch.randint(3, (10,)) - val = metric(preds, target) - assert isinstance(val, dict) - assert len(val) == 3 - for i in range(3): - assert f"accuracy_{i}" in val + for _ in range(2): + preds = torch.randn(10, 3).softmax(dim=-1) + target = torch.randint(3, (10,)) + val = metric(preds, target) + val_base = base(preds, target) + assert isinstance(val, dict) + assert len(val) == 3 + for i in range(3): + assert f"accuracy_{i}" in val + assert val[f"accuracy_{i}"] == val_base[i] def test_output_with_labels(): """Test that wrapper works with label input.""" labels = ["horse", "fish", "cat"] + base = Accuracy(num_classes=3, average=None) metric = ClasswiseWrapper(Accuracy(num_classes=3, average=None), labels=labels) - preds = torch.randn(10, 3).softmax(dim=-1) - target = torch.randint(3, (10,)) - val = metric(preds, target) - assert isinstance(val, dict) - assert len(val) == 3 - for lab in labels: - assert f"accuracy_{lab}" in val + for _ in range(2): + preds = torch.randn(10, 3).softmax(dim=-1) + target = torch.randint(3, (10,)) + val = metric(preds, target) + val_base = base(preds, target) + assert isinstance(val, dict) + assert len(val) == 3 + for i, lab in enumerate(labels): + assert f"accuracy_{lab}" in val + assert val[f"accuracy_{lab}"] == val_base[i] @pytest.mark.parametrize("prefix", [None, "pre_"]) diff --git a/torchmetrics/wrappers/classwise.py b/torchmetrics/wrappers/classwise.py index e652e8c1351..b24c3c9429e 100644 --- a/torchmetrics/wrappers/classwise.py +++ b/torchmetrics/wrappers/classwise.py @@ -51,6 +51,8 @@ class ClasswiseWrapper(Metric): 'recall_horse': tensor(0.), 'recall_fish': tensor(0.3333), 'recall_dog': tensor(0.4000)} """ + full_state_update: Optional[bool] = True + def __init__(self, metric: Metric, labels: Optional[List[str]] = None) -> None: super().__init__() if not isinstance(metric, Metric): @@ -71,3 +73,6 @@ def update(self, *args: Any, **kwargs: Any) -> None: def compute(self) -> Dict[str, Tensor]: return self._convert(self.metric.compute()) + + def reset(self) -> None: + self.metric.reset()