New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add mrr calculation #886
Add mrr calculation #886
Changes from 11 commits
8163316
17a4b03
0f1151a
93c637e
f6d7496
4c46f25
21a36b0
ae9089c
6b765aa
a69273f
c425ccf
668196f
b626eeb
7f33538
06e6761
04caa31
d9d7764
8bf6700
49b15a4
e870498
ba4c180
80920f0
cf8828f
84b7b52
28c5148
8911528
3acb341
5a16261
e192837
2230a5b
f632cc1
470cc99
8127eb8
034b832
6dad3f4
b2fa81a
5f51a38
4947d4b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
from typing import List | ||
|
||
from catalyst.core import MetricCallback | ||
from catalyst.utils import metrics | ||
|
||
|
||
class MRRCallback(MetricCallback): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you please add this callback to the docs? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. As for tests, I think better return to the question when we implement at least one Learning to Rank models. |
||
"""Calculates the AUC per class for each loader. | ||
|
||
.. note:: | ||
Currently, supports binary and multi-label cases. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
input_key: str = "targets", | ||
output_key: str = "logits", | ||
prefix: str = "mrr", | ||
activation: str = "none", | ||
): | ||
""" | ||
Args: | ||
input_key (str): input key to use for auc calculation | ||
specifies our ``y_true`` | ||
output_key (str): output key to use for auc calculation; | ||
specifies our ``y_pred`` | ||
prefix (str): name to display for mrr when printing | ||
activation (str): An torch.nn activation applied to the outputs. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why do we need |
||
Must be one of ``'none'``, ``'Sigmoid'``, or ``'Softmax2d'`` | ||
""" | ||
super().__init__( | ||
prefix=prefix, | ||
metric_fn=metrics.mrr, | ||
input_key=input_key, | ||
output_key=output_key, | ||
activation=activation, | ||
) | ||
|
||
|
||
__all__ = ["MRRCallback"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
""" | ||
MRR metric. | ||
""" | ||
|
||
import torch | ||
|
||
|
||
def mrr(outputs: torch.Tensor, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could you please add this metric to docs? |
||
targets: torch.Tensor | ||
) -> torch.Tensor: | ||
|
||
""" | ||
Calculate the MRR score given model ouptputs and targets | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'd appreciate it if you could extend documentation with the explanation of the metric or add link users will be able to read more about this score. |
||
Args: | ||
outputs [batch_size, slate_length] (torch.Tensor): | ||
model outputs, logits | ||
targets [batch_szie, slate_length] (torch.Tensor): | ||
ground truth, labels | ||
Returns: | ||
mrr (float): the mrr score for each slate | ||
""" | ||
max_rank = targets.shape[0] | ||
|
||
_, indices_for_sort = outputs.sort(descending=True, dim=-1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. could we use torch.topk here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess the comment is more relevant to the next part Anyway, I might missing the advantage of |
||
true_sorted_by_preds = torch.gather(targets, dim=-1, index=indices_for_sort) | ||
values, indices = torch.max(true_sorted_by_preds, dim=0) | ||
indices = indices.type_as(values).unsqueeze(dim=0).t() | ||
max_rank_rep = torch.tensor( | ||
data=max_rank, device=indices.device, dtype=torch.float32 | ||
) | ||
within_at_mask = (indices < max_rank_rep).type(torch.float32) | ||
|
||
result = torch.tensor(1.0) / (indices + torch.tensor(1.0)) | ||
|
||
zero_sum_mask = torch.sum(values) == 0.0 | ||
result[zero_sum_mask] = 0.0 | ||
|
||
mrr = result * within_at_mask | ||
return mrr | ||
|
||
|
||
__all__ = ["mrr"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import torch | ||
|
||
from catalyst.utils import metrics | ||
|
||
|
||
def test_mrr(): | ||
""" | ||
Tests for catalyst.utils.metrics.mrr metric. | ||
""" | ||
|
||
# check 0 simple case | ||
y_pred = [0.5, 0.2] | ||
y_true = [1.0, 0.0] | ||
|
||
mrr = metrics.mrr(torch.Tensor(y_pred), torch.Tensor(y_true)) | ||
assert mrr == 1 | ||
|
||
# check 1 simple case | ||
y_pred = [0.5, 0.2] | ||
y_true = [0.0, 1.0] | ||
|
||
mrr = metrics.mrr(torch.Tensor(y_pred), torch.Tensor(y_true)) | ||
assert mrr == 0.5 | ||
|
||
# check 2 simple case | ||
y_pred = [0.2, 0.5] | ||
y_true = [0.0, 1.0] | ||
|
||
mrr = metrics.mrr(torch.Tensor(y_pred), torch.Tensor(y_true)) | ||
assert mrr == 1.0 | ||
|
||
#test batched slates | ||
y_pred_1 = [0.2, 0.5] | ||
y_pred_05 = [0.5, 0.2] | ||
y_true = [0.0, 1.0] | ||
|
||
mrr = metrics.mrr(torch.Tensor([y_pred_1, y_pred_05]), torch.Tensor([y_true, y_true])) | ||
assert mrr[0][0] == 1.0 | ||
assert mrr[1][0] == 0.5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks like we need to move it up :)