Skip to content

Commit

Permalink
Fix missing reset in classwise wrapper (#1129)
Browse files Browse the repository at this point in the history
* missing reset
* changelog

(cherry picked from commit 412002c)
  • Loading branch information
SkafteNicki authored and Borda committed Jul 22, 2022
1 parent cac8bb2 commit c9faed4
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 15 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Expand Up @@ -40,7 +40,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

-
- Fixed missing reset in `ClasswiseWrapper` ([#1129](https://github.com/Lightning-AI/metrics/pull/1129))


- Fixed JaccardIndex multi-label compute ([#1125](https://github.com/Lightning-AI/metrics/pull/1125))
Expand Down
36 changes: 22 additions & 14 deletions tests/wrappers/test_classwise.py
Expand Up @@ -15,27 +15,35 @@ def test_raises_error_on_wrong_input():

def test_output_no_labels():
"""Test that wrapper works with no label input."""
base = Accuracy(num_classes=3, average=None)
metric = ClasswiseWrapper(Accuracy(num_classes=3, average=None))
preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))
val = metric(preds, target)
assert isinstance(val, dict)
assert len(val) == 3
for i in range(3):
assert f"accuracy_{i}" in val
for _ in range(2):
preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))
val = metric(preds, target)
val_base = base(preds, target)
assert isinstance(val, dict)
assert len(val) == 3
for i in range(3):
assert f"accuracy_{i}" in val
assert val[f"accuracy_{i}"] == val_base[i]


def test_output_with_labels():
"""Test that wrapper works with label input."""
labels = ["horse", "fish", "cat"]
base = Accuracy(num_classes=3, average=None)
metric = ClasswiseWrapper(Accuracy(num_classes=3, average=None), labels=labels)
preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))
val = metric(preds, target)
assert isinstance(val, dict)
assert len(val) == 3
for lab in labels:
assert f"accuracy_{lab}" in val
for _ in range(2):
preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))
val = metric(preds, target)
val_base = base(preds, target)
assert isinstance(val, dict)
assert len(val) == 3
for i, lab in enumerate(labels):
assert f"accuracy_{lab}" in val
assert val[f"accuracy_{lab}"] == val_base[i]


@pytest.mark.parametrize("prefix", [None, "pre_"])
Expand Down
5 changes: 5 additions & 0 deletions torchmetrics/wrappers/classwise.py
Expand Up @@ -51,6 +51,8 @@ class ClasswiseWrapper(Metric):
'recall_horse': tensor(0.), 'recall_fish': tensor(0.3333), 'recall_dog': tensor(0.4000)}
"""

full_state_update: Optional[bool] = True

def __init__(self, metric: Metric, labels: Optional[List[str]] = None) -> None:
super().__init__()
if not isinstance(metric, Metric):
Expand All @@ -71,3 +73,6 @@ def update(self, *args: Any, **kwargs: Any) -> None:

def compute(self) -> Dict[str, Tensor]:
return self._convert(self.metric.compute())

def reset(self) -> None:
self.metric.reset()

0 comments on commit c9faed4

Please sign in to comment.