diff --git a/CHANGELOG.md b/CHANGELOG.md index 418944d93d0..6e1d302ab52 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -45,7 +45,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed aggregation metrics when input only contains zero ([#1070](https://github.com/PyTorchLightning/metrics/pull/1070)) -- + +- Fixed `TypeError` when providing superclass arguments as kwargs ([#1069](https://github.com/PyTorchLightning/metrics/pull/1069)) ## [0.9.0] - 2022-05-30 diff --git a/tests/classification/test_confusion_matrix.py b/tests/classification/test_confusion_matrix.py index 08ce3e3fb4d..41370f3f7dd 100644 --- a/tests/classification/test_confusion_matrix.py +++ b/tests/classification/test_confusion_matrix.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial +from typing import Any, Dict import numpy as np import pytest @@ -30,6 +31,7 @@ from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob from tests.helpers import seed_all from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester +from torchmetrics import JaccardIndex from torchmetrics.classification.confusion_matrix import ConfusionMatrix from torchmetrics.functional.classification.confusion_matrix import confusion_matrix @@ -186,3 +188,18 @@ def test_warning_on_nan(tmpdir): match=".* nan values found in confusion matrix have been replaced with zeros.", ): confusion_matrix(preds, target, num_classes=5, normalize="true") + + +@pytest.mark.parametrize( + "metric_args", + [ + {"num_classes": 1, "normalize": "true"}, + {"num_classes": 1, "normalize": "pred"}, + {"num_classes": 1, "normalize": "all"}, + {"num_classes": 1, "normalize": "none"}, + {"num_classes": 1, "normalize": None}, + ], +) +def test_provide_superclass_kwargs(metric_args: Dict[str, Any]): + """Test instantiating subclasses with superclass arguments as kwargs.""" + JaccardIndex(**metric_args) diff --git a/tests/classification/test_stat_scores.py b/tests/classification/test_stat_scores.py index 6f19e1c148b..aa49a76b180 100644 --- a/tests/classification/test_stat_scores.py +++ b/tests/classification/test_stat_scores.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from functools import partial -from typing import Callable, Optional +from typing import Any, Callable, Dict, Optional import numpy as np import pytest @@ -30,7 +30,7 @@ from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob from tests.helpers import seed_all from tests.helpers.testers import NUM_CLASSES, MetricTester -from torchmetrics import StatScores +from torchmetrics import Accuracy, Dice, FBetaScore, Precision, Recall, Specificity, StatScores from torchmetrics.functional import stat_scores from torchmetrics.utilities.checks import _input_format_classification @@ -326,3 +326,20 @@ def test_top_k(k: int, preds: Tensor, target: Tensor, reduce: str, expected: Ten assert torch.equal(class_metric.compute(), expected.T) assert torch.equal(stat_scores(preds, target, top_k=k, reduce=reduce, num_classes=3), expected.T) + + +@pytest.mark.parametrize( + "metric_args", + [ + {"reduce": "micro"}, + {"num_classes": 1, "reduce": "macro"}, + {"reduce": "samples"}, + {"mdmc_reduce": None}, + {"mdmc_reduce": "samplewise"}, + {"mdmc_reduce": "global"}, + ], +) +@pytest.mark.parametrize("metric_cls", [Accuracy, Dice, FBetaScore, Precision, Recall, Specificity]) +def test_provide_superclass_kwargs(metric_cls: StatScores, metric_args: Dict[str, Any]): + """Test instantiating subclasses with superclass arguments as kwargs.""" + metric_cls(**metric_args) diff --git a/tests/text/test_mer.py b/tests/text/test_mer.py index 8153a915317..eaa79618e32 100644 --- a/tests/text/test_mer.py +++ b/tests/text/test_mer.py @@ -4,6 +4,8 @@ from tests.text.helpers import TextTester from tests.text.inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 +from torchmetrics.functional.text.mer import match_error_rate +from torchmetrics.text.mer import MatchErrorRate from torchmetrics.utilities.imports import _JIWER_AVAILABLE if _JIWER_AVAILABLE: @@ -11,9 +13,6 @@ else: compute_measures: Callable -from torchmetrics.functional.text.mer import match_error_rate -from torchmetrics.text.mer import MatchErrorRate - def _compute_mer_metric_jiwer(preds: Union[str, List[str]], target: Union[str, List[str]]): return compute_measures(target, preds)["mer"] diff --git a/tests/text/test_wer.py b/tests/text/test_wer.py index 00c5d2ddc64..c70cac7c72f 100644 --- a/tests/text/test_wer.py +++ b/tests/text/test_wer.py @@ -4,6 +4,8 @@ from tests.text.helpers import TextTester from tests.text.inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2 +from torchmetrics.functional.text.wer import word_error_rate +from torchmetrics.text.wer import WordErrorRate from torchmetrics.utilities.imports import _JIWER_AVAILABLE if _JIWER_AVAILABLE: @@ -11,9 +13,6 @@ else: compute_measures: Callable -from torchmetrics.functional.text.wer import word_error_rate -from torchmetrics.text.wer import WordErrorRate - def _compute_wer_metric_jiwer(preds: Union[str, List[str]], target: Union[str, List[str]]): return compute_measures(target, preds)["wer"] diff --git a/torchmetrics/classification/accuracy.py b/torchmetrics/classification/accuracy.py index 8f77235dc39..472643d3991 100644 --- a/torchmetrics/classification/accuracy.py +++ b/torchmetrics/classification/accuracy.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Optional from torch import Tensor, tensor @@ -23,7 +23,7 @@ _subset_accuracy_compute, _subset_accuracy_update, ) -from torchmetrics.utilities.enums import DataType +from torchmetrics.utilities.enums import AverageMethod, DataType from torchmetrics.classification.stat_scores import StatScores # isort:skip @@ -170,15 +170,19 @@ def __init__( top_k: Optional[int] = None, multiclass: Optional[bool] = None, subset_accuracy: bool = False, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: allowed_average = ["micro", "macro", "weighted", "samples", "none", None] if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") + _reduce_options = (AverageMethod.WEIGHTED, AverageMethod.NONE, None) + if "reduce" not in kwargs: + kwargs["reduce"] = AverageMethod.MACRO if average in _reduce_options else average + if "mdmc_reduce" not in kwargs: + kwargs["mdmc_reduce"] = mdmc_average + super().__init__( - reduce="macro" if average in ["weighted", "none", None] else average, - mdmc_reduce=mdmc_average, threshold=threshold, top_k=top_k, num_classes=num_classes, diff --git a/torchmetrics/classification/dice.py b/torchmetrics/classification/dice.py index 09ff5be205b..bc19c44dbff 100644 --- a/torchmetrics/classification/dice.py +++ b/torchmetrics/classification/dice.py @@ -11,12 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Optional from torch import Tensor from torchmetrics.classification.stat_scores import StatScores from torchmetrics.functional.classification.dice import _dice_compute +from torchmetrics.utilities.enums import AverageMethod class Dice(StatScores): @@ -128,15 +129,19 @@ def __init__( ignore_index: Optional[int] = None, top_k: Optional[int] = None, multiclass: Optional[bool] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: allowed_average = ("micro", "macro", "weighted", "samples", "none", None) if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") + _reduce_options = (AverageMethod.WEIGHTED, AverageMethod.NONE, None) + if "reduce" not in kwargs: + kwargs["reduce"] = AverageMethod.MACRO if average in _reduce_options else average + if "mdmc_reduce" not in kwargs: + kwargs["mdmc_reduce"] = mdmc_average + super().__init__( - reduce="macro" if average in ("weighted", "none", None) else average, - mdmc_reduce=mdmc_average, threshold=threshold, top_k=top_k, num_classes=num_classes, diff --git a/torchmetrics/classification/f_beta.py b/torchmetrics/classification/f_beta.py index 754436aac04..01de148e8cf 100644 --- a/torchmetrics/classification/f_beta.py +++ b/torchmetrics/classification/f_beta.py @@ -130,16 +130,20 @@ def __init__( ignore_index: Optional[int] = None, top_k: Optional[int] = None, multiclass: Optional[bool] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: self.beta = beta allowed_average = list(AverageMethod) if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") + _reduce_options = (AverageMethod.WEIGHTED, AverageMethod.NONE, None) + if "reduce" not in kwargs: + kwargs["reduce"] = AverageMethod.MACRO if average in _reduce_options else average + if "mdmc_reduce" not in kwargs: + kwargs["mdmc_reduce"] = mdmc_average + super().__init__( - reduce="macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE] else average, - mdmc_reduce=mdmc_average, threshold=threshold, top_k=top_k, num_classes=num_classes, diff --git a/torchmetrics/classification/jaccard.py b/torchmetrics/classification/jaccard.py index 58f2438f671..d088f6e0702 100644 --- a/torchmetrics/classification/jaccard.py +++ b/torchmetrics/classification/jaccard.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Optional import torch from torch import Tensor @@ -88,11 +88,12 @@ def __init__( absent_score: float = 0.0, threshold: float = 0.5, multilabel: bool = False, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: + kwargs["normalize"] = kwargs.get("normalize") + super().__init__( num_classes=num_classes, - normalize=None, threshold=threshold, multilabel=multilabel, **kwargs, diff --git a/torchmetrics/classification/precision_recall.py b/torchmetrics/classification/precision_recall.py index 69d6cfe0ac1..c0c93e68a88 100644 --- a/torchmetrics/classification/precision_recall.py +++ b/torchmetrics/classification/precision_recall.py @@ -11,12 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Optional from torch import Tensor from torchmetrics.classification.stat_scores import StatScores from torchmetrics.functional.classification.precision_recall import _precision_compute, _recall_compute +from torchmetrics.utilities.enums import AverageMethod class Precision(StatScores): @@ -121,15 +122,19 @@ def __init__( ignore_index: Optional[int] = None, top_k: Optional[int] = None, multiclass: Optional[bool] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: allowed_average = ["micro", "macro", "weighted", "samples", "none", None] if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") + _reduce_options = (AverageMethod.WEIGHTED, AverageMethod.NONE, None) + if "reduce" not in kwargs: + kwargs["reduce"] = AverageMethod.MACRO if average in _reduce_options else average + if "mdmc_reduce" not in kwargs: + kwargs["mdmc_reduce"] = mdmc_average + super().__init__( - reduce="macro" if average in ["weighted", "none", None] else average, - mdmc_reduce=mdmc_average, threshold=threshold, top_k=top_k, num_classes=num_classes, @@ -256,15 +261,19 @@ def __init__( ignore_index: Optional[int] = None, top_k: Optional[int] = None, multiclass: Optional[bool] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: allowed_average = ["micro", "macro", "weighted", "samples", "none", None] if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") + _reduce_options = (AverageMethod.WEIGHTED, AverageMethod.NONE, None) + if "reduce" not in kwargs: + kwargs["reduce"] = AverageMethod.MACRO if average in _reduce_options else average + if "mdmc_reduce" not in kwargs: + kwargs["mdmc_reduce"] = mdmc_average + super().__init__( - reduce="macro" if average in ["weighted", "none", None] else average, - mdmc_reduce=mdmc_average, threshold=threshold, top_k=top_k, num_classes=num_classes, diff --git a/torchmetrics/classification/specificity.py b/torchmetrics/classification/specificity.py index 9f8eabb89f6..b1bfeb8badb 100644 --- a/torchmetrics/classification/specificity.py +++ b/torchmetrics/classification/specificity.py @@ -11,13 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Optional +from typing import Any, Optional import torch from torch import Tensor from torchmetrics.classification.stat_scores import StatScores from torchmetrics.functional.classification.specificity import _specificity_compute +from torchmetrics.utilities.enums import AverageMethod class Specificity(StatScores): @@ -123,15 +124,19 @@ def __init__( ignore_index: Optional[int] = None, top_k: Optional[int] = None, multiclass: Optional[bool] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: allowed_average = ["micro", "macro", "weighted", "samples", "none", None] if average not in allowed_average: raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.") + _reduce_options = (AverageMethod.WEIGHTED, AverageMethod.NONE, None) + if "reduce" not in kwargs: + kwargs["reduce"] = AverageMethod.MACRO if average in _reduce_options else average + if "mdmc_reduce" not in kwargs: + kwargs["mdmc_reduce"] = mdmc_average + super().__init__( - reduce="macro" if average in ["weighted", "none", None] else average, - mdmc_reduce=mdmc_average, threshold=threshold, top_k=top_k, num_classes=num_classes, diff --git a/torchmetrics/functional/audio/pesq.py b/torchmetrics/functional/audio/pesq.py index 73fcf541d34..7456edf9763 100644 --- a/torchmetrics/functional/audio/pesq.py +++ b/torchmetrics/functional/audio/pesq.py @@ -12,17 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. import numpy as np +import torch +from torch import Tensor +from torchmetrics.utilities.checks import _check_same_shape from torchmetrics.utilities.imports import _PESQ_AVAILABLE if _PESQ_AVAILABLE: import pesq as pesq_backend else: pesq_backend = None -import torch -from torch import Tensor -from torchmetrics.utilities.checks import _check_same_shape __doctest_requires__ = {("perceptual_evaluation_speech_quality",): ["pesq"]} diff --git a/torchmetrics/functional/audio/stoi.py b/torchmetrics/functional/audio/stoi.py index edd1673afa9..fe9cb5a9586 100644 --- a/torchmetrics/functional/audio/stoi.py +++ b/torchmetrics/functional/audio/stoi.py @@ -13,7 +13,9 @@ # limitations under the License. import numpy as np import torch +from torch import Tensor +from torchmetrics.utilities.checks import _check_same_shape from torchmetrics.utilities.imports import _PYSTOI_AVAILABLE if _PYSTOI_AVAILABLE: @@ -21,9 +23,6 @@ else: stoi_backend = None __doctest_skip__ = ["short_time_objective_intelligibility"] -from torch import Tensor - -from torchmetrics.utilities.checks import _check_same_shape def short_time_objective_intelligibility(