Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mrr calculation #886

Merged
merged 38 commits into from Oct 11, 2020
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
8163316
mrr implementation
Jul 14, 2020
17a4b03
add mrr
Jul 14, 2020
0f1151a
edit codestyle
Jul 15, 2020
93c637e
Add changelog
Jul 15, 2020
f6d7496
updated changelog
Jul 15, 2020
4c46f25
add docstring to mrr
Jul 16, 2020
21a36b0
fixed commit and pep8
Jul 25, 2020
ae9089c
removed clones
Jul 25, 2020
6b765aa
Add batch tests
Jul 25, 2020
a69273f
edit changelog
Jul 25, 2020
c425ccf
Add callbacks
Jul 25, 2020
668196f
make codestyle
Jul 27, 2020
b626eeb
small issues
Jul 27, 2020
7f33538
add newline at the end of the file
Jul 28, 2020
06e6761
small issues
Jul 28, 2020
04caa31
add movielens
Jul 28, 2020
d9d7764
minor improvements
Jul 30, 2020
8bf6700
minor improvements
Jul 30, 2020
49b15a4
removed activation
Aug 2, 2020
e870498
Merge remote-tracking branch 'upstream/master'
Aug 5, 2020
ba4c180
updated changelog
Aug 5, 2020
80920f0
add at k support
Sep 2, 2020
cf8828f
Merge remote-tracking branch 'upstream/master' into recsys-mrr
Sep 6, 2020
84b7b52
add mrr computations
Sep 6, 2020
28c5148
commit before merge
Sep 6, 2020
8911528
Merge branch 'master' into recsys-mrr
Sep 6, 2020
3acb341
deleted dataset from another branch
Sep 6, 2020
5a16261
WIP merr calcback tests
Sep 6, 2020
e192837
minor changes
Sep 10, 2020
2230a5b
alphabetical order of the imports
Sep 12, 2020
f632cc1
add new line at the end of py file
Sep 12, 2020
470cc99
fixed small issues
Sep 22, 2020
8127eb8
changed the codestyle
Sep 29, 2020
034b832
Merge remote-tracking branch 'upstream/master'
Oct 8, 2020
6dad3f4
Merge branch 'master' into recsys-mrr
Oct 8, 2020
b2fa81a
moved files to metrics
Oct 8, 2020
5f51a38
fixed typos
Oct 8, 2020
4947d4b
fixed docs
Oct 11, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Expand Up @@ -7,7 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## [20.07.1] - YYYY-MM-DD

### Added

