-
-
Notifications
You must be signed in to change notification settings - Fork 385
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add mrr calculation #886
Add mrr calculation #886
Changes from 5 commits
8163316
17a4b03
0f1151a
93c637e
f6d7496
4c46f25
21a36b0
ae9089c
6b765aa
a69273f
c425ccf
668196f
b626eeb
7f33538
06e6761
04caa31
d9d7764
8bf6700
49b15a4
e870498
ba4c180
80920f0
cf8828f
84b7b52
28c5148
8911528
3acb341
5a16261
e192837
2230a5b
f632cc1
470cc99
8127eb8
034b832
6dad3f4
b2fa81a
5f51a38
4947d4b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import torch | ||
|
||
|
||
def mrr(outputs: torch.Tensor, targets: torch.Tensor): | ||
|
||
""" | ||
Calculate the MRR score given model ouptputs and targets | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd appreciate it if you could extend documentation with the explanation of the metric or add link users will be able to read more about this score. |
||
Args: | ||
outputs [batch_size, slate_length] (torch.Tensor): model outputs, logits | ||
targets [batch_szie, slate_length] (torch.Tensor): ground truth, labels | ||
Returns: | ||
mrr (float): the mrr score | ||
""" | ||
outputs = outputs.clone() | ||
targets = targets.clone() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do you need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tried to follow the 'shared clones', which is a common pattern in torch. But it seems it not so necessary here. |
||
max_rank = targets.shape[0] | ||
|
||
_, indices = outputs.sort(descending=True, dim=-1) | ||
true_sorted_by_preds = torch.gather(targets, dim=0, index=indices) | ||
values, indices = torch.max(true_sorted_by_preds, dim=0) | ||
indices = indices.type_as(values).unsqueeze(dim=0).t() | ||
ats_rep = torch.tensor( | ||
data=max_rank, device=indices.device, dtype=torch.float32 | ||
) | ||
within_at_mask = (indices < ats_rep).type(torch.float32) | ||
|
||
result = torch.tensor(1.0) / (indices + torch.tensor(1.0)) | ||
|
||
zero_sum_mask = torch.sum(values) == 0.0 | ||
result[zero_sum_mask] = 0.0 | ||
|
||
mrr = result * within_at_mask | ||
return mrr[0] | ||
|
||
|
||
__all__ = ["mrr"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import torch | ||
|
||
from catalyst.utils import metrics | ||
|
||
|
||
def test_mrr(): | ||
""" | ||
Tests for catalyst.utils.metrics.mrr metric. | ||
""" | ||
|
||
# check 0 simple case | ||
y_pred = [0.5, 0.2] | ||
y_true = [1.0, 0.0] | ||
|
||
mrr = metrics.mrr(torch.Tensor(y_pred), torch.Tensor(y_true)) | ||
assert mrr == 1 | ||
|
||
# check 1 simple case | ||
y_pred = [0.5, 0.2] | ||
y_true = [0.0, 1.0] | ||
|
||
mrr = metrics.mrr(torch.Tensor(y_pred), torch.Tensor(y_true)) | ||
assert mrr == 0.5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could you please add docs here
https://github.com/catalyst-team/catalyst/blob/master/docs/api/utils.rst
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
btw, what do you think about MRRCallback? like https://github.com/catalyst-team/catalyst/blob/master/catalyst/dl/callbacks/metrics/iou.py#L12 for example
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I'll have a look at callbacks.