From cac8bb21238277a0ea2612b70133b6cfceb91037 Mon Sep 17 00:00:00 2001 From: KeVoyer1 <106331090+kevoyer1@users.noreply.github.com> Date: Sun, 10 Jul 2022 13:52:17 +0200 Subject: [PATCH] Fix: JaccardIndex multi-label compute (#1125) (cherry picked from commit 883254e2563c61ada9a1b57aa15414893fc7b05e) --- CHANGELOG.md | 9 ++++++++ torchmetrics/classification/jaccard.py | 29 +++++++++++++++++++------- 2 files changed, 31 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 976534fe6ef..9b1113b0fa1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -38,6 +38,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - +### Fixed + +- + + +- Fixed JaccardIndex multi-label compute ([#1125](https://github.com/Lightning-AI/metrics/pull/1125)) + + + ## [0.9.2] - 2022-06-29 ### Fixed diff --git a/torchmetrics/classification/jaccard.py b/torchmetrics/classification/jaccard.py index d088f6e0702..354a73c762d 100644 --- a/torchmetrics/classification/jaccard.py +++ b/torchmetrics/classification/jaccard.py @@ -104,10 +104,25 @@ def __init__( def compute(self) -> Tensor: """Computes intersection over union (IoU)""" - return _jaccard_from_confmat( - self.confmat, - self.num_classes, - self.average, - self.ignore_index, - self.absent_score, - ) + + if self.multilabel: + return torch.stack( + [ + _jaccard_from_confmat( + confmat, + 2, + self.average, + self.ignore_index, + self.absent_score, + )[1] + for confmat in self.confmat + ] + ) + else: + return _jaccard_from_confmat( + self.confmat, + self.num_classes, + self.average, + self.ignore_index, + self.absent_score, + )