Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for bug when providing superclass arguments as kwargs #1069

Merged
merged 16 commits into from Jun 7, 2022
3 changes: 2 additions & 1 deletion CHANGELOG.md
Expand Up @@ -41,7 +41,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

-

-

- Fixed `TypeError` when providing superclass arguments as kwargs ([#1069](https://github.com/PyTorchLightning/metrics/pull/1069))


## [0.9.0] - 2022-05-30
Expand Down
17 changes: 17 additions & 0 deletions tests/classification/test_confusion_matrix.py
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Any, Dict

import numpy as np
import pytest
Expand All @@ -30,6 +31,7 @@
from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob
from tests.helpers import seed_all
from tests.helpers.testers import NUM_CLASSES, THRESHOLD, MetricTester
from torchmetrics import JaccardIndex
from torchmetrics.classification.confusion_matrix import ConfusionMatrix
from torchmetrics.functional.classification.confusion_matrix import confusion_matrix

Expand Down Expand Up @@ -186,3 +188,18 @@ def test_warning_on_nan(tmpdir):
match=".* nan values found in confusion matrix have been replaced with zeros.",
):
confusion_matrix(preds, target, num_classes=5, normalize="true")


@pytest.mark.parametrize(
"metric_args",
[
{"num_classes": 1, "normalize": "true"},
{"num_classes": 1, "normalize": "pred"},
{"num_classes": 1, "normalize": "all"},
{"num_classes": 1, "normalize": "none"},
{"num_classes": 1, "normalize": None},
],
)
def test_provide_superclass_kwargs(metric_args: Dict[str, Any]):
"""Test instantiating subclasses with superclass arguments as kwargs."""
JaccardIndex(**metric_args)
21 changes: 19 additions & 2 deletions tests/classification/test_stat_scores.py
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Callable, Optional
from typing import Any, Callable, Dict, Optional

import numpy as np
import pytest
Expand All @@ -30,7 +30,7 @@
from tests.classification.inputs import _input_multilabel_prob as _input_mlb_prob
from tests.helpers import seed_all
from tests.helpers.testers import NUM_CLASSES, MetricTester
from torchmetrics import StatScores
from torchmetrics import Accuracy, Dice, FBetaScore, Precision, Recall, Specificity, StatScores
from torchmetrics.functional import stat_scores
from torchmetrics.utilities.checks import _input_format_classification

Expand Down Expand Up @@ -326,3 +326,20 @@ def test_top_k(k: int, preds: Tensor, target: Tensor, reduce: str, expected: Ten

assert torch.equal(class_metric.compute(), expected.T)
assert torch.equal(stat_scores(preds, target, top_k=k, reduce=reduce, num_classes=3), expected.T)


@pytest.mark.parametrize(
"metric_args",
[
{"reduce": "micro"},
{"num_classes": 1, "reduce": "macro"},
{"reduce": "samples"},
{"mdmc_reduce": None},
{"mdmc_reduce": "samplewise"},
{"mdmc_reduce": "global"},
],
)
@pytest.mark.parametrize("metric_cls", [Accuracy, Dice, FBetaScore, Precision, Recall, Specificity])
def test_provide_superclass_kwargs(metric_cls: StatScores, metric_args: Dict[str, Any]):
"""Test instantiating subclasses with superclass arguments as kwargs."""
metric_cls(**metric_args)
5 changes: 2 additions & 3 deletions tests/text/test_mer.py
Expand Up @@ -4,16 +4,15 @@

from tests.text.helpers import TextTester
from tests.text.inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2
from torchmetrics.functional.text.mer import match_error_rate
from torchmetrics.text.mer import MatchErrorRate
from torchmetrics.utilities.imports import _JIWER_AVAILABLE

if _JIWER_AVAILABLE:
from jiwer import compute_measures
else:
compute_measures: Callable

from torchmetrics.functional.text.mer import match_error_rate
from torchmetrics.text.mer import MatchErrorRate


def _compute_mer_metric_jiwer(preds: Union[str, List[str]], target: Union[str, List[str]]):
return compute_measures(target, preds)["mer"]
Expand Down
5 changes: 2 additions & 3 deletions tests/text/test_wer.py
Expand Up @@ -4,16 +4,15 @@

from tests.text.helpers import TextTester
from tests.text.inputs import _inputs_error_rate_batch_size_1, _inputs_error_rate_batch_size_2
from torchmetrics.functional.text.wer import word_error_rate
from torchmetrics.text.wer import WordErrorRate
from torchmetrics.utilities.imports import _JIWER_AVAILABLE

if _JIWER_AVAILABLE:
from jiwer import compute_measures
else:
compute_measures: Callable

from torchmetrics.functional.text.wer import word_error_rate
from torchmetrics.text.wer import WordErrorRate


def _compute_wer_metric_jiwer(preds: Union[str, List[str]], target: Union[str, List[str]]):
return compute_measures(target, preds)["wer"]
Expand Down
9 changes: 6 additions & 3 deletions torchmetrics/classification/accuracy.py
Expand Up @@ -23,7 +23,7 @@
_subset_accuracy_compute,
_subset_accuracy_update,
)
from torchmetrics.utilities.enums import DataType
from torchmetrics.utilities.enums import AverageMethod, DataType

from torchmetrics.classification.stat_scores import StatScores # isort:skip

Expand Down Expand Up @@ -176,9 +176,12 @@ def __init__(
if average not in allowed_average:
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")

if "reduce" not in kwargs or kwargs["reduce"] is None:
kwargs["reduce"] = "macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE, "none"] else average
if "mdmc_reduce" not in kwargs or kwargs["mdmc_reduce"] is None:
kwargs["mdmc_reduce"] = mdmc_average

super().__init__(
reduce="macro" if average in ["weighted", "none", None] else average,
mdmc_reduce=mdmc_average,
threshold=threshold,
top_k=top_k,
num_classes=num_classes,
Expand Down
8 changes: 6 additions & 2 deletions torchmetrics/classification/dice.py
Expand Up @@ -17,6 +17,7 @@

from torchmetrics.classification.stat_scores import StatScores
from torchmetrics.functional.classification.dice import _dice_compute
from torchmetrics.utilities.enums import AverageMethod


class Dice(StatScores):
Expand Down Expand Up @@ -134,9 +135,12 @@ def __init__(
if average not in allowed_average:
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")

if "reduce" not in kwargs or kwargs["reduce"] is None:
kwargs["reduce"] = "macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE, "none"] else average
if "mdmc_reduce" not in kwargs or kwargs["mdmc_reduce"] is None:
kwargs["mdmc_reduce"] = mdmc_average

super().__init__(
reduce="macro" if average in ("weighted", "none", None) else average,
mdmc_reduce=mdmc_average,
threshold=threshold,
top_k=top_k,
num_classes=num_classes,
Expand Down
7 changes: 5 additions & 2 deletions torchmetrics/classification/f_beta.py
Expand Up @@ -137,9 +137,12 @@ def __init__(
if average not in allowed_average:
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")

if "reduce" not in kwargs or kwargs["reduce"] is None:
kwargs["reduce"] = "macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE, "none"] else average
if "mdmc_reduce" not in kwargs or kwargs["mdmc_reduce"] is None:
kwargs["mdmc_reduce"] = mdmc_average

super().__init__(
reduce="macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE] else average,
mdmc_reduce=mdmc_average,
threshold=threshold,
top_k=top_k,
num_classes=num_classes,
Expand Down
4 changes: 3 additions & 1 deletion torchmetrics/classification/jaccard.py
Expand Up @@ -90,9 +90,11 @@ def __init__(
multilabel: bool = False,
**kwargs: Dict[str, Any],
) -> None:
if "normalize" not in kwargs:
kwargs["normalize"] = None

super().__init__(
num_classes=num_classes,
normalize=None,
threshold=threshold,
multilabel=multilabel,
**kwargs,
Expand Down
15 changes: 11 additions & 4 deletions torchmetrics/classification/precision_recall.py
Expand Up @@ -17,6 +17,7 @@

from torchmetrics.classification.stat_scores import StatScores
from torchmetrics.functional.classification.precision_recall import _precision_compute, _recall_compute
from torchmetrics.utilities.enums import AverageMethod


class Precision(StatScores):
Expand Down Expand Up @@ -127,9 +128,12 @@ def __init__(
if average not in allowed_average:
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")

if "reduce" not in kwargs or kwargs["reduce"] is None:
kwargs["reduce"] = "macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE, "none"] else average
if "mdmc_reduce" not in kwargs or kwargs["mdmc_reduce"] is None:
kwargs["mdmc_reduce"] = mdmc_average

super().__init__(
reduce="macro" if average in ["weighted", "none", None] else average,
mdmc_reduce=mdmc_average,
threshold=threshold,
top_k=top_k,
num_classes=num_classes,
Expand Down Expand Up @@ -262,9 +266,12 @@ def __init__(
if average not in allowed_average:
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")

if "reduce" not in kwargs or kwargs["reduce"] is None:
kwargs["reduce"] = "macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE, "none"] else average
if "mdmc_reduce" not in kwargs or kwargs["mdmc_reduce"] is None:
kwargs["mdmc_reduce"] = mdmc_average

super().__init__(
reduce="macro" if average in ["weighted", "none", None] else average,
mdmc_reduce=mdmc_average,
threshold=threshold,
top_k=top_k,
num_classes=num_classes,
Expand Down
8 changes: 6 additions & 2 deletions torchmetrics/classification/specificity.py
Expand Up @@ -18,6 +18,7 @@

from torchmetrics.classification.stat_scores import StatScores
from torchmetrics.functional.classification.specificity import _specificity_compute
from torchmetrics.utilities.enums import AverageMethod


class Specificity(StatScores):
Expand Down Expand Up @@ -129,9 +130,12 @@ def __init__(
if average not in allowed_average:
raise ValueError(f"The `average` has to be one of {allowed_average}, got {average}.")

if "reduce" not in kwargs or kwargs["reduce"] is None:
kwargs["reduce"] = "macro" if average in [AverageMethod.WEIGHTED, AverageMethod.NONE, "none"] else average
if "mdmc_reduce" not in kwargs or kwargs["mdmc_reduce"] is None:
kwargs["mdmc_reduce"] = mdmc_average

super().__init__(
reduce="macro" if average in ["weighted", "none", None] else average,
mdmc_reduce=mdmc_average,
threshold=threshold,
top_k=top_k,
num_classes=num_classes,
Expand Down
6 changes: 3 additions & 3 deletions torchmetrics/functional/audio/pesq.py
Expand Up @@ -12,17 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
from torch import Tensor

from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.imports import _PESQ_AVAILABLE

if _PESQ_AVAILABLE:
import pesq as pesq_backend
else:
pesq_backend = None
import torch
from torch import Tensor

from torchmetrics.utilities.checks import _check_same_shape

__doctest_requires__ = {("perceptual_evaluation_speech_quality",): ["pesq"]}

Expand Down
5 changes: 2 additions & 3 deletions torchmetrics/functional/audio/stoi.py
Expand Up @@ -13,17 +13,16 @@
# limitations under the License.
import numpy as np
import torch
from torch import Tensor

from torchmetrics.utilities.checks import _check_same_shape
from torchmetrics.utilities.imports import _PYSTOI_AVAILABLE

if _PYSTOI_AVAILABLE:
from pystoi import stoi as stoi_backend
else:
stoi_backend = None
__doctest_skip__ = ["short_time_objective_intelligibility"]
from torch import Tensor

from torchmetrics.utilities.checks import _check_same_shape


def short_time_objective_intelligibility(
Expand Down