Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #971 from elephantmipt/metrics
Metrics
- Loading branch information
Showing
8 changed files
with
328 additions
and
45 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from typing import Optional, Tuple | ||
|
||
import torch | ||
|
||
from catalyst.metrics.functional import get_multiclass_statistics | ||
|
||
|
||
def precision_recall_fbeta_support( | ||
outputs: torch.Tensor, | ||
targets: torch.Tensor, | ||
beta: float = 1, | ||
eps: float = 1e-6, | ||
argmax_dim: int = -1, | ||
num_classes: Optional[int] = None, | ||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | ||
""" | ||
Counts precision, recall, fbeta_score. | ||
Args: | ||
outputs: A list of predicted elements | ||
targets: A list of elements that are to be predicted | ||
beta: beta param for f_score | ||
eps: epsilon to avoid zero division | ||
argmax_dim: int, that specifies dimension for argmax transformation | ||
in case of scores/probabilities in ``outputs`` | ||
num_classes: int, that specifies number of classes if it known. | ||
Returns: | ||
tuple of precision, recall, fbeta_score | ||
""" | ||
tn, fp, fn, tp, support = get_multiclass_statistics( | ||
outputs=outputs, | ||
targets=targets, | ||
argmax_dim=argmax_dim, | ||
num_classes=num_classes, | ||
) | ||
precision = (tp + eps) / (fp + tp + eps) | ||
recall = (tp + eps) / (fn + tp + eps) | ||
numerator = (1 + beta ** 2) * precision * recall | ||
denominator = beta ** 2 * precision + recall | ||
fbeta = numerator / denominator | ||
|
||
return precision, recall, fbeta, support |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,53 +1,84 @@ | ||
""" | ||
F1 score. | ||
""" | ||
from typing import Optional, Union | ||
|
||
import torch | ||
|
||
from catalyst.utils.torch import get_activation_fn | ||
from catalyst.metrics.classification import precision_recall_fbeta_support | ||
|
||
|
||
def f1_score( | ||
def fbeta_score( | ||
outputs: torch.Tensor, | ||
targets: torch.Tensor, | ||
beta: float = 1.0, | ||
eps: float = 1e-7, | ||
threshold: float = None, | ||
activation: str = "Sigmoid", | ||
): | ||
argmax_dim: int = -1, | ||
num_classes: Optional[int] = None, | ||
) -> Union[float, torch.Tensor]: | ||
""" | ||
Counts fbeta score for given ``outputs`` and ``targets``. | ||
Args: | ||
outputs: A list of predicted elements | ||
targets: A list of elements that are to be predicted | ||
eps: epsilon to avoid zero division | ||
beta: beta param for f_score | ||
threshold: threshold for outputs binarization | ||
activation: An torch.nn activation applied to the outputs. | ||
Must be one of ["none", "Sigmoid", "Softmax2d"] | ||
eps: epsilon to avoid zero division | ||
argmax_dim: int, that specifies dimension for argmax transformation | ||
in case of scores/probabilities in ``outputs`` | ||
num_classes: int, that specifies number of classes if it known | ||
Raises: | ||
Exception: If ``beta`` is a negative number. | ||
Returns: | ||
float: F_1 score | ||
float: F_1 score. | ||
""" | ||
activation_fn = get_activation_fn(activation) | ||
if beta < 0: | ||
raise Exception("beta parameter should be non-negative") | ||
|
||
_p, _r, fbeta, _ = precision_recall_fbeta_support( | ||
outputs=outputs, | ||
targets=targets, | ||
beta=beta, | ||
eps=eps, | ||
argmax_dim=argmax_dim, | ||
num_classes=num_classes, | ||
) | ||
return fbeta | ||
|
||
outputs = activation_fn(outputs) | ||
|
||
if threshold is not None: | ||
outputs = (outputs > threshold).float() | ||
def f1_score( | ||
outputs: torch.Tensor, | ||
targets: torch.Tensor, | ||
eps: float = 1e-7, | ||
argmax_dim: int = -1, | ||
num_classes: Optional[int] = None, | ||
) -> Union[float, torch.Tensor]: | ||
""" | ||
Fbeta_score with beta=1. | ||
true_positive = torch.sum(targets * outputs) | ||
false_positive = torch.sum(outputs) - true_positive | ||
false_negative = torch.sum(targets) - true_positive | ||
Args: | ||
outputs: A list of predicted elements | ||
targets: A list of elements that are to be predicted | ||
eps: epsilon to avoid zero division | ||
argmax_dim: int, that specifies dimension for argmax transformation | ||
in case of scores/probabilities in ``outputs`` | ||
num_classes: int, that specifies number of classes if it known | ||
precision_plus_recall = ( | ||
(1 + beta ** 2) * true_positive | ||
+ beta ** 2 * false_negative | ||
+ false_positive | ||
+ eps | ||
Returns: | ||
float: F_1 score | ||
""" | ||
score = fbeta_score( | ||
outputs=outputs, | ||
targets=targets, | ||
beta=1, | ||
eps=eps, | ||
argmax_dim=argmax_dim, | ||
num_classes=num_classes, | ||
) | ||
|
||
score = ((1 + beta ** 2) * true_positive + eps) / precision_plus_recall | ||
|
||
return score | ||
|
||
|
||
__all__ = ["f1_score"] | ||
__all__ = ["f1_score", "fbeta_score"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
from typing import Optional, Union | ||
|
||
import torch | ||
|
||
from catalyst.metrics import precision_recall_fbeta_support | ||
|
||
|
||
def recall( | ||
outputs: torch.Tensor, | ||
targets: torch.Tensor, | ||
argmax_dim: int = -1, | ||
eps: float = 1e-7, | ||
num_classes: Optional[int] = None, | ||
) -> Union[float, torch.Tensor]: | ||
""" | ||
Multiclass precision metric. | ||
Args: | ||
outputs: estimated targets as predicted by a model | ||
with shape [bs; ..., (num_classes or 1)] | ||
targets: ground truth (correct) target values | ||
with shape [bs; ..., 1] | ||
argmax_dim: int, that specifies dimension for argmax transformation | ||
in case of scores/probabilities in ``outputs`` | ||
eps: float. Epsilon to avoid zero division. | ||
num_classes: int, that specifies number of classes if it known. | ||
Returns: | ||
Tensor: recall for every class | ||
""" | ||
_, recall_score, _, _ = precision_recall_fbeta_support( | ||
outputs=outputs, | ||
targets=targets, | ||
argmax_dim=argmax_dim, | ||
eps=eps, | ||
num_classes=num_classes, | ||
) | ||
|
||
return recall_score |
Oops, something went wrong.