From 69e93df512e8cfc6331f6d0c5a0189d19fd2605b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Louren=C3=A7o=20Silva?= Date: Mon, 6 Jun 2022 03:45:57 +0100 Subject: [PATCH 01/10] Fix for bug when providing superclass arguments as kwargs --- tests/classification/test_accuracy.py | 6 ++++++ tests/classification/test_dice.py | 6 ++++++ tests/classification/test_f_beta.py | 8 ++++++++ tests/classification/test_jaccard.py | 5 +++++ tests/classification/test_precision_recall.py | 8 ++++++++ tests/classification/test_specificity.py | 6 ++++++ torchmetrics/classification/accuracy.py | 9 ++++++--- torchmetrics/classification/dice.py | 9 +++++++-- torchmetrics/classification/f_beta.py | 7 +++++-- torchmetrics/classification/jaccard.py | 4 +++- torchmetrics/classification/precision_recall.py | 16 ++++++++++++---- torchmetrics/classification/specificity.py | 9 +++++++-- 12 files changed, 79 insertions(+), 14 deletions(-) diff --git a/tests/classification/test_accuracy.py b/tests/classification/test_accuracy.py index 40c27f20368..73125cfa834 100644 --- a/tests/classification/test_accuracy.py +++ b/tests/classification/test_accuracy.py @@ -448,3 +448,9 @@ def test_negmetric_noneavg(noneavg=_negmetric_noneavg): assert torch.allclose(noneavg["res1"], result1, equal_nan=True) result2 = acc(noneavg["pred2"], noneavg["target2"]) assert torch.allclose(noneavg["res2"], result2, equal_nan=True) + +def test_provide_superclass_kwargs(): + """Test instantiating class providing superclass arguments. + """ + acc = Accuracy(reduce='micro') + acc = Accuracy(mdmc_reduce='global') diff --git a/tests/classification/test_dice.py b/tests/classification/test_dice.py index 445f926d70e..a85c9a25008 100644 --- a/tests/classification/test_dice.py +++ b/tests/classification/test_dice.py @@ -163,3 +163,9 @@ def test_dice_fn(self, preds, target, ignore_index): sk_metric=partial(_sk_dice, ignore_index=ignore_index), metric_args={"ignore_index": ignore_index}, ) + +def test_provide_superclass_kwargs(): + """Test instantiating class providing superclass arguments. + """ + dice = Dice(reduce='micro') + dice = Dice(mdmc_reduce='global') diff --git a/tests/classification/test_f_beta.py b/tests/classification/test_f_beta.py index 1c1a9032b4a..569c0ee7743 100644 --- a/tests/classification/test_f_beta.py +++ b/tests/classification/test_f_beta.py @@ -460,3 +460,11 @@ def test_same_input(metric_class, metric_functional, sk_fn, average, ignore_inde assert torch.allclose(class_res, torch.tensor(sk_res).float()) assert torch.allclose(func_res, torch.tensor(sk_res).float()) + +def test_provide_superclass_kwargs(): + """Test instantiating class providing superclass arguments. + """ + fbeta = FBetaScore(reduce='micro') + fbeta = FBetaScore(mdmc_reduce='global') + f1 = F1Score(reduce='micro') + f1 = F1Score(mdmc_reduce='global') diff --git a/tests/classification/test_jaccard.py b/tests/classification/test_jaccard.py index 80dc44c0b2a..3e93a8f4d2d 100644 --- a/tests/classification/test_jaccard.py +++ b/tests/classification/test_jaccard.py @@ -237,3 +237,8 @@ def test_jaccard_ignore_index(pred, target, ignore_index, num_classes, average, # reduction=reduction, ) assert torch.allclose(jaccard_val, tensor(expected).to(jaccard_val)) + +def test_provide_superclass_kwargs(): + """Test instantiating class providing superclass arguments. + """ + jaccard = JaccardIndex(num_classes=1, normalize='true') diff --git a/tests/classification/test_precision_recall.py b/tests/classification/test_precision_recall.py index 1c117d70591..e750df41353 100644 --- a/tests/classification/test_precision_recall.py +++ b/tests/classification/test_precision_recall.py @@ -468,3 +468,11 @@ def test_noneavg(metric_cls, noneavg=_negmetric_noneavg): assert torch.allclose(noneavg["res1"], result1, equal_nan=True) result2 = prec(noneavg["pred2"], noneavg["target2"]) assert torch.allclose(noneavg["res2"], result2, equal_nan=True) + +def test_provide_superclass_kwargs(): + """Test instantiating class providing superclass arguments. + """ + precision = Precision(reduce='micro') + precision = Precision(mdmc_reduce='global') + recall = Recall(reduce='micro') + recall = Recall(mdmc_reduce='global') diff --git a/tests/classification/test_specificity.py b/tests/classification/test_specificity.py index dc6f38df5ce..0905a333992 100644 --- a/tests/classification/test_specificity.py +++ b/tests/classification/test_specificity.py @@ -410,3 +410,9 @@ def test_class_not_present(metric_class, metric_fn, ignore_index, expected): cl_metric(preds, target) result_cl = cl_metric.compute() assert torch.allclose(expected, result_cl, equal_nan=True) + +def test_provide_superclass_kwargs(): + """Test instantiating class providing superclass arguments. + """ + specificity = Specificity(reduce='micro') + specificity = Specificity(mdmc_reduce='global') diff --git a/torchmetrics/classification/accuracy.py b/torchmetrics/classification/accuracy.py index 8f77235dc39..a3d977d7648 100644 --- a/torchmetrics/classification/accuracy.py +++ b/torchmetrics/classification/accuracy.py @@ -23,7 +23,7 @@ _subset_accuracy_compute, _subset_accuracy_update, ) -from torchmetrics.utilities.enums import DataType +from torchmetrics.utilities.enums import AverageMethod, DataType from torchmetrics.classification.stat_scores import StatScores # isort:skip @@ -176,9 +176,12 @@ def __init__( if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") + if "reduce" not in kwargs or kwargs["reduce"] is None: + kwargs["reduce"] = "macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE, "none"] else average + if "mdmc_reduce" not in kwargs or kwargs["mdmc_reduce"] is None: + kwargs["mdmc_reduce"] = mdmc_average + super().__init__( - reduce="macro" if average in ["weighted", "none", None] else average, - mdmc_reduce=mdmc_average, threshold=threshold, top_k=top_k, num_classes=num_classes, diff --git a/torchmetrics/classification/dice.py b/torchmetrics/classification/dice.py index 09ff5be205b..d193762c63d 100644 --- a/torchmetrics/classification/dice.py +++ b/torchmetrics/classification/dice.py @@ -15,6 +15,8 @@ from torch import Tensor +from torchmetrics.utilities.enums import AverageMethod, DataType + from torchmetrics.classification.stat_scores import StatScores from torchmetrics.functional.classification.dice import _dice_compute @@ -134,9 +136,12 @@ def __init__( if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") + if "reduce" not in kwargs or kwargs["reduce"] is None: + kwargs["reduce"] = "macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE, "none"] else average + if "mdmc_reduce" not in kwargs or kwargs["mdmc_reduce"] is None: + kwargs["mdmc_reduce"] = mdmc_average + super().__init__( - reduce="macro" if average in ("weighted", "none", None) else average, - mdmc_reduce=mdmc_average, threshold=threshold, top_k=top_k, num_classes=num_classes, diff --git a/torchmetrics/classification/f_beta.py b/torchmetrics/classification/f_beta.py index 754436aac04..864bf06ff29 100644 --- a/torchmetrics/classification/f_beta.py +++ b/torchmetrics/classification/f_beta.py @@ -137,9 +137,12 @@ def __init__( if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") + if "reduce" not in kwargs or kwargs["reduce"] is None: + kwargs["reduce"] = "macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE, "none"] else average + if "mdmc_reduce" not in kwargs or kwargs["mdmc_reduce"] is None: + kwargs["mdmc_reduce"] = mdmc_average + super().__init__( - reduce="macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE] else average, - mdmc_reduce=mdmc_average, threshold=threshold, top_k=top_k, num_classes=num_classes, diff --git a/torchmetrics/classification/jaccard.py b/torchmetrics/classification/jaccard.py index 58f2438f671..0dfdf6e99d2 100644 --- a/torchmetrics/classification/jaccard.py +++ b/torchmetrics/classification/jaccard.py @@ -90,9 +90,11 @@ def __init__( multilabel: bool = False, **kwargs: Dict[str, Any], ) -> None: + if "normalize" not in kwargs: + kwargs["normalize"] = None + super().__init__( num_classes=num_classes, - normalize=None, threshold=threshold, multilabel=multilabel, **kwargs, diff --git a/torchmetrics/classification/precision_recall.py b/torchmetrics/classification/precision_recall.py index 69d6cfe0ac1..424fca19358 100644 --- a/torchmetrics/classification/precision_recall.py +++ b/torchmetrics/classification/precision_recall.py @@ -15,6 +15,8 @@ from torch import Tensor +from torchmetrics.utilities.enums import AverageMethod + from torchmetrics.classification.stat_scores import StatScores from torchmetrics.functional.classification.precision_recall import _precision_compute, _recall_compute @@ -127,9 +129,12 @@ def __init__( if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") + if "reduce" not in kwargs or kwargs["reduce"] is None: + kwargs["reduce"] = "macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE, "none"] else average + if "mdmc_reduce" not in kwargs or kwargs["mdmc_reduce"] is None: + kwargs["mdmc_reduce"] = mdmc_average + super().__init__( - reduce="macro" if average in ["weighted", "none", None] else average, - mdmc_reduce=mdmc_average, threshold=threshold, top_k=top_k, num_classes=num_classes, @@ -262,9 +267,12 @@ def __init__( if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") + if "reduce" not in kwargs or kwargs["reduce"] is None: + kwargs["reduce"] = "macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE, "none"] else average + if "mdmc_reduce" not in kwargs or kwargs["mdmc_reduce"] is None: + kwargs["mdmc_reduce"] = mdmc_average + super().__init__( - reduce="macro" if average in ["weighted", "none", None] else average, - mdmc_reduce=mdmc_average, threshold=threshold, top_k=top_k, num_classes=num_classes, diff --git a/torchmetrics/classification/specificity.py b/torchmetrics/classification/specificity.py index 9f8eabb89f6..2a2498912c3 100644 --- a/torchmetrics/classification/specificity.py +++ b/torchmetrics/classification/specificity.py @@ -16,6 +16,8 @@ import torch from torch import Tensor +from torchmetrics.utilities.enums import AverageMethod + from torchmetrics.classification.stat_scores import StatScores from torchmetrics.functional.classification.specificity import _specificity_compute @@ -129,9 +131,12 @@ def __init__( if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") + if "reduce" not in kwargs or kwargs["reduce"] is None: + kwargs["reduce"] = "macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE, "none"] else average + if "mdmc_reduce" not in kwargs or kwargs["mdmc_reduce"] is None: + kwargs["mdmc_reduce"] = mdmc_average + super().__init__( - reduce="macro" if average in ["weighted", "none", None] else average, - mdmc_reduce=mdmc_average, threshold=threshold, top_k=top_k, num_classes=num_classes, From e5d3adbb672749209fa5d5bb04509e25453c9c6d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 6 Jun 2022 02:50:33 +0000 Subject: [PATCH 02/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/classification/test_accuracy.py | 8 ++++---- tests/classification/test_dice.py | 8 ++++---- tests/classification/test_f_beta.py | 12 ++++++------ tests/classification/test_jaccard.py | 6 +++--- tests/classification/test_precision_recall.py | 12 ++++++------ tests/classification/test_specificity.py | 8 ++++---- torchmetrics/classification/accuracy.py | 2 +- torchmetrics/classification/dice.py | 5 ++--- torchmetrics/classification/f_beta.py | 2 +- torchmetrics/classification/jaccard.py | 2 +- torchmetrics/classification/precision_recall.py | 7 +++---- torchmetrics/classification/specificity.py | 5 ++--- 12 files changed, 37 insertions(+), 40 deletions(-) diff --git a/tests/classification/test_accuracy.py b/tests/classification/test_accuracy.py index 73125cfa834..3690873cbef 100644 --- a/tests/classification/test_accuracy.py +++ b/tests/classification/test_accuracy.py @@ -449,8 +449,8 @@ def test_negmetric_noneavg(noneavg=_negmetric_noneavg): result2 = acc(noneavg["pred2"], noneavg["target2"]) assert torch.allclose(noneavg["res2"], result2, equal_nan=True) + def test_provide_superclass_kwargs(): - """Test instantiating class providing superclass arguments. - """ - acc = Accuracy(reduce='micro') - acc = Accuracy(mdmc_reduce='global') + """Test instantiating class providing superclass arguments.""" + acc = Accuracy(reduce="micro") + acc = Accuracy(mdmc_reduce="global") diff --git a/tests/classification/test_dice.py b/tests/classification/test_dice.py index a85c9a25008..eaec4aa9272 100644 --- a/tests/classification/test_dice.py +++ b/tests/classification/test_dice.py @@ -164,8 +164,8 @@ def test_dice_fn(self, preds, target, ignore_index): metric_args={"ignore_index": ignore_index}, ) + def test_provide_superclass_kwargs(): - """Test instantiating class providing superclass arguments. - """ - dice = Dice(reduce='micro') - dice = Dice(mdmc_reduce='global') + """Test instantiating class providing superclass arguments.""" + dice = Dice(reduce="micro") + dice = Dice(mdmc_reduce="global") diff --git a/tests/classification/test_f_beta.py b/tests/classification/test_f_beta.py index 569c0ee7743..b44872b790e 100644 --- a/tests/classification/test_f_beta.py +++ b/tests/classification/test_f_beta.py @@ -461,10 +461,10 @@ def test_same_input(metric_class, metric_functional, sk_fn, average, ignore_inde assert torch.allclose(class_res, torch.tensor(sk_res).float()) assert torch.allclose(func_res, torch.tensor(sk_res).float()) + def test_provide_superclass_kwargs(): - """Test instantiating class providing superclass arguments. - """ - fbeta = FBetaScore(reduce='micro') - fbeta = FBetaScore(mdmc_reduce='global') - f1 = F1Score(reduce='micro') - f1 = F1Score(mdmc_reduce='global') + """Test instantiating class providing superclass arguments.""" + fbeta = FBetaScore(reduce="micro") + fbeta = FBetaScore(mdmc_reduce="global") + f1 = F1Score(reduce="micro") + f1 = F1Score(mdmc_reduce="global") diff --git a/tests/classification/test_jaccard.py b/tests/classification/test_jaccard.py index 3e93a8f4d2d..7808af74ddb 100644 --- a/tests/classification/test_jaccard.py +++ b/tests/classification/test_jaccard.py @@ -238,7 +238,7 @@ def test_jaccard_ignore_index(pred, target, ignore_index, num_classes, average, ) assert torch.allclose(jaccard_val, tensor(expected).to(jaccard_val)) + def test_provide_superclass_kwargs(): - """Test instantiating class providing superclass arguments. - """ - jaccard = JaccardIndex(num_classes=1, normalize='true') + """Test instantiating class providing superclass arguments.""" + jaccard = JaccardIndex(num_classes=1, normalize="true") diff --git a/tests/classification/test_precision_recall.py b/tests/classification/test_precision_recall.py index e750df41353..d6dfcd66001 100644 --- a/tests/classification/test_precision_recall.py +++ b/tests/classification/test_precision_recall.py @@ -469,10 +469,10 @@ def test_noneavg(metric_cls, noneavg=_negmetric_noneavg): result2 = prec(noneavg["pred2"], noneavg["target2"]) assert torch.allclose(noneavg["res2"], result2, equal_nan=True) + def test_provide_superclass_kwargs(): - """Test instantiating class providing superclass arguments. - """ - precision = Precision(reduce='micro') - precision = Precision(mdmc_reduce='global') - recall = Recall(reduce='micro') - recall = Recall(mdmc_reduce='global') + """Test instantiating class providing superclass arguments.""" + precision = Precision(reduce="micro") + precision = Precision(mdmc_reduce="global") + recall = Recall(reduce="micro") + recall = Recall(mdmc_reduce="global") diff --git a/tests/classification/test_specificity.py b/tests/classification/test_specificity.py index 0905a333992..d3dfedce03d 100644 --- a/tests/classification/test_specificity.py +++ b/tests/classification/test_specificity.py @@ -411,8 +411,8 @@ def test_class_not_present(metric_class, metric_fn, ignore_index, expected): result_cl = cl_metric.compute() assert torch.allclose(expected, result_cl, equal_nan=True) + def test_provide_superclass_kwargs(): - """Test instantiating class providing superclass arguments. - """ - specificity = Specificity(reduce='micro') - specificity = Specificity(mdmc_reduce='global') + """Test instantiating class providing superclass arguments.""" + specificity = Specificity(reduce="micro") + specificity = Specificity(mdmc_reduce="global") diff --git a/torchmetrics/classification/accuracy.py b/torchmetrics/classification/accuracy.py index a3d977d7648..15d9f3c1646 100644 --- a/torchmetrics/classification/accuracy.py +++ b/torchmetrics/classification/accuracy.py @@ -180,7 +180,7 @@ def __init__( kwargs["reduce"] = "macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE, "none"] else average if "mdmc_reduce" not in kwargs or kwargs["mdmc_reduce"] is None: kwargs["mdmc_reduce"] = mdmc_average - + super().__init__( threshold=threshold, top_k=top_k, diff --git a/torchmetrics/classification/dice.py b/torchmetrics/classification/dice.py index d193762c63d..f61a4fdb1ba 100644 --- a/torchmetrics/classification/dice.py +++ b/torchmetrics/classification/dice.py @@ -15,10 +15,9 @@ from torch import Tensor -from torchmetrics.utilities.enums import AverageMethod, DataType - from torchmetrics.classification.stat_scores import StatScores from torchmetrics.functional.classification.dice import _dice_compute +from torchmetrics.utilities.enums import AverageMethod, DataType class Dice(StatScores): @@ -140,7 +139,7 @@ def __init__( kwargs["reduce"] = "macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE, "none"] else average if "mdmc_reduce" not in kwargs or kwargs["mdmc_reduce"] is None: kwargs["mdmc_reduce"] = mdmc_average - + super().__init__( threshold=threshold, top_k=top_k, diff --git a/torchmetrics/classification/f_beta.py b/torchmetrics/classification/f_beta.py index 864bf06ff29..6577badad65 100644 --- a/torchmetrics/classification/f_beta.py +++ b/torchmetrics/classification/f_beta.py @@ -141,7 +141,7 @@ def __init__( kwargs["reduce"] = "macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE, "none"] else average if "mdmc_reduce" not in kwargs or kwargs["mdmc_reduce"] is None: kwargs["mdmc_reduce"] = mdmc_average - + super().__init__( threshold=threshold, top_k=top_k, diff --git a/torchmetrics/classification/jaccard.py b/torchmetrics/classification/jaccard.py index 0dfdf6e99d2..801d1332be9 100644 --- a/torchmetrics/classification/jaccard.py +++ b/torchmetrics/classification/jaccard.py @@ -92,7 +92,7 @@ def __init__( ) -> None: if "normalize" not in kwargs: kwargs["normalize"] = None - + super().__init__( num_classes=num_classes, threshold=threshold, diff --git a/torchmetrics/classification/precision_recall.py b/torchmetrics/classification/precision_recall.py index 424fca19358..8ddf7d01c0f 100644 --- a/torchmetrics/classification/precision_recall.py +++ b/torchmetrics/classification/precision_recall.py @@ -15,10 +15,9 @@ from torch import Tensor -from torchmetrics.utilities.enums import AverageMethod - from torchmetrics.classification.stat_scores import StatScores from torchmetrics.functional.classification.precision_recall import _precision_compute, _recall_compute +from torchmetrics.utilities.enums import AverageMethod class Precision(StatScores): @@ -133,7 +132,7 @@ def __init__( kwargs["reduce"] = "macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE, "none"] else average if "mdmc_reduce" not in kwargs or kwargs["mdmc_reduce"] is None: kwargs["mdmc_reduce"] = mdmc_average - + super().__init__( threshold=threshold, top_k=top_k, @@ -271,7 +270,7 @@ def __init__( kwargs["reduce"] = "macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE, "none"] else average if "mdmc_reduce" not in kwargs or kwargs["mdmc_reduce"] is None: kwargs["mdmc_reduce"] = mdmc_average - + super().__init__( threshold=threshold, top_k=top_k, diff --git a/torchmetrics/classification/specificity.py b/torchmetrics/classification/specificity.py index 2a2498912c3..37af6ffcd80 100644 --- a/torchmetrics/classification/specificity.py +++ b/torchmetrics/classification/specificity.py @@ -16,10 +16,9 @@ import torch from torch import Tensor -from torchmetrics.utilities.enums import AverageMethod - from torchmetrics.classification.stat_scores import StatScores from torchmetrics.functional.classification.specificity import _specificity_compute +from torchmetrics.utilities.enums import AverageMethod class Specificity(StatScores): @@ -135,7 +134,7 @@ def __init__( kwargs["reduce"] = "macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE, "none"] else average if "mdmc_reduce" not in kwargs or kwargs["mdmc_reduce"] is None: kwargs["mdmc_reduce"] = mdmc_average - + super().__init__( threshold=threshold, top_k=top_k, From c4544cb0c7cf4a9b18c27957fcfd91ff154bebc4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Louren=C3=A7o=20Silva?= Date: Mon, 6 Jun 2022 04:40:54 +0100 Subject: [PATCH 03/10] Fix code formatting MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: João Lourenço Silva --- tests/classification/test_accuracy.py | 8 ++++---- tests/classification/test_dice.py | 8 ++++---- tests/classification/test_f_beta.py | 12 ++++++------ tests/classification/test_precision_recall.py | 12 ++++++------ tests/classification/test_specificity.py | 8 ++++---- torchmetrics/classification/accuracy.py | 2 +- torchmetrics/classification/dice.py | 5 ++--- torchmetrics/classification/f_beta.py | 2 +- torchmetrics/classification/jaccard.py | 2 +- torchmetrics/classification/precision_recall.py | 7 +++---- torchmetrics/classification/specificity.py | 5 ++--- torchmetrics/functional/audio/pesq.py | 6 +++--- torchmetrics/functional/audio/stoi.py | 5 ++--- 13 files changed, 39 insertions(+), 43 deletions(-) diff --git a/tests/classification/test_accuracy.py b/tests/classification/test_accuracy.py index 73125cfa834..53515f9ef37 100644 --- a/tests/classification/test_accuracy.py +++ b/tests/classification/test_accuracy.py @@ -449,8 +449,8 @@ def test_negmetric_noneavg(noneavg=_negmetric_noneavg): result2 = acc(noneavg["pred2"], noneavg["target2"]) assert torch.allclose(noneavg["res2"], result2, equal_nan=True) + def test_provide_superclass_kwargs(): - """Test instantiating class providing superclass arguments. - """ - acc = Accuracy(reduce='micro') - acc = Accuracy(mdmc_reduce='global') + """Test instantiating class providing superclass arguments.""" + Accuracy(reduce="micro") + Accuracy(mdmc_reduce="global") diff --git a/tests/classification/test_dice.py b/tests/classification/test_dice.py index a85c9a25008..921e0b39274 100644 --- a/tests/classification/test_dice.py +++ b/tests/classification/test_dice.py @@ -164,8 +164,8 @@ def test_dice_fn(self, preds, target, ignore_index): metric_args={"ignore_index": ignore_index}, ) + def test_provide_superclass_kwargs(): - """Test instantiating class providing superclass arguments. - """ - dice = Dice(reduce='micro') - dice = Dice(mdmc_reduce='global') + """Test instantiating class providing superclass arguments.""" + Dice(reduce="micro") + Dice(mdmc_reduce="global") diff --git a/tests/classification/test_f_beta.py b/tests/classification/test_f_beta.py index 569c0ee7743..452d25c1e4f 100644 --- a/tests/classification/test_f_beta.py +++ b/tests/classification/test_f_beta.py @@ -461,10 +461,10 @@ def test_same_input(metric_class, metric_functional, sk_fn, average, ignore_inde assert torch.allclose(class_res, torch.tensor(sk_res).float()) assert torch.allclose(func_res, torch.tensor(sk_res).float()) + def test_provide_superclass_kwargs(): - """Test instantiating class providing superclass arguments. - """ - fbeta = FBetaScore(reduce='micro') - fbeta = FBetaScore(mdmc_reduce='global') - f1 = F1Score(reduce='micro') - f1 = F1Score(mdmc_reduce='global') + """Test instantiating class providing superclass arguments.""" + FBetaScore(reduce="micro") + FBetaScore(mdmc_reduce="global") + F1Score(reduce="micro") + F1Score(mdmc_reduce="global") diff --git a/tests/classification/test_precision_recall.py b/tests/classification/test_precision_recall.py index e750df41353..45d64fc5901 100644 --- a/tests/classification/test_precision_recall.py +++ b/tests/classification/test_precision_recall.py @@ -469,10 +469,10 @@ def test_noneavg(metric_cls, noneavg=_negmetric_noneavg): result2 = prec(noneavg["pred2"], noneavg["target2"]) assert torch.allclose(noneavg["res2"], result2, equal_nan=True) + def test_provide_superclass_kwargs(): - """Test instantiating class providing superclass arguments. - """ - precision = Precision(reduce='micro') - precision = Precision(mdmc_reduce='global') - recall = Recall(reduce='micro') - recall = Recall(mdmc_reduce='global') + """Test instantiating class providing superclass arguments.""" + Precision(reduce="micro") + Precision(mdmc_reduce="global") + Recall(reduce="micro") + Recall(mdmc_reduce="global") diff --git a/tests/classification/test_specificity.py b/tests/classification/test_specificity.py index 0905a333992..3d1ee3dcbe3 100644 --- a/tests/classification/test_specificity.py +++ b/tests/classification/test_specificity.py @@ -411,8 +411,8 @@ def test_class_not_present(metric_class, metric_fn, ignore_index, expected): result_cl = cl_metric.compute() assert torch.allclose(expected, result_cl, equal_nan=True) + def test_provide_superclass_kwargs(): - """Test instantiating class providing superclass arguments. - """ - specificity = Specificity(reduce='micro') - specificity = Specificity(mdmc_reduce='global') + """Test instantiating class providing superclass arguments.""" + Specificity(reduce="micro") + Specificity(mdmc_reduce="global") diff --git a/torchmetrics/classification/accuracy.py b/torchmetrics/classification/accuracy.py index a3d977d7648..15d9f3c1646 100644 --- a/torchmetrics/classification/accuracy.py +++ b/torchmetrics/classification/accuracy.py @@ -180,7 +180,7 @@ def __init__( kwargs["reduce"] = "macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE, "none"] else average if "mdmc_reduce" not in kwargs or kwargs["mdmc_reduce"] is None: kwargs["mdmc_reduce"] = mdmc_average - + super().__init__( threshold=threshold, top_k=top_k, diff --git a/torchmetrics/classification/dice.py b/torchmetrics/classification/dice.py index d193762c63d..85082182dd7 100644 --- a/torchmetrics/classification/dice.py +++ b/torchmetrics/classification/dice.py @@ -15,10 +15,9 @@ from torch import Tensor -from torchmetrics.utilities.enums import AverageMethod, DataType - from torchmetrics.classification.stat_scores import StatScores from torchmetrics.functional.classification.dice import _dice_compute +from torchmetrics.utilities.enums import AverageMethod class Dice(StatScores): @@ -140,7 +139,7 @@ def __init__( kwargs["reduce"] = "macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE, "none"] else average if "mdmc_reduce" not in kwargs or kwargs["mdmc_reduce"] is None: kwargs["mdmc_reduce"] = mdmc_average - + super().__init__( threshold=threshold, top_k=top_k, diff --git a/torchmetrics/classification/f_beta.py b/torchmetrics/classification/f_beta.py index 864bf06ff29..6577badad65 100644 --- a/torchmetrics/classification/f_beta.py +++ b/torchmetrics/classification/f_beta.py @@ -141,7 +141,7 @@ def __init__( kwargs["reduce"] = "macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE, "none"] else average if "mdmc_reduce" not in kwargs or kwargs["mdmc_reduce"] is None: kwargs["mdmc_reduce"] = mdmc_average - + super().__init__( threshold=threshold, top_k=top_k, diff --git a/torchmetrics/classification/jaccard.py b/torchmetrics/classification/jaccard.py index 0dfdf6e99d2..801d1332be9 100644 --- a/torchmetrics/classification/jaccard.py +++ b/torchmetrics/classification/jaccard.py @@ -92,7 +92,7 @@ def __init__( ) -> None: if "normalize" not in kwargs: kwargs["normalize"] = None - + super().__init__( num_classes=num_classes, threshold=threshold, diff --git a/torchmetrics/classification/precision_recall.py b/torchmetrics/classification/precision_recall.py index 424fca19358..8ddf7d01c0f 100644 --- a/torchmetrics/classification/precision_recall.py +++ b/torchmetrics/classification/precision_recall.py @@ -15,10 +15,9 @@ from torch import Tensor -from torchmetrics.utilities.enums import AverageMethod - from torchmetrics.classification.stat_scores import StatScores from torchmetrics.functional.classification.precision_recall import _precision_compute, _recall_compute +from torchmetrics.utilities.enums import AverageMethod class Precision(StatScores): @@ -133,7 +132,7 @@ def __init__( kwargs["reduce"] = "macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE, "none"] else average if "mdmc_reduce" not in kwargs or kwargs["mdmc_reduce"] is None: kwargs["mdmc_reduce"] = mdmc_average - + super().__init__( threshold=threshold, top_k=top_k, @@ -271,7 +270,7 @@ def __init__( kwargs["reduce"] = "macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE, "none"] else average if "mdmc_reduce" not in kwargs or kwargs["mdmc_reduce"] is None: kwargs["mdmc_reduce"] = mdmc_average - + super().__init__( threshold=threshold, top_k=top_k, diff --git a/torchmetrics/classification/specificity.py b/torchmetrics/classification/specificity.py index 2a2498912c3..37af6ffcd80 100644 --- a/torchmetrics/classification/specificity.py +++ b/torchmetrics/classification/specificity.py @@ -16,10 +16,9 @@ import torch from torch import Tensor -from torchmetrics.utilities.enums import AverageMethod - from torchmetrics.classification.stat_scores import StatScores from torchmetrics.functional.classification.specificity import _specificity_compute +from torchmetrics.utilities.enums import AverageMethod class Specificity(StatScores): @@ -135,7 +134,7 @@ def __init__( kwargs["reduce"] = "macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE, "none"] else average if "mdmc_reduce" not in kwargs or kwargs["mdmc_reduce"] is None: kwargs["mdmc_reduce"] = mdmc_average - + super().__init__( threshold=threshold, top_k=top_k, diff --git a/torchmetrics/functional/audio/pesq.py b/torchmetrics/functional/audio/pesq.py index 73fcf541d34..7456edf9763 100644 --- a/torchmetrics/functional/audio/pesq.py +++ b/torchmetrics/functional/audio/pesq.py @@ -12,17 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. import numpy as np +import torch +from torch import Tensor +from torchmetrics.utilities.checks import _check_same_shape from torchmetrics.utilities.imports import _PESQ_AVAILABLE if _PESQ_AVAILABLE: import pesq as pesq_backend else: pesq_backend = None -import torch -from torch import Tensor -from torchmetrics.utilities.checks import _check_same_shape __doctest_requires__ = {("perceptual_evaluation_speech_quality",): ["pesq"]} diff --git a/torchmetrics/functional/audio/stoi.py b/torchmetrics/functional/audio/stoi.py index edd1673afa9..fe9cb5a9586 100644 --- a/torchmetrics/functional/audio/stoi.py +++ b/torchmetrics/functional/audio/stoi.py @@ -13,7 +13,9 @@ # limitations under the License. import numpy as np import torch +from torch import Tensor +from torchmetrics.utilities.checks import _check_same_shape from torchmetrics.utilities.imports import _PYSTOI_AVAILABLE if _PYSTOI_AVAILABLE: @@ -21,9 +23,6 @@ else: stoi_backend = None __doctest_skip__ = ["short_time_objective_intelligibility"] -from torch import Tensor - -from torchmetrics.utilities.checks import _check_same_shape def short_time_objective_intelligibility( From a9e0c5b507bdb3f0a09f685b2b4c0afcc70a9dd3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Louren=C3=A7o=20Silva?= Date: Tue, 7 Jun 2022 00:58:15 +0100 Subject: [PATCH 04/10] Moved tests to the superclass test files --- tests/classification/test_accuracy.py | 6 ----- tests/classification/test_confusion_matrix.py | 20 ++++++++++++++++ tests/classification/test_dice.py | 6 ----- tests/classification/test_f_beta.py | 8 ------- tests/classification/test_precision_recall.py | 8 ------- tests/classification/test_specificity.py | 6 ----- tests/classification/test_stat_scores.py | 23 +++++++++++++++++-- 7 files changed, 41 insertions(+), 36 deletions(-) diff --git a/tests/classification/test_accuracy.py b/tests/classification/test_accuracy.py index 53515f9ef37..40c27f20368 100644 --- a/tests/classification/test_accuracy.py +++ b/tests/classification/test_accuracy.py @@ -448,9 +448,3 @@ def test_negmetric_noneavg(noneavg=_negmetric_noneavg): assert torch.allclose(noneavg["res1"], result1, equal_nan=True) result2 = acc(noneavg["pred2"], noneavg["target2"]) assert torch.allclose(noneavg["res2"], result2, equal_nan=True) - - -def test_provide_superclass_kwargs(): - """Test instantiating class providing superclass arguments.""" - Accuracy(reduce="micro") - Accuracy(mdmc_reduce="global") diff --git a/tests/classification/test_confusion_matrix.py b/tests/classification/test_confusion_matrix.py index 08ce3e3fb4d..402ffb73f98 100644 --- a/tests/classification/test_confusion_matrix.py +++ b/tests/classification/test_confusion_matrix.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial +from itertools import product +from typing import Any, Dict import numpy as np import pytest @@ -30,6 +32,7 @@ from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob from tests.helpers import seed_all from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester +from torchmetrics import JaccardIndex from torchmetrics.classification.confusion_matrix import ConfusionMatrix from torchmetrics.functional.classification.confusion_matrix import confusion_matrix @@ -186,3 +189,20 @@ def test_warning_on_nan(tmpdir): match=".* nan values found in confusion matrix have been replaced with zeros.", ): confusion_matrix(preds, target, num_classes=5, normalize="true") + + +kwarg_options = [ + {"num_classes": 1, "normalize": "true"}, + {"num_classes": 1, "normalize": "pred"}, + {"num_classes": 1, "normalize": "all"}, + {"num_classes": 1, "normalize": "none"}, + {"num_classes": 1, "normalize": None}, +] +subclasses = [JaccardIndex] +params = list(product(subclasses, kwarg_options)) + + +@pytest.mark.parametrize("metric_cls, kwargs", params) +def test_provide_superclass_kwargs(metric_cls: ConfusionMatrix, kwargs: Dict[str, Any]): + """Test instantiating subclasses with superclass arguments as kwargs.""" + metric_cls(**kwargs) diff --git a/tests/classification/test_dice.py b/tests/classification/test_dice.py index 921e0b39274..445f926d70e 100644 --- a/tests/classification/test_dice.py +++ b/tests/classification/test_dice.py @@ -163,9 +163,3 @@ def test_dice_fn(self, preds, target, ignore_index): sk_metric=partial(_sk_dice, ignore_index=ignore_index), metric_args={"ignore_index": ignore_index}, ) - - -def test_provide_superclass_kwargs(): - """Test instantiating class providing superclass arguments.""" - Dice(reduce="micro") - Dice(mdmc_reduce="global") diff --git a/tests/classification/test_f_beta.py b/tests/classification/test_f_beta.py index 452d25c1e4f..1c1a9032b4a 100644 --- a/tests/classification/test_f_beta.py +++ b/tests/classification/test_f_beta.py @@ -460,11 +460,3 @@ def test_same_input(metric_class, metric_functional, sk_fn, average, ignore_inde assert torch.allclose(class_res, torch.tensor(sk_res).float()) assert torch.allclose(func_res, torch.tensor(sk_res).float()) - - -def test_provide_superclass_kwargs(): - """Test instantiating class providing superclass arguments.""" - FBetaScore(reduce="micro") - FBetaScore(mdmc_reduce="global") - F1Score(reduce="micro") - F1Score(mdmc_reduce="global") diff --git a/tests/classification/test_precision_recall.py b/tests/classification/test_precision_recall.py index 45d64fc5901..1c117d70591 100644 --- a/tests/classification/test_precision_recall.py +++ b/tests/classification/test_precision_recall.py @@ -468,11 +468,3 @@ def test_noneavg(metric_cls, noneavg=_negmetric_noneavg): assert torch.allclose(noneavg["res1"], result1, equal_nan=True) result2 = prec(noneavg["pred2"], noneavg["target2"]) assert torch.allclose(noneavg["res2"], result2, equal_nan=True) - - -def test_provide_superclass_kwargs(): - """Test instantiating class providing superclass arguments.""" - Precision(reduce="micro") - Precision(mdmc_reduce="global") - Recall(reduce="micro") - Recall(mdmc_reduce="global") diff --git a/tests/classification/test_specificity.py b/tests/classification/test_specificity.py index 3d1ee3dcbe3..dc6f38df5ce 100644 --- a/tests/classification/test_specificity.py +++ b/tests/classification/test_specificity.py @@ -410,9 +410,3 @@ def test_class_not_present(metric_class, metric_fn, ignore_index, expected): cl_metric(preds, target) result_cl = cl_metric.compute() assert torch.allclose(expected, result_cl, equal_nan=True) - - -def test_provide_superclass_kwargs(): - """Test instantiating class providing superclass arguments.""" - Specificity(reduce="micro") - Specificity(mdmc_reduce="global") diff --git a/tests/classification/test_stat_scores.py b/tests/classification/test_stat_scores.py index 6f19e1c148b..73b451876cd 100644 --- a/tests/classification/test_stat_scores.py +++ b/tests/classification/test_stat_scores.py @@ -12,7 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Callable, Optional +from itertools import product +from typing import Any, Callable, Dict, Optional import numpy as np import pytest @@ -30,7 +31,7 @@ from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob from tests.helpers import seed_all from tests.helpers.testers import NUM_CLASSES, MetricTester -from torchmetrics import StatScores +from torchmetrics import Accuracy, Dice, FBetaScore, Precision, Recall, Specificity, StatScores from torchmetrics.functional import stat_scores from torchmetrics.utilities.checks import _input_format_classification @@ -326,3 +327,21 @@ def test_top_k(k: int, preds: Tensor, target: Tensor, reduce: str, expected: Ten assert torch.equal(class_metric.compute(), expected.T) assert torch.equal(stat_scores(preds, target, top_k=k, reduce=reduce, num_classes=3), expected.T) + + +kwarg_options = [ + {"reduce": "micro"}, + {"num_classes": 1, "reduce": "macro"}, + {"reduce": "samples"}, + {"mdmc_reduce": None}, + {"mdmc_reduce": "samplewise"}, + {"mdmc_reduce": "global"}, +] +subclasses = [Accuracy, Dice, FBetaScore, Precision, Recall, Specificity] +params = list(product(subclasses, kwarg_options)) + + +@pytest.mark.parametrize("metric_cls, kwargs", params) +def test_provide_superclass_kwargs(metric_cls: StatScores, kwargs: Dict[str, Any]): + """Test instantiating subclasses with superclass arguments as kwargs.""" + metric_cls(**kwargs) From 3414a7b482a48d2ed334e965eda9f533f486d280 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jo=C3=A3o=20Louren=C3=A7o=20Silva?= Date: Tue, 7 Jun 2022 01:03:51 +0100 Subject: [PATCH 05/10] Moved tests to the superclass test files --- tests/classification/test_jaccard.py | 5 ----- tests/text/test_mer.py | 5 ++--- tests/text/test_wer.py | 5 ++--- 3 files changed, 4 insertions(+), 11 deletions(-) diff --git a/tests/classification/test_jaccard.py b/tests/classification/test_jaccard.py index 7808af74ddb..80dc44c0b2a 100644 --- a/tests/classification/test_jaccard.py +++ b/tests/classification/test_jaccard.py @@ -237,8 +237,3 @@ def test_jaccard_ignore_index(pred, target, ignore_index, num_classes, average, # reduction=reduction, ) assert torch.allclose(jaccard_val, tensor(expected).to(jaccard_val)) - - -def test_provide_superclass_kwargs(): - """Test instantiating class providing superclass arguments.""" - jaccard = JaccardIndex(num_classes=1, normalize="true") diff --git a/tests/text/test_mer.py b/tests/text/test_mer.py index 8153a915317..eaa79618e32 100644 --- a/tests/text/test_mer.py +++ b/tests/text/test_mer.py @@ -4,6 +4,8 @@ from tests.text.helpers import TextTester from tests.text.inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 +from torchmetrics.functional.text.mer import match_error_rate +from torchmetrics.text.mer import MatchErrorRate from torchmetrics.utilities.imports import _JIWER_AVAILABLE if _JIWER_AVAILABLE: @@ -11,9 +13,6 @@ else: compute_measures: Callable -from torchmetrics.functional.text.mer import match_error_rate -from torchmetrics.text.mer import MatchErrorRate - def _compute_mer_metric_jiwer(preds: Union[str, List[str]], target: Union[str, List[str]]): return compute_measures(target, preds)["mer"] diff --git a/tests/text/test_wer.py b/tests/text/test_wer.py index 00c5d2ddc64..c70cac7c72f 100644 --- a/tests/text/test_wer.py +++ b/tests/text/test_wer.py @@ -4,6 +4,8 @@ from tests.text.helpers import TextTester from tests.text.inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 +from torchmetrics.functional.text.wer import word_error_rate +from torchmetrics.text.wer import WordErrorRate from torchmetrics.utilities.imports import _JIWER_AVAILABLE if _JIWER_AVAILABLE: @@ -11,9 +13,6 @@ else: compute_measures: Callable -from torchmetrics.functional.text.wer import word_error_rate -from torchmetrics.text.wer import WordErrorRate - def _compute_wer_metric_jiwer(preds: Union[str, List[str]], target: Union[str, List[str]]): return compute_measures(target, preds)["wer"] From b65dfc59952e772f3fda4e06a1e0ef2fb03161c1 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 7 Jun 2022 11:05:33 +0200 Subject: [PATCH 06/10] simple --- tests/classification/test_confusion_matrix.py | 27 ++++++++--------- tests/classification/test_stat_scores.py | 30 +++++++++---------- 2 files changed, 26 insertions(+), 31 deletions(-) diff --git a/tests/classification/test_confusion_matrix.py b/tests/classification/test_confusion_matrix.py index 402ffb73f98..41370f3f7dd 100644 --- a/tests/classification/test_confusion_matrix.py +++ b/tests/classification/test_confusion_matrix.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from itertools import product from typing import Any, Dict import numpy as np @@ -191,18 +190,16 @@ def test_warning_on_nan(tmpdir): confusion_matrix(preds, target, num_classes=5, normalize="true") -kwarg_options = [ - {"num_classes": 1, "normalize": "true"}, - {"num_classes": 1, "normalize": "pred"}, - {"num_classes": 1, "normalize": "all"}, - {"num_classes": 1, "normalize": "none"}, - {"num_classes": 1, "normalize": None}, -] -subclasses = [JaccardIndex] -params = list(product(subclasses, kwarg_options)) - - -@pytest.mark.parametrize("metric_cls, kwargs", params) -def test_provide_superclass_kwargs(metric_cls: ConfusionMatrix, kwargs: Dict[str, Any]): +@pytest.mark.parametrize( + "metric_args", + [ + {"num_classes": 1, "normalize": "true"}, + {"num_classes": 1, "normalize": "pred"}, + {"num_classes": 1, "normalize": "all"}, + {"num_classes": 1, "normalize": "none"}, + {"num_classes": 1, "normalize": None}, + ], +) +def test_provide_superclass_kwargs(metric_args: Dict[str, Any]): """Test instantiating subclasses with superclass arguments as kwargs.""" - metric_cls(**kwargs) + JaccardIndex(**metric_args) diff --git a/tests/classification/test_stat_scores.py b/tests/classification/test_stat_scores.py index 73b451876cd..aa49a76b180 100644 --- a/tests/classification/test_stat_scores.py +++ b/tests/classification/test_stat_scores.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from itertools import product from typing import Any, Callable, Dict, Optional import numpy as np @@ -329,19 +328,18 @@ def test_top_k(k: int, preds: Tensor, target: Tensor, reduce: str, expected: Ten assert torch.equal(stat_scores(preds, target, top_k=k, reduce=reduce, num_classes=3), expected.T) -kwarg_options = [ - {"reduce": "micro"}, - {"num_classes": 1, "reduce": "macro"}, - {"reduce": "samples"}, - {"mdmc_reduce": None}, - {"mdmc_reduce": "samplewise"}, - {"mdmc_reduce": "global"}, -] -subclasses = [Accuracy, Dice, FBetaScore, Precision, Recall, Specificity] -params = list(product(subclasses, kwarg_options)) - - -@pytest.mark.parametrize("metric_cls, kwargs", params) -def test_provide_superclass_kwargs(metric_cls: StatScores, kwargs: Dict[str, Any]): +@pytest.mark.parametrize( + "metric_args", + [ + {"reduce": "micro"}, + {"num_classes": 1, "reduce": "macro"}, + {"reduce": "samples"}, + {"mdmc_reduce": None}, + {"mdmc_reduce": "samplewise"}, + {"mdmc_reduce": "global"}, + ], +) +@pytest.mark.parametrize("metric_cls", [Accuracy, Dice, FBetaScore, Precision, Recall, Specificity]) +def test_provide_superclass_kwargs(metric_cls: StatScores, metric_args: Dict[str, Any]): """Test instantiating subclasses with superclass arguments as kwargs.""" - metric_cls(**kwargs) + metric_cls(**metric_args) From d45537fc8ece85e52556425d58c632d58efea5f6 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 7 Jun 2022 11:07:15 +0200 Subject: [PATCH 07/10] chlog --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ee1bc72d827..78e26e65b74 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -41,7 +41,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - -- + +- Fixed `TypeError` when providing superclass arguments as kwargs ([#1069](https://github.com/PyTorchLightning/metrics/pull/1069)) ## [0.9.0] - 2022-05-30 From ba0b25f3ebb481b7cff0c1f3c04946243e53f662 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 7 Jun 2022 21:23:30 +0200 Subject: [PATCH 08/10] simple --- torchmetrics/classification/accuracy.py | 7 ++++--- torchmetrics/classification/dice.py | 7 ++++--- torchmetrics/classification/f_beta.py | 7 ++++--- torchmetrics/classification/jaccard.py | 3 +-- torchmetrics/classification/precision_recall.py | 14 ++++++++------ torchmetrics/classification/specificity.py | 7 ++++--- 6 files changed, 25 insertions(+), 20 deletions(-) diff --git a/torchmetrics/classification/accuracy.py b/torchmetrics/classification/accuracy.py index 15d9f3c1646..1d238ebf6a7 100644 --- a/torchmetrics/classification/accuracy.py +++ b/torchmetrics/classification/accuracy.py @@ -176,9 +176,10 @@ def __init__( if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - if "reduce" not in kwargs or kwargs["reduce"] is None: - kwargs["reduce"] = "macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE, "none"] else average - if "mdmc_reduce" not in kwargs or kwargs["mdmc_reduce"] is None: + _reduce_options = (AverageMethod.WEIGHTED, AverageMethod.NONE, None) + if "reduce" not in kwargs: + kwargs["reduce"] = AverageMethod.MACRO if average in _reduce_options else average + if "mdmc_reduce" not in kwargs: kwargs["mdmc_reduce"] = mdmc_average super().__init__( diff --git a/torchmetrics/classification/dice.py b/torchmetrics/classification/dice.py index 85082182dd7..3d5aa827afd 100644 --- a/torchmetrics/classification/dice.py +++ b/torchmetrics/classification/dice.py @@ -135,9 +135,10 @@ def __init__( if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - if "reduce" not in kwargs or kwargs["reduce"] is None: - kwargs["reduce"] = "macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE, "none"] else average - if "mdmc_reduce" not in kwargs or kwargs["mdmc_reduce"] is None: + _reduce_options = (AverageMethod.WEIGHTED, AverageMethod.NONE, None) + if "reduce" not in kwargs: + kwargs["reduce"] = AverageMethod.MACRO if average in _reduce_options else average + if "mdmc_reduce" not in kwargs: kwargs["mdmc_reduce"] = mdmc_average super().__init__( diff --git a/torchmetrics/classification/f_beta.py b/torchmetrics/classification/f_beta.py index 6577badad65..90ab0ad1f59 100644 --- a/torchmetrics/classification/f_beta.py +++ b/torchmetrics/classification/f_beta.py @@ -137,9 +137,10 @@ def __init__( if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - if "reduce" not in kwargs or kwargs["reduce"] is None: - kwargs["reduce"] = "macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE, "none"] else average - if "mdmc_reduce" not in kwargs or kwargs["mdmc_reduce"] is None: + _reduce_options = (AverageMethod.WEIGHTED, AverageMethod.NONE, None) + if "reduce" not in kwargs: + kwargs["reduce"] = AverageMethod.MACRO if average in _reduce_options else average + if "mdmc_reduce" not in kwargs: kwargs["mdmc_reduce"] = mdmc_average super().__init__( diff --git a/torchmetrics/classification/jaccard.py b/torchmetrics/classification/jaccard.py index 801d1332be9..c43dc40f21a 100644 --- a/torchmetrics/classification/jaccard.py +++ b/torchmetrics/classification/jaccard.py @@ -90,8 +90,7 @@ def __init__( multilabel: bool = False, **kwargs: Dict[str, Any], ) -> None: - if "normalize" not in kwargs: - kwargs["normalize"] = None + kwargs["normalize"] = kwargs.get("normalize") super().__init__( num_classes=num_classes, diff --git a/torchmetrics/classification/precision_recall.py b/torchmetrics/classification/precision_recall.py index 8ddf7d01c0f..698f12b6a84 100644 --- a/torchmetrics/classification/precision_recall.py +++ b/torchmetrics/classification/precision_recall.py @@ -128,9 +128,10 @@ def __init__( if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - if "reduce" not in kwargs or kwargs["reduce"] is None: - kwargs["reduce"] = "macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE, "none"] else average - if "mdmc_reduce" not in kwargs or kwargs["mdmc_reduce"] is None: + _reduce_options = (AverageMethod.WEIGHTED, AverageMethod.NONE, None) + if "reduce" not in kwargs: + kwargs["reduce"] = AverageMethod.MACRO if average in _reduce_options else average + if "mdmc_reduce" not in kwargs: kwargs["mdmc_reduce"] = mdmc_average super().__init__( @@ -266,9 +267,10 @@ def __init__( if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - if "reduce" not in kwargs or kwargs["reduce"] is None: - kwargs["reduce"] = "macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE, "none"] else average - if "mdmc_reduce" not in kwargs or kwargs["mdmc_reduce"] is None: + _reduce_options = (AverageMethod.WEIGHTED, AverageMethod.NONE, None) + if "reduce" not in kwargs: + kwargs["reduce"] = AverageMethod.MACRO if average in _reduce_options else average + if "mdmc_reduce" not in kwargs: kwargs["mdmc_reduce"] = mdmc_average super().__init__( diff --git a/torchmetrics/classification/specificity.py b/torchmetrics/classification/specificity.py index 37af6ffcd80..d82bf6d4cc6 100644 --- a/torchmetrics/classification/specificity.py +++ b/torchmetrics/classification/specificity.py @@ -130,9 +130,10 @@ def __init__( if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") - if "reduce" not in kwargs or kwargs["reduce"] is None: - kwargs["reduce"] = "macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE, "none"] else average - if "mdmc_reduce" not in kwargs or kwargs["mdmc_reduce"] is None: + _reduce_options = (AverageMethod.WEIGHTED, AverageMethod.NONE, None) + if "reduce" not in kwargs: + kwargs["reduce"] = AverageMethod.MACRO if average in _reduce_options else average + if "mdmc_reduce" not in kwargs: kwargs["mdmc_reduce"] = mdmc_average super().__init__( From f77d181064724fdec2c96df443105eda06a81e94 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 7 Jun 2022 21:53:54 +0200 Subject: [PATCH 09/10] try Any --- torchmetrics/classification/accuracy.py | 2 +- torchmetrics/classification/dice.py | 2 +- torchmetrics/classification/f_beta.py | 2 +- torchmetrics/classification/jaccard.py | 2 +- torchmetrics/classification/precision_recall.py | 4 ++-- torchmetrics/classification/specificity.py | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/torchmetrics/classification/accuracy.py b/torchmetrics/classification/accuracy.py index 1d238ebf6a7..5fb8e12b543 100644 --- a/torchmetrics/classification/accuracy.py +++ b/torchmetrics/classification/accuracy.py @@ -170,7 +170,7 @@ def __init__( top_k: Optional[int] = None, multiclass: Optional[bool] = None, subset_accuracy: bool = False, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: allowed_average = ["micro", "macro", "weighted", "samples", "none", None] if average not in allowed_average: diff --git a/torchmetrics/classification/dice.py b/torchmetrics/classification/dice.py index 3d5aa827afd..e444b94002f 100644 --- a/torchmetrics/classification/dice.py +++ b/torchmetrics/classification/dice.py @@ -129,7 +129,7 @@ def __init__( ignore_index: Optional[int] = None, top_k: Optional[int] = None, multiclass: Optional[bool] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: allowed_average = ("micro", "macro", "weighted", "samples", "none", None) if average not in allowed_average: diff --git a/torchmetrics/classification/f_beta.py b/torchmetrics/classification/f_beta.py index 90ab0ad1f59..01de148e8cf 100644 --- a/torchmetrics/classification/f_beta.py +++ b/torchmetrics/classification/f_beta.py @@ -130,7 +130,7 @@ def __init__( ignore_index: Optional[int] = None, top_k: Optional[int] = None, multiclass: Optional[bool] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: self.beta = beta allowed_average = list(AverageMethod) diff --git a/torchmetrics/classification/jaccard.py b/torchmetrics/classification/jaccard.py index c43dc40f21a..0328b614e66 100644 --- a/torchmetrics/classification/jaccard.py +++ b/torchmetrics/classification/jaccard.py @@ -88,7 +88,7 @@ def __init__( absent_score: float = 0.0, threshold: float = 0.5, multilabel: bool = False, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: kwargs["normalize"] = kwargs.get("normalize") diff --git a/torchmetrics/classification/precision_recall.py b/torchmetrics/classification/precision_recall.py index 698f12b6a84..a932faf4653 100644 --- a/torchmetrics/classification/precision_recall.py +++ b/torchmetrics/classification/precision_recall.py @@ -122,7 +122,7 @@ def __init__( ignore_index: Optional[int] = None, top_k: Optional[int] = None, multiclass: Optional[bool] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: allowed_average = ["micro", "macro", "weighted", "samples", "none", None] if average not in allowed_average: @@ -261,7 +261,7 @@ def __init__( ignore_index: Optional[int] = None, top_k: Optional[int] = None, multiclass: Optional[bool] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: allowed_average = ["micro", "macro", "weighted", "samples", "none", None] if average not in allowed_average: diff --git a/torchmetrics/classification/specificity.py b/torchmetrics/classification/specificity.py index d82bf6d4cc6..ac8c539920c 100644 --- a/torchmetrics/classification/specificity.py +++ b/torchmetrics/classification/specificity.py @@ -124,7 +124,7 @@ def __init__( ignore_index: Optional[int] = None, top_k: Optional[int] = None, multiclass: Optional[bool] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: allowed_average = ["micro", "macro", "weighted", "samples", "none", None] if average not in allowed_average: From 04381ade6c2e03cc9041649c818c5bc99e8ba318 Mon Sep 17 00:00:00 2001 From: Jirka Date: Tue, 7 Jun 2022 21:58:34 +0200 Subject: [PATCH 10/10] flake8 --- torchmetrics/classification/accuracy.py | 2 +- torchmetrics/classification/dice.py | 2 +- torchmetrics/classification/jaccard.py | 2 +- torchmetrics/classification/precision_recall.py | 2 +- torchmetrics/classification/specificity.py | 2 +- 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/torchmetrics/classification/accuracy.py b/torchmetrics/classification/accuracy.py index 5fb8e12b543..472643d3991 100644 --- a/torchmetrics/classification/accuracy.py +++ b/torchmetrics/classification/accuracy.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Optional from torch import Tensor, tensor diff --git a/torchmetrics/classification/dice.py b/torchmetrics/classification/dice.py index e444b94002f..bc19c44dbff 100644 --- a/torchmetrics/classification/dice.py +++ b/torchmetrics/classification/dice.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Optional from torch import Tensor diff --git a/torchmetrics/classification/jaccard.py b/torchmetrics/classification/jaccard.py index 0328b614e66..d088f6e0702 100644 --- a/torchmetrics/classification/jaccard.py +++ b/torchmetrics/classification/jaccard.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Optional import torch from torch import Tensor diff --git a/torchmetrics/classification/precision_recall.py b/torchmetrics/classification/precision_recall.py index a932faf4653..c0c93e68a88 100644 --- a/torchmetrics/classification/precision_recall.py +++ b/torchmetrics/classification/precision_recall.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Optional from torch import Tensor diff --git a/torchmetrics/classification/specificity.py b/torchmetrics/classification/specificity.py index ac8c539920c..b1bfeb8badb 100644 --- a/torchmetrics/classification/specificity.py +++ b/torchmetrics/classification/specificity.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Optional import torch from torch import Tensor