Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mrr calculation #886

Merged
merged 38 commits into from
Oct 11, 2020
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
8163316
mrr implementation
Jul 14, 2020
17a4b03
add mrr
Jul 14, 2020
0f1151a
edit codestyle
Jul 15, 2020
93c637e
Add changelog
Jul 15, 2020
f6d7496
updated changelog
Jul 15, 2020
4c46f25
add docstring to mrr
Jul 16, 2020
21a36b0
fixed commit and pep8
Jul 25, 2020
ae9089c
removed clones
Jul 25, 2020
6b765aa
Add batch tests
Jul 25, 2020
a69273f
edit changelog
Jul 25, 2020
c425ccf
Add callbacks
Jul 25, 2020
668196f
make codestyle
Jul 27, 2020
b626eeb
small issues
Jul 27, 2020
7f33538
add newline at the end of the file
Jul 28, 2020
06e6761
small issues
Jul 28, 2020
04caa31
add movielens
Jul 28, 2020
d9d7764
minor improvements
Jul 30, 2020
8bf6700
minor improvements
Jul 30, 2020
49b15a4
removed activation
Aug 2, 2020
e870498
Merge remote-tracking branch 'upstream/master'
Aug 5, 2020
ba4c180
updated changelog
Aug 5, 2020
80920f0
add at k support
Sep 2, 2020
cf8828f
Merge remote-tracking branch 'upstream/master' into recsys-mrr
Sep 6, 2020
84b7b52
add mrr computations
Sep 6, 2020
28c5148
commit before merge
Sep 6, 2020
8911528
Merge branch 'master' into recsys-mrr
Sep 6, 2020
3acb341
deleted dataset from another branch
Sep 6, 2020
5a16261
WIP merr calcback tests
Sep 6, 2020
e192837
minor changes
Sep 10, 2020
2230a5b
alphabetical order of the imports
Sep 12, 2020
f632cc1
add new line at the end of py file
Sep 12, 2020
470cc99
fixed small issues
Sep 22, 2020
8127eb8
changed the codestyle
Sep 29, 2020
034b832
Merge remote-tracking branch 'upstream/master'
Oct 8, 2020
6dad3f4
Merge branch 'master' into recsys-mrr
Oct 8, 2020
b2fa81a
moved files to metrics
Oct 8, 2020
5f51a38
fixed typos
Oct 8, 2020
4947d4b
fixed docs
Oct 11, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## [20.07.1] - YYYY-MM-DD

### Added

- MRR metrics calculation ([#886](https://github.com/catalyst-team/catalyst/pull/886))
-

### Changed
Expand Down
36 changes: 36 additions & 0 deletions catalyst/utils/metrics/mrr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch


def mrr(outputs: torch.Tensor, targets: torch.Tensor):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

@zkid18 zkid18 Jul 16, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I'll have a look at callbacks.


"""
Calculate the MRR score given model ouptputs and targets
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd appreciate it if you could extend documentation with the explanation of the metric or add link users will be able to read more about this score.

Args:
outputs [batch_size, slate_length] (torch.Tensor): model outputs, logits
targets [batch_szie, slate_length] (torch.Tensor): ground truth, labels
Returns:
mrr (float): the mrr score
"""
outputs = outputs.clone()
targets = targets.clone()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do you need clone?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried to follow the 'shared clones', which is a common pattern in torch. But it seems it not so necessary here.

max_rank = targets.shape[0]

_, indices = outputs.sort(descending=True, dim=-1)
true_sorted_by_preds = torch.gather(targets, dim=0, index=indices)
values, indices = torch.max(true_sorted_by_preds, dim=0)
indices = indices.type_as(values).unsqueeze(dim=0).t()
ats_rep = torch.tensor(
data=max_rank, device=indices.device, dtype=torch.float32
)
within_at_mask = (indices < ats_rep).type(torch.float32)

result = torch.tensor(1.0) / (indices + torch.tensor(1.0))

zero_sum_mask = torch.sum(values) == 0.0
result[zero_sum_mask] = 0.0

mrr = result * within_at_mask
return mrr[0]


__all__ = ["mrr"]
23 changes: 23 additions & 0 deletions catalyst/utils/metrics/tests/test_mrr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import torch

from catalyst.utils import metrics


def test_mrr():
"""
Tests for catalyst.utils.metrics.mrr metric.
"""

# check 0 simple case
y_pred = [0.5, 0.2]
y_true = [1.0, 0.0]

mrr = metrics.mrr(torch.Tensor(y_pred), torch.Tensor(y_true))
assert mrr == 1

# check 1 simple case
y_pred = [0.5, 0.2]
y_true = [0.0, 1.0]

mrr = metrics.mrr(torch.Tensor(y_pred), torch.Tensor(y_true))
assert mrr == 0.5