diff --git a/CHANGELOG.md b/CHANGELOG.md index 38dc6b7a63..ff40730540 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- MRR metrics calculation ([#886](https://github.com/catalyst-team/catalyst/pull/886)) - docs for MetricCallbacks ([#947](https://github.com/catalyst-team/catalyst/pull/947)) - SoftMax, CosFace, ArcFace layers to contrib ([#939](https://github.com/catalyst-team/catalyst/pull/939)) @@ -90,7 +91,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ## [20.08] - 2020-08-09 ### Added -- Full metric learning pipeline including training and validation stages ([#886](https://github.com/catalyst-team/catalyst/pull/876)) - `CMCScoreCallback` ([#880](https://github.com/catalyst-team/catalyst/pull/880)) - kornia augmentations `BatchTransformCallback` ([#862](https://github.com/catalyst-team/catalyst/issues/862)) - `average_precision` and `mean_average_precision` metrics ([#883](https://github.com/catalyst-team/catalyst/pull/883)) diff --git a/catalyst/dl/callbacks/metrics/__init__.py b/catalyst/dl/callbacks/metrics/__init__.py index fd61543a97..89ddc35268 100644 --- a/catalyst/dl/callbacks/metrics/__init__.py +++ b/catalyst/dl/callbacks/metrics/__init__.py @@ -18,6 +18,7 @@ IouCallback, JaccardCallback, ) +from catalyst.dl.callbacks.metrics.mrr import MRRCallback from catalyst.dl.callbacks.metrics.ppv_tpr_f1 import ( PrecisionRecallF1ScoreCallback, ) diff --git a/catalyst/dl/callbacks/metrics/mrr.py b/catalyst/dl/callbacks/metrics/mrr.py new file mode 100644 index 0000000000..29ea0c003b --- /dev/null +++ b/catalyst/dl/callbacks/metrics/mrr.py @@ -0,0 +1,34 @@ +from catalyst.core import MetricCallback +from catalyst.utils import metrics + + +class MRRCallback(MetricCallback): + """Calculates the AUC per class for each loader. + + .. note:: + Currently, supports binary and multi-label cases. + """ + + def __init__( + self, + input_key: str = "targets", + output_key: str = "logits", + prefix: str = "mrr", + ): + """ + Args: + input_key (str): input key to use for mrr calculation + specifies our ``y_true`` + output_key (str): output key to use for mrr calculation; + specifies our ``y_pred`` + prefix (str): name to display for mrr when printing + """ + super().__init__( + prefix=prefix, + metric_fn=metrics.mrr, + input_key=input_key, + output_key=output_key, + ) + + +__all__ = ["MRRCallback"] diff --git a/catalyst/dl/callbacks/metrics/tests/__init__.py b/catalyst/dl/callbacks/metrics/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/catalyst/metrics/__init__.py b/catalyst/metrics/__init__.py index e69de29bb2..24f7b086b3 100644 --- a/catalyst/metrics/__init__.py +++ b/catalyst/metrics/__init__.py @@ -0,0 +1,2 @@ +# flake8: noqa +from catalyst.metrics.mrr import mrr diff --git a/catalyst/metrics/mrr.py b/catalyst/metrics/mrr.py new file mode 100644 index 0000000000..cce322e27b --- /dev/null +++ b/catalyst/metrics/mrr.py @@ -0,0 +1,53 @@ +""" +MRR metric. +""" + +import torch + + +def mrr(outputs: torch.Tensor, targets: torch.Tensor, k=100) -> torch.Tensor: + """ + Calculate the Mean Reciprocal Rank (MRR) + score given model ouptputs and targets + User's data aggreagtesd in batches. + + The MRR@k is the mean overall user of the + reciprocal rank, that is the rank of the highest + ranked relevant item, if any in the top *k*, 0 otherwise. + https://en.wikipedia.org/wiki/Mean_reciprocal_rank + + Args: + outputs (torch.Tensor): + Tensor weith predicted score + size: [batch_size, slate_length] + model outputs, logits + targets (torch.Tensor): + Binary tensor with ground truth. + 1 means the item is relevant + for the user and 0 not relevant + size: [batch_szie, slate_length] + ground truth, labels + k (int): + Parameter fro evaluation on top-k items + + Returns: + result (torch.Tensor): + The mrr score for each user. + """ + k = min(outputs.size(1), k) + _, indices_for_sort = outputs.sort(descending=True, dim=-1) + true_sorted_by_preds = torch.gather( + targets, dim=-1, index=indices_for_sort + ) + true_sorted_by_pred_shrink = true_sorted_by_preds[:, :k] + + values, indices = torch.max(true_sorted_by_pred_shrink, dim=1) + indices = indices.type_as(values).unsqueeze(dim=0).t() + result = torch.tensor(1.0) / (indices + torch.tensor(1.0)) + + zero_sum_mask = values == 0.0 + result[zero_sum_mask] = 0.0 + return result + + +__all__ = ["mrr"] diff --git a/catalyst/metrics/tests/__init__.py b/catalyst/metrics/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/catalyst/metrics/tests/test_mrr.py b/catalyst/metrics/tests/test_mrr.py new file mode 100644 index 0000000000..8e8bd800ab --- /dev/null +++ b/catalyst/metrics/tests/test_mrr.py @@ -0,0 +1,71 @@ +import torch + +from catalyst import metrics + + +def test_mrr(): + """ + Tests for catalyst.metrics.mrr metric. + """ + # # check 0 simple case + y_pred = [0.5, 0.2] + y_true = [1.0, 0.0] + + mrr = metrics.mrr(torch.Tensor([y_pred]), torch.Tensor([y_true])) + assert mrr[0][0] == 1 + + # check 1 simple case + y_pred = [0.5, 0.2] + y_true = [0.0, 1.0] + + mrr = metrics.mrr(torch.Tensor([y_pred]), torch.Tensor([y_true])) + # mrr = metrics.mrr(torch.Tensor(y_pred), torch.Tensor(y_true)) + assert mrr[0][0] == 0.5 + # assert mrr == 0.5 + + # check 2 simple case + y_pred = [0.2, 0.5] + y_true = [0.0, 1.0] + + mrr = metrics.mrr(torch.Tensor([y_pred]), torch.Tensor([y_true])) + assert mrr[0][0] == 1.0 + + # check 3 test multiple users + y_pred1 = [0.2, 0.5] + y_pred05 = [0.5, 0.2] + y_true = [0.0, 1.0] + + mrr = metrics.mrr( + torch.Tensor([y_pred1, y_pred05]), torch.Tensor([y_true, y_true]) + ) + assert mrr[0][0] == 1.0 + assert mrr[1][0] == 0.5 + + # check 4 test with k + y_pred1 = [4.0, 2.0, 3.0, 1.0] + y_pred2 = [1.0, 2.0, 3.0, 4.0] + y_true1 = [0, 0, 1.0, 1.0] + y_true2 = [0, 0, 1.0, 1.0] + + y_pred_torch = torch.Tensor([y_pred1, y_pred2]) + y_true_torch = torch.Tensor([y_true1, y_true2]) + + mrr = metrics.mrr(y_pred_torch, y_true_torch, k=3) + + assert mrr[0][0] == 0.5 + assert mrr[1][0] == 1.0 + + # check 5 test with k + + y_pred1 = [4.0, 2.0, 3.0, 1.0] + y_pred2 = [1.0, 2.0, 3.0, 4.0] + y_true1 = [0, 0, 1.0, 1.0] + y_true2 = [0, 0, 1.0, 1.0] + + y_pred_torch = torch.Tensor([y_pred1, y_pred2]) + y_true_torch = torch.Tensor([y_true1, y_true2]) + + mrr = metrics.mrr(y_pred_torch, y_true_torch, k=1) + + assert mrr[0][0] == 0.0 + assert mrr[1][0] == 1.0 diff --git a/catalyst/utils/metrics/__init__.py b/catalyst/utils/metrics/__init__.py index 2211d0dbc0..49cd570bcb 100644 --- a/catalyst/utils/metrics/__init__.py +++ b/catalyst/utils/metrics/__init__.py @@ -3,6 +3,7 @@ accuracy, multi_label_accuracy, ) + from catalyst.utils.metrics.auc import auc from catalyst.utils.metrics.cmc_score import cmc_score_count, cmc_score from catalyst.utils.metrics.dice import dice, calculate_dice diff --git a/docs/api/utils.rst b/docs/api/utils.rst index 561f5fbaae..4cfe66902a 100644 --- a/docs/api/utils.rst +++ b/docs/api/utils.rst @@ -290,3 +290,10 @@ Functional :members: :undoc-members: :show-inheritance: + +MRR +~~~~~~~~~~~~~~~~~~~~~~ +.. automodule:: catalyst.metrics.mrr + :members: + :undoc-members: + :show-inheritance: