Skip to content

Commit

Permalink
Implement Explained Variance Metric + metric fix (#4013)
Browse files Browse the repository at this point in the history
* metric fix, explained variance

* one more test

* pep8

* remove comment

* fix add_state condition

Co-authored-by: ananyahjha93 <ananya@pytorchlightning.ai>
  • Loading branch information
teddykoker and ananyahjha93 committed Oct 9, 2020
1 parent 7db26a9 commit b961e12
Show file tree
Hide file tree
Showing 6 changed files with 134 additions and 3 deletions.
7 changes: 7 additions & 0 deletions docs/source/metrics.rst
Expand Up @@ -181,6 +181,13 @@ MeanSquaredLogError
:noindex:


ExplainedVariance
^^^^^^^^^^^^^^^^^

.. autoclass:: pytorch_lightning.metrics.regression.ExplainedVariance
:noindex:


Functional Metrics
==================

Expand Down
7 changes: 6 additions & 1 deletion pytorch_lightning/metrics/__init__.py
@@ -1,4 +1,9 @@
from pytorch_lightning.metrics.metric import Metric

from pytorch_lightning.metrics.classification.accuracy import Accuracy
from pytorch_lightning.metrics.regression import MeanSquaredError, MeanAbsoluteError, MeanSquaredLogError
from pytorch_lightning.metrics.regression import (
MeanSquaredError,
MeanAbsoluteError,
MeanSquaredLogError,
ExplainedVariance,
)
8 changes: 6 additions & 2 deletions pytorch_lightning/metrics/metric.py
Expand Up @@ -96,7 +96,11 @@ def add_state(self, name: str, default, dist_reduce_fx: Optional[Union[str, Call
the format discussed in the above note.
"""
if not isinstance(default, torch.Tensor) or (isinstance(default, list) and len(default) != 0):
if (
not isinstance(default, torch.Tensor)
and not isinstance(default, list) # noqa: W503
or (isinstance(default, list) and len(default) != 0) # noqa: W503
):
raise ValueError(
"state variable must be a tensor or any empty list (where you can append tensors)"
)
Expand Down Expand Up @@ -163,7 +167,7 @@ def _sync_dist(self):
elif isinstance(output_dict[attr][0], list):
output_dict[attr] = _flatten(output_dict[attr])

assert isinstance(reduction_fn, (Callable, None))
assert isinstance(reduction_fn, (Callable)) or reduction_fn is None
reduced = reduction_fn(output_dict[attr]) if reduction_fn is not None else output_dict[attr]
setattr(self, attr, reduced)

Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/metrics/regression/__init__.py
@@ -1,3 +1,4 @@
from pytorch_lightning.metrics.regression.mean_squared_error import MeanSquaredError
from pytorch_lightning.metrics.regression.mean_absolute_error import MeanAbsoluteError
from pytorch_lightning.metrics.regression.mean_squared_log_error import MeanSquaredLogError
from pytorch_lightning.metrics.regression.explained_variance import ExplainedVariance
63 changes: 63 additions & 0 deletions pytorch_lightning/metrics/regression/explained_variance.py
@@ -0,0 +1,63 @@
import torch
from typing import Any, Callable, Optional, Union

from pytorch_lightning.metrics.metric import Metric


class ExplainedVariance(Metric):
"""
Computes explained variance.
Example:
>>> from pytorch_lightning.metrics import ExplainedVariance
>>> target = torch.tensor([3, -0.5, 2, 7])
>>> preds = torch.tensor([2.5, 0.0, 2, 8])
>>> explained_variance = ExplainedVariance()
>>> explained_variance(preds, target)
tensor(0.9572)
"""

def __init__(
self,
compute_on_step: bool = True,
ddp_sync_on_step: bool = False,
process_group: Optional[Any] = None,
):
super().__init__(
compute_on_step=compute_on_step,
ddp_sync_on_step=ddp_sync_on_step,
process_group=process_group,
)

self.add_state("y", default=[], dist_reduce_fx=None)
self.add_state("y_pred", default=[], dist_reduce_fx=None)

def update(self, preds: torch.Tensor, target: torch.Tensor):
"""
Update state with predictions and targets.
Args:
preds: Predictions from model
target: Ground truth values
"""
self.y.append(target)
self.y_pred.append(preds)

def compute(self):
"""
Computes explained variance over state.
"""
y_true = torch.cat(self.y, dim=0)
y_pred = torch.cat(self.y_pred, dim=0)

y_diff_avg = torch.mean(y_true - y_pred, dim=0)
numerator = torch.mean((y_true - y_pred - y_diff_avg) ** 2, dim=0)

y_true_avg = torch.mean(y_true, dim=0)
denominator = torch.mean((y_true - y_true_avg) ** 2, dim=0)

# TODO: multioutput
return 1.0 - torch.mean(numerator / denominator)
51 changes: 51 additions & 0 deletions tests/metrics/regression/test_explained_variance.py
@@ -0,0 +1,51 @@
import torch
import pytest
from collections import namedtuple
from functools import partial

from pytorch_lightning.metrics.regression import ExplainedVariance
from sklearn.metrics import explained_variance_score

from tests.metrics.utils import compute_batch, setup_ddp
from tests.metrics.utils import NUM_BATCHES, NUM_PROCESSES, BATCH_SIZE

torch.manual_seed(42)

num_targets = 5

Input = namedtuple('Input', ["preds", "target"])

_single_target_inputs = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE),
target=torch.rand(NUM_BATCHES, BATCH_SIZE),
)

_multi_target_inputs = Input(
preds=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets),
target=torch.rand(NUM_BATCHES, BATCH_SIZE, num_targets),
)


def _single_target_sk_metric(preds, target, sk_fn=explained_variance_score):
sk_preds = preds.view(-1).numpy()
sk_target = target.view(-1).numpy()
return sk_fn(sk_target, sk_preds)


def _multi_target_sk_metric(preds, target, sk_fn=explained_variance_score):
sk_preds = preds.view(-1, num_targets).numpy()
sk_target = target.view(-1, num_targets).numpy()
return sk_fn(sk_target, sk_preds)


@pytest.mark.parametrize("ddp", [True, False])
@pytest.mark.parametrize("ddp_sync_on_step", [True, False])
@pytest.mark.parametrize(
"preds, target, sk_metric",
[
(_single_target_inputs.preds, _single_target_inputs.target, _single_target_sk_metric),
(_multi_target_inputs.preds, _multi_target_inputs.target, _multi_target_sk_metric),
],
)
def test_explained_variance(ddp, ddp_sync_on_step, preds, target, sk_metric):
compute_batch(preds, target, ExplainedVariance, sk_metric, ddp_sync_on_step, ddp)

0 comments on commit b961e12

Please sign in to comment.