- MRR metrics calculation ([#886](https://github.com/catalyst-team/catalyst/pull/886))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like we need to move it up :)

- `CMCScoreCallback` ([#880](https://github.com/catalyst-team/catalyst/pull/880))
- kornia augmentations `BatchTransformCallback` ([#862](https://github.com/catalyst-team/catalyst/issues/862))
- `average_precision` and `mean_average_precision` metrics ([#883](https://github.com/catalyst-team/catalyst/pull/883))
Expand Down
2 changes: 2 additions & 0 deletions catalyst/dl/callbacks/metrics/__init__.py
Expand Up @@ -26,3 +26,5 @@
AveragePrecisionCallback,
MeanAveragePrecisionCallback,
)

from catalyst.dl.callbacks.metrics.mrr import MRRCallback
40 changes: 40 additions & 0 deletions catalyst/dl/callbacks/metrics/mrr.py
@@ -0,0 +1,40 @@
from typing import List

from catalyst.core import MetricCallback
from catalyst.utils import metrics


class MRRCallback(MetricCallback):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As for tests, I think better return to the question when we implement at least one Learning to Rank models.

"""Calculates the AUC per class for each loader.

.. note::
Currently, supports binary and multi-label cases.
"""

def __init__(
self,
input_key: str = "targets",
output_key: str = "logits",
prefix: str = "mrr",
activation: str = "none",
):
"""
Args:
input_key (str): input key to use for auc calculation
specifies our ``y_true``
output_key (str): output key to use for auc calculation;
specifies our ``y_pred``
prefix (str): name to display for mrr when printing
activation (str): An torch.nn activation applied to the outputs.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need activation? mrr metric doesn't have activation support

Must be one of ``'none'``, ``'Sigmoid'``, or ``'Softmax2d'``
"""
super().__init__(
prefix=prefix,
metric_fn=metrics.mrr,
input_key=input_key,
output_key=output_key,
activation=activation,
)


__all__ = ["MRRCallback"]
1 change: 1 addition & 0 deletions catalyst/utils/metrics/__init__.py
Expand Up @@ -9,3 +9,4 @@
from .focal import reduced_focal_loss, sigmoid_focal_loss
from .iou import iou, jaccard
from .precision import average_precision, mean_average_precision
from .mrr import mrr
42 changes: 42 additions & 0 deletions catalyst/utils/metrics/mrr.py
@@ -0,0 +1,42 @@
"""
MRR metric.
"""

import torch


def mrr(outputs: torch.Tensor,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

targets: torch.Tensor
) -> torch.Tensor:

"""
Calculate the MRR score given model ouptputs and targets
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd appreciate it if you could extend documentation with the explanation of the metric or add link users will be able to read more about this score.

Args:
outputs [batch_size, slate_length] (torch.Tensor):
model outputs, logits
targets [batch_szie, slate_length] (torch.Tensor):
ground truth, labels
Returns:
mrr (float): the mrr score for each slate
"""
max_rank = targets.shape[0]

_, indices_for_sort = outputs.sort(descending=True, dim=-1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we use torch.topk here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the comment is more relevant to the next part
true_sorted_by_pred_shrink = true_sorted_by_preds[:, :k]

Anyway, I might missing the advantage of torch.topk over the proposed approach.
We need to sort predictions by the corresponding indexes of the outputs. Is there is a way in pytorch to sort in that fashion?

true_sorted_by_preds = torch.gather(targets, dim=-1, index=indices_for_sort)
values, indices = torch.max(true_sorted_by_preds, dim=0)
indices = indices.type_as(values).unsqueeze(dim=0).t()
max_rank_rep = torch.tensor(
data=max_rank, device=indices.device, dtype=torch.float32
)
within_at_mask = (indices < max_rank_rep).type(torch.float32)

result = torch.tensor(1.0) / (indices + torch.tensor(1.0))

zero_sum_mask = torch.sum(values) == 0.0
result[zero_sum_mask] = 0.0

mrr = result * within_at_mask
return mrr


__all__ = ["mrr"]
39 changes: 39 additions & 0 deletions catalyst/utils/metrics/tests/test_mrr.py
@@ -0,0 +1,39 @@
import torch

from catalyst.utils import metrics


def test_mrr():
"""
Tests for catalyst.utils.metrics.mrr metric.
"""

# check 0 simple case
y_pred = [0.5, 0.2]
y_true = [1.0, 0.0]

mrr = metrics.mrr(torch.Tensor(y_pred), torch.Tensor(y_true))
assert mrr == 1

# check 1 simple case
y_pred = [0.5, 0.2]
y_true = [0.0, 1.0]

mrr = metrics.mrr(torch.Tensor(y_pred), torch.Tensor(y_true))
assert mrr == 0.5

# check 2 simple case
y_pred = [0.2, 0.5]
y_true = [0.0, 1.0]

mrr = metrics.mrr(torch.Tensor(y_pred), torch.Tensor(y_true))
assert mrr == 1.0

#test batched slates
y_pred_1 = [0.2, 0.5]
y_pred_05 = [0.5, 0.2]
y_true = [0.0, 1.0]

mrr = metrics.mrr(torch.Tensor([y_pred_1, y_pred_05]), torch.Tensor([y_true, y_true]))
assert mrr[0][0] == 1.0
assert mrr[1][0] == 0.5