Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Metric retrieval recall at precision #951

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
cb56a2a
add metric retrieval as precision
Apr 9, 2022
5e9e50d
recall_precision.py test_recall_precision.py
enuk1dze Apr 11, 2022
dee6aad
fix test
Apr 11, 2022
8b97a39
add adaptive_k add parameters
Apr 12, 2022
5804a12
test_test fixed
enuk1dze Apr 12, 2022
2900c12
del deprecated files
Apr 12, 2022
b11ba9a
fix
Apr 12, 2022
938f768
fix docs
Apr 12, 2022
5a3a93e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 12, 2022
3ce17bf
Merge branch 'master' into metric_retrieval_recall_at_precision
Apr 18, 2022
47c588d
fix bugs
Apr 18, 2022
b6f5b5b
added test sk_metric_function
Apr 18, 2022
f6997fe
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 18, 2022
e0ea387
WIP ~ added draft of test
enuk1dze Apr 18, 2022
c42ba88
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 18, 2022
4d1049a
added tests retrieval recall at precision
Apr 19, 2022
3830e0b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 19, 2022
e45d8a0
added fix doctest
Apr 19, 2022
9545a3c
Merge branch 'metric_retrieval_recall_at_precision' of https://github…
Apr 19, 2022
1654af4
added torch max_k_range
Apr 19, 2022
529d51f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 19, 2022
11a6fcc
doctest
Apr 19, 2022
63346f1
add TestRetrievalRecallAtPrecisionError test
enuk1dze Apr 19, 2022
a18c0fe
fix auto max_k problem
enuk1dze Apr 19, 2022
5d074b0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 19, 2022
7e8ffe0
fix test min_precision_error
Apr 19, 2022
6ef2e98
Merge branch 'master' into metric_retrieval_recall_at_precision
Borda Apr 20, 2022
08553d7
Merge branch 'master' into metric_retrieval_recall_at_precision
Borda Apr 25, 2022
e5e3a46
added functional recall_precision_curve.py fixed value_error
Apr 26, 2022
2fd9c58
Merge branch 'master' into metric_retrieval_recall_at_precision
MrShevan Apr 26, 2022
0b6687b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 26, 2022
47fa721
fixed pep8
Apr 26, 2022
24c88f7
bug fixed
Apr 27, 2022
801fe93
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 27, 2022
15c5870
added RetrievalPrecisionRecallCurve
Apr 27, 2022
8539097
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 27, 2022
975116e
fixed doctest
Apr 27, 2022
3ac3bfc
Update torchmetrics/retrieval/recall_precision.py
MrShevan Apr 28, 2022
acfafdb
Update torchmetrics/retrieval/recall_precision.py
MrShevan Apr 28, 2022
be2ed5f
Update torchmetrics/retrieval/recall_precision.py
MrShevan Apr 28, 2022
5f59f75
Update torchmetrics/retrieval/recall_precision.py
MrShevan Apr 28, 2022
7cd3cc1
Update torchmetrics/retrieval/recall_precision.py
MrShevan Apr 28, 2022
6a4c5d1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 28, 2022
f9d890d
bug fix
Apr 28, 2022
b78658b
Merge branch 'metric_retrieval_recall_at_precision' of https://github…
Apr 28, 2022
2e835b3
added dim_zero_cat
Apr 28, 2022
6295db7
Merge branch 'master' into metric_retrieval_recall_at_precision
MrShevan Apr 28, 2022
d94f510
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 28, 2022
f0aba5c
doctest
Apr 28, 2022
ebd9ee6
Merge branch 'master' into metric_retrieval_recall_at_precision
Borda Apr 28, 2022
7d1437a
Merge branch 'master' into metric_retrieval_recall_at_precision
MrShevan Apr 28, 2022
9a98548
Merge branch 'master' into metric_retrieval_recall_at_precision
MrShevan Apr 29, 2022
b9d0877
doctest
Apr 29, 2022
901e9fe
Merge branch 'master' into metric_retrieval_recall_at_precision
SkafteNicki Apr 29, 2022
b725e48
changelog
SkafteNicki Apr 29, 2022
8cf1c7c
docs
SkafteNicki Apr 29, 2022
f9dd87f
Update CHANGELOG.md
SkafteNicki Apr 29, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Expand Up @@ -11,7 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

