diff --git a/CHANGELOG.md b/CHANGELOG.md index c8f9fad25a9..93fc6fdf72a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -49,6 +49,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Skip box conversion if no boxes are present in `MeanAveragePrecision` ([#1097](https://github.com/Lightning-AI/metrics/pull/1097)) +- Fixed inconsistency in docs and code when setting `average="none"` in `AvaragePrecision` metric ([#1116](https://github.com/Lightning-AI/metrics/pull/1116)) + + ## [0.9.1] - 2022-06-08 ### Added diff --git a/torchmetrics/functional/classification/average_precision.py b/torchmetrics/functional/classification/average_precision.py index db256cd00d0..aa3d101aec0 100644 --- a/torchmetrics/functional/classification/average_precision.py +++ b/torchmetrics/functional/classification/average_precision.py @@ -170,9 +170,9 @@ def _average_precision_compute_with_precision_recall( return res[~torch.isnan(res)].mean() weights = torch.ones_like(res) if weights is None else weights return (res * weights)[~torch.isnan(res)].sum() - if average is None: + if average is None or average == "none": return res - allowed_average = ("micro", "macro", "weighted", None) + allowed_average = ("micro", "macro", "weighted", "none", None) raise ValueError(f"Expected argument `average` to be one of {allowed_average}" f" but got {average}")