New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add mrr calculation #886
Merged
Merged
Add mrr calculation #886
Changes from 37 commits
Commits
Show all changes
38 commits
Select commit
Hold shift + click to select a range
8163316
mrr implementation
17a4b03
add mrr
0f1151a
edit codestyle
93c637e
Add changelog
f6d7496
updated changelog
4c46f25
add docstring to mrr
21a36b0
fixed commit and pep8
ae9089c
removed clones
6b765aa
Add batch tests
a69273f
edit changelog
c425ccf
Add callbacks
668196f
make codestyle
b626eeb
small issues
7f33538
add newline at the end of the file
06e6761
small issues
04caa31
add movielens
d9d7764
minor improvements
8bf6700
minor improvements
49b15a4
removed activation
e870498
Merge remote-tracking branch 'upstream/master'
ba4c180
updated changelog
80920f0
add at k support
cf8828f
Merge remote-tracking branch 'upstream/master' into recsys-mrr
84b7b52
add mrr computations
28c5148
commit before merge
8911528
Merge branch 'master' into recsys-mrr
3acb341
deleted dataset from another branch
5a16261
WIP merr calcback tests
e192837
minor changes
2230a5b
alphabetical order of the imports
f632cc1
add new line at the end of py file
470cc99
fixed small issues
8127eb8
changed the codestyle
034b832
Merge remote-tracking branch 'upstream/master'
6dad3f4
Merge branch 'master' into recsys-mrr
b2fa81a
moved files to metrics
5f51a38
fixed typos
4947d4b
fixed docs
File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from catalyst.core import MetricCallback | ||
from catalyst.utils import metrics | ||
|
||
|
||
class MRRCallback(MetricCallback): | ||
"""Calculates the AUC per class for each loader. | ||
|
||
.. note:: | ||
Currently, supports binary and multi-label cases. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
input_key: str = "targets", | ||
output_key: str = "logits", | ||
prefix: str = "mrr", | ||
): | ||
""" | ||
Args: | ||
input_key (str): input key to use for mrr calculation | ||
specifies our ``y_true`` | ||
output_key (str): output key to use for mrr calculation; | ||
specifies our ``y_pred`` | ||
prefix (str): name to display for mrr when printing | ||
""" | ||
super().__init__( | ||
prefix=prefix, | ||
metric_fn=metrics.mrr, | ||
input_key=input_key, | ||
output_key=output_key, | ||
) | ||
|
||
|
||
__all__ = ["MRRCallback"] |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# flake8: noqa | ||
from catalyst.metrics.mrr import mrr |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
""" | ||
MRR metric. | ||
""" | ||
|
||
import torch | ||
|
||
|
||
def mrr(outputs: torch.Tensor, targets: torch.Tensor, k=100) -> torch.Tensor: | ||
""" | ||
Calculate the Mean Reciprocal Rank (MRR) | ||
score given model ouptputs and targets | ||
User's data aggreagtesd in batches. | ||
|
||
The MRR@k is the mean overall user of the | ||
reciprocal rank, that is the rank of the highest | ||
ranked relevant item, if any in the top *k*, 0 otherwise. | ||
https://en.wikipedia.org/wiki/Mean_reciprocal_rank | ||
|
||
Args: | ||
outputs (torch.Tensor): | ||
Tensor weith predicted score | ||
size: [batch_size, slate_length] | ||
model outputs, logits | ||
targets (torch.Tensor): | ||
Binary tensor with ground truth. | ||
1 means the item is relevant | ||
for the user and 0 not relevant | ||
size: [batch_szie, slate_length] | ||
ground truth, labels | ||
k (int): | ||
Parameter fro evaluation on top-k items | ||
|
||
Returns: | ||
result (torch.Tensor): | ||
The mrr score for each user. | ||
""" | ||
k = min(outputs.size(1), k) | ||
_, indices_for_sort = outputs.sort(descending=True, dim=-1) | ||
true_sorted_by_preds = torch.gather( | ||
targets, dim=-1, index=indices_for_sort | ||
) | ||
true_sorted_by_pred_shrink = true_sorted_by_preds[:, :k] | ||
|
||
values, indices = torch.max(true_sorted_by_pred_shrink, dim=1) | ||
indices = indices.type_as(values).unsqueeze(dim=0).t() | ||
result = torch.tensor(1.0) / (indices + torch.tensor(1.0)) | ||
|
||
zero_sum_mask = values == 0.0 | ||
result[zero_sum_mask] = 0.0 | ||
return result | ||
|
||
|
||
__all__ = ["mrr"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import torch | ||
|
||
from catalyst import metrics | ||
|
||
|
||
def test_mrr(): | ||
""" | ||
Tests for catalyst.metrics.mrr metric. | ||
""" | ||
# # check 0 simple case | ||
y_pred = [0.5, 0.2] | ||
y_true = [1.0, 0.0] | ||
|
||
mrr = metrics.mrr(torch.Tensor([y_pred]), torch.Tensor([y_true])) | ||
assert mrr[0][0] == 1 | ||
|
||
# check 1 simple case | ||
y_pred = [0.5, 0.2] | ||
y_true = [0.0, 1.0] | ||
|
||
mrr = metrics.mrr(torch.Tensor([y_pred]), torch.Tensor([y_true])) | ||
# mrr = metrics.mrr(torch.Tensor(y_pred), torch.Tensor(y_true)) | ||
assert mrr[0][0] == 0.5 | ||
# assert mrr == 0.5 | ||
|
||
# check 2 simple case | ||
y_pred = [0.2, 0.5] | ||
y_true = [0.0, 1.0] | ||
|
||
mrr = metrics.mrr(torch.Tensor([y_pred]), torch.Tensor([y_true])) | ||
assert mrr[0][0] == 1.0 | ||
|
||
# check 3 test multiple users | ||
y_pred1 = [0.2, 0.5] | ||
y_pred05 = [0.5, 0.2] | ||
y_true = [0.0, 1.0] | ||
|
||
mrr = metrics.mrr( | ||
torch.Tensor([y_pred1, y_pred05]), torch.Tensor([y_true, y_true]) | ||
) | ||
assert mrr[0][0] == 1.0 | ||
assert mrr[1][0] == 0.5 | ||
|
||
# check 4 test with k | ||
y_pred1 = [4.0, 2.0, 3.0, 1.0] | ||
y_pred2 = [1.0, 2.0, 3.0, 4.0] | ||
y_true1 = [0, 0, 1.0, 1.0] | ||
y_true2 = [0, 0, 1.0, 1.0] | ||
|
||
y_pred_torch = torch.Tensor([y_pred1, y_pred2]) | ||
y_true_torch = torch.Tensor([y_true1, y_true2]) | ||
|
||
mrr = metrics.mrr(y_pred_torch, y_true_torch, k=3) | ||
|
||
assert mrr[0][0] == 0.5 | ||
assert mrr[1][0] == 1.0 | ||
|
||
# check 5 test with k | ||
|
||
y_pred1 = [4.0, 2.0, 3.0, 1.0] | ||
y_pred2 = [1.0, 2.0, 3.0, 4.0] | ||
y_true1 = [0, 0, 1.0, 1.0] | ||
y_true2 = [0, 0, 1.0, 1.0] | ||
|
||
y_pred_torch = torch.Tensor([y_pred1, y_pred2]) | ||
y_true_torch = torch.Tensor([y_true1, y_true2]) | ||
|
||
mrr = metrics.mrr(y_pred_torch, y_true_torch, k=1) | ||
|
||
assert mrr[0][0] == 0.0 | ||
assert mrr[1][0] == 1.0 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -290,3 +290,10 @@ Functional | |
:members: | ||
:undoc-members: | ||
:show-inheritance: | ||
|
||
MRR | ||
~~~~~~~~~~~~~~~~~~~~~~ | ||
.. automodule:: catalyst.utils.metrics.mrr | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. looks like here is an error with docs :) |
||
:members: | ||
:undoc-members: | ||
:show-inheritance: |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you please add this callback to the docs?
https://github.com/catalyst-team/catalyst/blob/master/docs/api/dl.rst#metrics
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As for tests, I think better return to the question when we implement at least one Learning to Rank models.