-
- Added `RetrievalPrecisionRecallCurve` to retrieval package ([#951](https://github.com/PyTorchLightning/metrics/pull/951))
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved

- Added `RetrievalRecallAtFixedPrecision` to retrieval package ([#951](https://github.com/PyTorchLightning/metrics/pull/951))


-
Expand Down
22 changes: 22 additions & 0 deletions docs/source/retrieval/precision_recall_curve.rst
@@ -0,0 +1,22 @@
.. customcarditem::
:header: Precision Recall Curve
:image: https://pl-flash-data.s3.amazonaws.com/assets/thumbnails/text_classification.svg
:tags: Retrieval

.. include:: ../links.rst

######################
Precision Recall Curve
######################

Module Interface
________________

.. autoclass:: torchmetrics.RetrievalPrecisionRecallCurve
:noindex:

Functional Interface
____________________

.. autofunction:: torchmetrics.functional.retrieval_precision_recall_curve
:noindex:
180 changes: 180 additions & 0 deletions tests/retrieval/test_precision_recall_curve.py
@@ -0,0 +1,180 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import partial
from typing import Callable, Tuple, Union

import numpy as np
import pytest
import torch
from numpy import array
from torch import Tensor, tensor

from tests.helpers import seed_all
from tests.helpers.testers import Metric, MetricTester
from tests.retrieval.helpers import _default_metric_class_input_arguments, get_group_indexes
from tests.retrieval.test_precision import _precision_at_k
from tests.retrieval.test_recall import _recall_at_k
from torchmetrics import RetrievalPrecisionRecallCurve

seed_all(42)


def _compute_precision_recall_curve(
preds: Union[Tensor, array],
target: Union[Tensor, array],
indexes: Union[Tensor, array] = None,
max_k: int = None,
adaptive_k: bool = False,
ignore_index: int = None,
empty_target_action: str = "skip",
reverse: bool = False,
) -> Tuple[Tensor, Tensor, Tensor]:
"""Compute metric with multiple iterations over every query predictions set.

Didn't find a reliable implementation of precision-recall curve in Information Retrieval,
so, reimplementing here.

A good explanation can be found here:
`<https://nlp.stanford.edu/IR-book/pdf/08eval.pdf>_`. (part 8.4)
"""
recalls, precisions = [], []

if indexes is None:
indexes = np.full_like(preds, fill_value=0, dtype=np.int64)
if isinstance(indexes, Tensor):
indexes = indexes.cpu().numpy()
if isinstance(preds, Tensor):
preds = preds.cpu().numpy()
if isinstance(target, Tensor):
target = target.cpu().numpy()

assert isinstance(indexes, np.ndarray)
assert isinstance(preds, np.ndarray)
assert isinstance(target, np.ndarray)

if ignore_index is not None:
valid_positions = target != ignore_index
indexes, preds, target = indexes[valid_positions], preds[valid_positions], target[valid_positions]

indexes = indexes.flatten()
preds = preds.flatten()
target = target.flatten()
groups = get_group_indexes(indexes)

if max_k is None:
max_k = max(map(len, groups))

top_k = torch.arange(1, max_k + 1)

for group in groups:
trg, prd = target[group], preds[group]
r, p = [], []

if ((1 - trg) if reverse else trg).sum() == 0:
if empty_target_action == "skip":
pass
elif empty_target_action == "pos":
arr = [1.0] * max_k
recalls.append(arr)
precisions.append(arr)
elif empty_target_action == "neg":
arr = [0.0] * max_k
recalls.append(arr)
precisions.append(arr)

else:
for k in top_k:
r.append(_recall_at_k(trg, prd, k=k.item()))
p.append(_precision_at_k(trg, prd, k=k.item(), adaptive_k=adaptive_k))

recalls.append(r)
precisions.append(p)

if not recalls:
return torch.zeros(max_k), torch.zeros(max_k), top_k

recalls = tensor(recalls).mean(dim=0)
precisions = tensor(precisions).mean(dim=0)

return precisions, recalls, top_k


class RetrievalPrecisionRecallCurveTester(MetricTester):
def run_class_metric_test(
self,
ddp: bool,
indexes: Tensor,
preds: Tensor,
target: Tensor,
metric_class: Metric,
sk_metric: Callable,
dist_sync_on_step: bool,
metric_args: dict,
reverse: bool = False,
):
_sk_metric_adapted = partial(sk_metric, reverse=reverse, **metric_args)

super().run_class_metric_test(
ddp=ddp,
preds=preds,
target=target,
metric_class=metric_class,
sk_metric=_sk_metric_adapted,
dist_sync_on_step=dist_sync_on_step,
metric_args=metric_args,
fragment_kwargs=True,
indexes=indexes, # every additional argument will be passed to metric_class and _sk_metric_adapted
)


@pytest.mark.parametrize("ddp", [False])
@pytest.mark.parametrize("dist_sync_on_step", [False])
@pytest.mark.parametrize("empty_target_action", ["neg", "skip", "pos"])
@pytest.mark.parametrize("ignore_index", [None, 1]) # avoid setting 0, otherwise test with all 0 targets will fail
@pytest.mark.parametrize("max_k", [None, 1, 2, 5, 10])
@pytest.mark.parametrize("adaptive_k", [False, True])
@pytest.mark.parametrize(**_default_metric_class_input_arguments)
class TestRetrievalPrecisionRecallCurve(RetrievalPrecisionRecallCurveTester):
atol = 0.02

def test_class_metric(
self,
indexes,
preds,
target,
ddp,
dist_sync_on_step,
empty_target_action,
ignore_index,
max_k,
adaptive_k,
):
metric_args = dict(
max_k=max_k,
adaptive_k=adaptive_k,
empty_target_action=empty_target_action,
ignore_index=ignore_index,
)

self.run_class_metric_test(
ddp=ddp,
indexes=indexes,
preds=preds,
target=target,
metric_class=RetrievalPrecisionRecallCurve,
sk_metric=_compute_precision_recall_curve,
dist_sync_on_step=dist_sync_on_step,
metric_args=metric_args,
)
4 changes: 4 additions & 0 deletions torchmetrics/__init__.py
Expand Up @@ -80,7 +80,9 @@
RetrievalMRR,
RetrievalNormalizedDCG,
RetrievalPrecision,
RetrievalPrecisionRecallCurve,
RetrievalRecall,
RetrievalRecallAtFixedPrecision,
RetrievalRPrecision,
)
from torchmetrics.text import ( # noqa: E402
Expand Down Expand Up @@ -166,6 +168,8 @@
"RetrievalPrecision",
"RetrievalRecall",
"RetrievalRPrecision",
"RetrievalPrecisionRecallCurve",
"RetrievalRecallAtFixedPrecision",
"ROC",
"SacreBLEUScore",
"SignalDistortionRatio",
Expand Down
2 changes: 2 additions & 0 deletions torchmetrics/functional/__init__.py
Expand Up @@ -69,6 +69,7 @@
from torchmetrics.functional.retrieval.hit_rate import retrieval_hit_rate
from torchmetrics.functional.retrieval.ndcg import retrieval_normalized_dcg
from torchmetrics.functional.retrieval.precision import retrieval_precision
from torchmetrics.functional.retrieval.precision_recall_curve import retrieval_precision_recall_curve
from torchmetrics.functional.retrieval.r_precision import retrieval_r_precision
from torchmetrics.functional.retrieval.recall import retrieval_recall
from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank
Expand Down Expand Up @@ -144,6 +145,7 @@
"retrieval_r_precision",
"retrieval_recall",
"retrieval_reciprocal_rank",
"retrieval_precision_recall_curve",
"roc",
"rouge_score",
"sacre_bleu_score",
Expand Down
1 change: 1 addition & 0 deletions torchmetrics/functional/retrieval/__init__.py
Expand Up @@ -17,6 +17,7 @@
from torchmetrics.functional.retrieval.hit_rate import retrieval_hit_rate # noqa: F401
from torchmetrics.functional.retrieval.ndcg import retrieval_normalized_dcg # noqa: F401
from torchmetrics.functional.retrieval.precision import retrieval_precision # noqa: F401
from torchmetrics.functional.retrieval.precision_recall_curve import retrieval_precision_recall_curve # noqa: F401
from torchmetrics.functional.retrieval.r_precision import retrieval_r_precision # noqa: F401
from torchmetrics.functional.retrieval.recall import retrieval_recall # noqa: F401
from torchmetrics.functional.retrieval.reciprocal_rank import retrieval_reciprocal_rank # noqa: F401
97 changes: 97 additions & 0 deletions torchmetrics/functional/retrieval/precision_recall_curve.py
@@ -0,0 +1,97 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Tuple

import torch
from torch import Tensor, cumsum, tensor
from torch.nn.functional import pad

from torchmetrics.utilities.checks import _check_retrieval_functional_inputs


def retrieval_precision_recall_curve(
preds: Tensor, target: Tensor, max_k: Optional[int] = None, adaptive_k: bool = False
) -> Tuple[Tensor, Tensor, Tensor]:
"""Computes precision-recall pairs for different k (from 1 to `max_k`).

In a ranked retrieval context, appropriate sets of retrieved documents are naturally given by
the top k retrieved documents.

Recall is the fraction of relevant documents retrieved among all the relevant documents.
Precision is the fraction of relevant documents among all the retrieved documents.

For each such set, precision and recall values can be plotted to give a recall-precision
curve.

``preds`` and ``target`` should be of the same shape and live on the same device. If no ``target`` is ``True``,
``0`` is returned. ``target`` must be either `bool` or `integers` and ``preds`` must be ``float``,
otherwise an error is raised.

Args:
preds: estimated probabilities of each document to be relevant.
target: ground truth about each document being relevant or not.
max_k: Calculate recall and precision for all possible top k from 1 to max_k
(default: `None`, which considers all possible top k)
adaptive_k: adjust `max_k` to `min(max_k, number of documents)` for each query

Returns:
tensor with the precision values for each k (at ``k``) from 1 to `max_k`
tensor with the recall values for each k (at ``k``) from 1 to `max_k`
tensor with all possibles k

Raises:
ValueError:
If ``max_k`` is not `None` or an integer larger than 0.
ValueError:
If ``adaptive_k`` is not boolean.

Example:
>>> from torchmetrics.functional import retrieval_precision_recall_curve
>>> preds = tensor([0.2, 0.3, 0.5])
>>> target = tensor([True, False, True])
>>> precisions, recalls, top_k = retrieval_precision_recall_curve(preds, target, max_k=2)
>>> precisions
tensor([1.0000, 0.5000])
>>> recalls
tensor([0.5000, 0.5000])
>>> top_k
tensor([1, 2])
"""
preds, target = _check_retrieval_functional_inputs(preds, target)

if not isinstance(adaptive_k, bool):
raise ValueError("`adaptive_k` has to be a boolean")

if max_k is None:
max_k = preds.shape[-1]

if not (isinstance(max_k, int) and max_k > 0):
raise ValueError("`max_k` has to be a positive integer or None")

if adaptive_k and max_k > preds.shape[-1]:
topk = torch.arange(1, preds.shape[-1] + 1, device=preds.device)
topk = pad(topk, (0, max_k - preds.shape[-1]), "constant", float(preds.shape[-1]))
else:
topk = torch.arange(1, max_k + 1, device=preds.device)

if not target.sum():
return torch.zeros(max_k, device=preds.device), torch.zeros(max_k, device=preds.device), topk

relevant = target[preds.topk(min(max_k, preds.shape[-1]), dim=-1)[1]].float()
relevant = cumsum(pad(relevant, (0, max(0, max_k - len(relevant))), "constant", 0.0), dim=0)

recall = relevant / target.sum()
precision = relevant / topk

return precision, recall, topk
4 changes: 4 additions & 0 deletions torchmetrics/retrieval/__init__.py
Expand Up @@ -17,6 +17,10 @@
from torchmetrics.retrieval.hit_rate import RetrievalHitRate # noqa: F401
from torchmetrics.retrieval.ndcg import RetrievalNormalizedDCG # noqa: F401
from torchmetrics.retrieval.precision import RetrievalPrecision # noqa: F401
from torchmetrics.retrieval.precision_recall_curve import ( # noqa: F401
RetrievalPrecisionRecallCurve,
RetrievalRecallAtFixedPrecision,
)
from torchmetrics.retrieval.r_precision import RetrievalRPrecision # noqa: F401
from torchmetrics.retrieval.recall import RetrievalRecall # noqa: F401
from torchmetrics.retrieval.reciprocal_rank import RetrievalMRR # noqa: F401