From bdb3e59ca7d9c737ff0d561bcec1dae239b79f8d Mon Sep 17 00:00:00 2001 From: Nicki Skafte Detlefsen Date: Tue, 28 Jun 2022 18:00:41 +0200 Subject: [PATCH] Add "none" option to AveragePrecision (#1116) --- CHANGELOG.md | 3 +++ .../functional/classification/average_precision.py | 4 ++-- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c8f9fad25a9..93fc6fdf72a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -49,6 +49,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Skip box conversion if no boxes are present in `MeanAveragePrecision` ([#1097](https://github.com/Lightning-AI/metrics/pull/1097)) +- Fixed inconsistency in docs and code when setting `average="none"` in `AvaragePrecision` metric ([#1116](https://github.com/Lightning-AI/metrics/pull/1116)) + + ## [0.9.1] - 2022-06-08 ### Added diff --git a/src/torchmetrics/functional/classification/average_precision.py b/src/torchmetrics/functional/classification/average_precision.py index 7228d70f523..26e47a0633b 100644 --- a/src/torchmetrics/functional/classification/average_precision.py +++ b/src/torchmetrics/functional/classification/average_precision.py @@ -167,9 +167,9 @@ def _average_precision_compute_with_precision_recall( return res[~torch.isnan(res)].mean() weights = torch.ones_like(res) if weights is None else weights return (res * weights)[~torch.isnan(res)].sum() - if average is None: + if average is None or average == "none": return res - allowed_average = ("micro", "macro", "weighted", None) + allowed_average = ("micro", "macro", "weighted", "none", None) raise ValueError(f"Expected argument `average` to be one of {allowed_average}" f" but got {average}")