Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor/remove double forward #984

Merged
merged 50 commits into from
May 10, 2022
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
1a89f1d
something is working
SkafteNicki Apr 20, 2022
d18b411
better
SkafteNicki Apr 21, 2022
ce221a3
fix tests
SkafteNicki Apr 25, 2022
335021b
Merge branch 'master' into refactor/remove_double_forward
SkafteNicki Apr 25, 2022
969da72
docstring
SkafteNicki Apr 25, 2022
d72dfa4
update docs
SkafteNicki Apr 25, 2022
77729b7
fix
SkafteNicki Apr 25, 2022
25fd83f
Merge branch 'master' into refactor/remove_double_forward
SkafteNicki Apr 26, 2022
f46cfa0
fix issue
SkafteNicki Apr 26, 2022
30c356a
Merge branch 'refactor/remove_double_forward' of https://github.com/P…
SkafteNicki Apr 26, 2022
b411548
fix some tests
SkafteNicki Apr 26, 2022
285cba4
Merge branch 'master' into refactor/remove_double_forward
SkafteNicki Apr 27, 2022
4f9d75b
introduce class property
SkafteNicki Apr 28, 2022
e2cca16
docs
SkafteNicki Apr 28, 2022
279a83c
Merge branch 'refactor/remove_double_forward' of https://github.com/P…
SkafteNicki Apr 28, 2022
967bf30
changelog
SkafteNicki Apr 28, 2022
ddc237b
fix docs
SkafteNicki Apr 28, 2022
ff011b8
update docs
SkafteNicki Apr 28, 2022
2abc26e
rename and re-order
SkafteNicki Apr 28, 2022
398521a
Merge branch 'master' into refactor/remove_double_forward
SkafteNicki Apr 28, 2022
99ff43a
fix list
SkafteNicki Apr 28, 2022
3b0c8cf
Merge branch 'refactor/remove_double_forward' of https://github.com/P…
SkafteNicki Apr 28, 2022
52a4615
Merge branch 'master' into refactor/remove_double_forward
SkafteNicki Apr 28, 2022
0ba7c37
Merge branch 'master' into refactor/remove_double_forward
SkafteNicki Apr 28, 2022
ff75e70
change impl
SkafteNicki May 4, 2022
274078f
Merge branch 'master' into refactor/remove_double_forward
SkafteNicki May 4, 2022
2eff3b6
fix tests
SkafteNicki May 4, 2022
87fa28f
regression
SkafteNicki May 4, 2022
fe58dc3
fix typing
SkafteNicki May 4, 2022
0670117
fix lightning integration
SkafteNicki May 4, 2022
08299ab
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 4, 2022
416b057
remove from test metric
SkafteNicki May 4, 2022
90c2254
update count
SkafteNicki May 5, 2022
7058907
Merge branch 'master' into refactor/remove_double_forward
SkafteNicki May 5, 2022
2790547
Merge branch 'master' into refactor/remove_double_forward
mergify[bot] May 5, 2022
15a4616
Merge branch 'master' into refactor/remove_double_forward
mergify[bot] May 5, 2022
a16b27b
Merge branch 'master' into refactor/remove_double_forward
Borda May 5, 2022
a310496
Merge branch 'master' into refactor/remove_double_forward
mergify[bot] May 5, 2022
5d4977f
Merge branch 'master' into refactor/remove_double_forward
mergify[bot] May 6, 2022
708c848
Merge branch 'master' into refactor/remove_double_forward
SkafteNicki May 6, 2022
c086f2c
audio
SkafteNicki May 6, 2022
5e8ebf1
fix mean reduction
SkafteNicki May 7, 2022
2ef10a6
refactor forward
SkafteNicki May 7, 2022
00d7fe8
classification
SkafteNicki May 7, 2022
aebff09
detection
SkafteNicki May 7, 2022
fcc5461
image
SkafteNicki May 7, 2022
3c5cacf
retrieval
SkafteNicki May 7, 2022
050cf86
text
SkafteNicki May 7, 2022
b8a7e17
Merge branch 'master' into refactor/remove_double_forward
mergify[bot] May 7, 2022
fc1e2c4
Merge branch 'master' into refactor/remove_double_forward
mergify[bot] May 8, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 1 addition & 6 deletions CHANGELOG.md
Expand Up @@ -15,18 +15,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Added `RetrievalRecallAtFixedPrecision` to retrieval package ([#951](https://github.com/PyTorchLightning/metrics/pull/951))


-

- Added class property `full_state_update` that determines `forward` should call `update` once or twice ([#984](https://github.com/PyTorchLightning/metrics/pull/984))

### Changed

-


-


### Deprecated

-
Expand Down
61 changes: 49 additions & 12 deletions docs/source/pages/implement.rst
@@ -1,5 +1,9 @@
.. _implement:

.. testsetup:: *

from typing import Optional

*********************
Implementing a Metric
*********************
Expand Down Expand Up @@ -40,6 +44,27 @@ Example implementation:
return self.correct.float() / self.total


Additionally you may want to set the class properties: `is_differentiable`, `higher_is_better` and
`full_state_update`. Note that none of them are strictly required for the metric to work.

.. testcode::

from torchmetrics import Metric

class MyMetric(Metric):
# Set to True if the metric is differentiable else set to False
is_differentiable: Optional[bool] = None

# Set to True if the metric reaches it optimal value when the metric is maximized.
# Set to False if it when the metric is minimized.
higher_is_better: Optional[bool] = True

# Set to True if the metric during 'update' requires access to the global metric
# state for its calculations. If not, setting this to False indicates that all
# batch states are independent and we will optimize the runtime of 'forward'
full_state_update: bool = True


Internal implementation details
-------------------------------

Expand All @@ -64,18 +89,30 @@ The cache is first emptied on the next call to ``update``.

``forward`` serves the dual purpose of both returning the metric on the current data and updating the internal
metric state for accumulating over multiple batches. The ``forward()`` method achieves this by combining calls
to ``update`` and ``compute`` in the following way:

1. Calls ``update()`` to update the global metric state (for accumulation over multiple batches)
2. Caches the global state.
3. Calls ``reset()`` to clear global metric state.
4. Calls ``update()`` to update local metric state.
5. Calls ``compute()`` to calculate metric for current batch.
6. Restores the global state.

This procedure has the consequence of calling the user defined ``update`` **twice** during a single
forward call (one to update global statistics and one for getting the batch statistics).

to ``update``, ``compute`` and ``reset``. Depending on the class property ``full_state_update``, ``forward``
can behave in two ways:

1. If ``full_state_update`` is ``True`` it indicates that the metric during ``update`` requires access to the full
metric state and we therefore need to do two calls to ``update`` to secure that the metric is calculated correctly

1. Calls ``update()`` to update the global metric state (for accumulation over multiple batches)
2. Caches the global state.
3. Calls ``reset()`` to clear global metric state.
4. Calls ``update()`` to update local metric state.
5. Calls ``compute()`` to calculate metric for current batch.
6. Restores the global state.

2. If ``full_state_update`` is ``False`` (default) the metric state of one batch is completly independent of the state of
other batches, which means that we only need to call ``update`` once.

1. Caches the global state.
2. Calls ``reset`` the metric to its default state
3. Calls ``update`` to update the state with local batch statistics
4. Calls ``compute`` to calculate the metric for the current batch
5. Reduce the global state and batch state into a single state that becomes the new global state

If implementing your own metric, we recommend trying out the metric with ``full_state_update`` class property set to
both ``True`` and ``False``. If the results are equal, then setting it to ``False`` will usually give the best performance.

---------

Expand Down
1 change: 0 additions & 1 deletion tests/bases/test_composition.py
Expand Up @@ -26,7 +26,6 @@ def __init__(self, val_to_return):
super().__init__()
self.add_state("_num_updates", tensor(0), dist_reduce_fx="sum")
self._val_to_return = val_to_return
self._update_called = True

def update(self, *args, **kwargs) -> None:
self._num_updates += 1
Expand Down
4 changes: 2 additions & 2 deletions tests/helpers/testers.py
Expand Up @@ -567,7 +567,7 @@ class DummyMetric(Metric):

def __init__(self, **kwargs):
super().__init__(**kwargs)
self.add_state("x", tensor(0.0), dist_reduce_fx=None)
self.add_state("x", tensor(0.0), dist_reduce_fx="sum")

def update(self):
pass
Expand All @@ -581,7 +581,7 @@ class DummyListMetric(Metric):

def __init__(self):
super().__init__()
self.add_state("x", [], dist_reduce_fx=None)
self.add_state("x", [], dist_reduce_fx="cat")

def update(self):
pass
Expand Down
9 changes: 5 additions & 4 deletions torchmetrics/classification/roc.py
Expand Up @@ -19,6 +19,7 @@
from torchmetrics.functional.classification.roc import _roc_compute, _roc_update
from torchmetrics.metric import Metric
from torchmetrics.utilities import rank_zero_warn
from torchmetrics.utilities.data import dim_zero_cat


class ROC(Metric):
Expand Down Expand Up @@ -114,8 +115,8 @@ def __init__(
self.num_classes = num_classes
self.pos_label = pos_label

self.add_state("preds", default=[], dist_reduce_fx=None)
self.add_state("target", default=[], dist_reduce_fx=None)
self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")

rank_zero_warn(
"Metric `ROC` will save all targets and predictions in buffer."
Expand Down Expand Up @@ -148,8 +149,8 @@ def compute(self) -> Union[Tuple[Tensor, Tensor, Tensor], Tuple[List[Tensor], Li
thresholds:
thresholds used for computing false- and true-positive rates
"""
preds = torch.cat(self.preds, dim=0)
target = torch.cat(self.target, dim=0)
preds = dim_zero_cat(self.preds)
target = dim_zero_cat(self.target)
if not self.num_classes:
raise ValueError(f"`num_classes` bas to be positive number, but got {self.num_classes}")
return _roc_compute(preds, target, self.num_classes, self.pos_label)
2 changes: 1 addition & 1 deletion torchmetrics/classification/stat_scores.py
Expand Up @@ -151,7 +151,7 @@ def __init__(
raise ValueError(f"The `ignore_index` {ignore_index} is not valid for inputs with {num_classes} classes")

default: Callable = lambda: []
reduce_fn: Optional[str] = None
reduce_fn: Optional[str] = "cat"
if mdmc_reduce != "samplewise" and reduce != "samples":
if reduce == "micro":
zeros_shape = []
Expand Down
4 changes: 2 additions & 2 deletions torchmetrics/image/psnr.py
Expand Up @@ -83,8 +83,8 @@ def __init__(
self.add_state("sum_squared_error", default=tensor(0.0), dist_reduce_fx="sum")
self.add_state("total", default=tensor(0), dist_reduce_fx="sum")
else:
self.add_state("sum_squared_error", default=[])
self.add_state("total", default=[])
self.add_state("sum_squared_error", default=[], dist_reduce_fx="cat")
self.add_state("total", default=[], dist_reduce_fx="cat")

if data_range is None:
if dim is not None:
Expand Down
107 changes: 82 additions & 25 deletions torchmetrics/metric.py
Expand Up @@ -73,6 +73,7 @@ class Metric(Module, ABC):
__jit_unused_properties__ = ["is_differentiable"]
is_differentiable: Optional[bool] = None
higher_is_better: Optional[bool] = None
full_state_update: bool = True

def __init__(
self,
Expand Down Expand Up @@ -112,7 +113,7 @@ def __init__(
self.compute: Callable = self._wrap_compute(self.compute) # type: ignore
self._computed = None
self._forward_cache = None
self._update_called = False
self._update_count = 0
self._to_sync = True
self._should_unsync = True
self._enable_grad = False
Expand All @@ -126,6 +127,11 @@ def __init__(
self._is_synced = False
self._cache: Optional[Dict[str, Union[List[Tensor], Tensor]]] = None

@property
def _update_called(self) -> bool:
# Needed for lightning integration
return self._update_count > 0

def add_state(
self,
name: str,
Expand Down Expand Up @@ -203,37 +209,59 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:
Input arguments are the exact same as corresponding ``update`` method. The returned output is the exact same as
the output of ``compute``.
"""
# add current step
# check if states are already synced
if self._is_synced:
raise TorchMetricsUserError(
"The Metric shouldn't be synced when performing ``update``. "
"The Metric shouldn't be synced when performing ``forward``. "
"HINT: Did you forget to call ``unsync`` ?."
)

# global accumulation
self.update(*args, **kwargs)
if self.full_state_update or self.dist_sync_on_step:
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
# global accumulation
self.update(*args, **kwargs)
_update_count = self._update_count

self._to_sync = self.dist_sync_on_step # type: ignore
# skip restore cache operation from compute as cache is stored below.
self._should_unsync = False
# skip computing on cpu for the batch
_temp_compute_on_cpu = self.compute_on_cpu
self.compute_on_cpu = False

# save context before switch
cache = {attr: getattr(self, attr) for attr in self._defaults}

# call reset, update, compute, on single batch
self._enable_grad = True # allow grads for batch computation
self.reset()
self.update(*args, **kwargs)
self._forward_cache = self.compute()

# restore context
for attr, val in cache.items():
setattr(self, attr, val)
self._update_count = _update_count
else:
# store global state and reset to default
global_state = {attr: getattr(self, attr) for attr in self._defaults.keys()}
self.reset()

self._to_sync = self.dist_sync_on_step # type: ignore
# skip restore cache operation from compute as cache is stored below.
self._should_unsync = False
# skip computing on cpu for the batch
_temp_compute_on_cpu = self.compute_on_cpu
self.compute_on_cpu = False
# local syncronization settings
self._to_sync = self.dist_sync_on_step
self._should_unsync = False
_temp_compute_on_cpu = self.compute_on_cpu
self.compute_on_cpu = False
self._enable_grad = True # allow grads for batch computation

# save context before switch
cache = {attr: getattr(self, attr) for attr in self._defaults}
# calculate batch state and compute batch value
self.update(*args, **kwargs)
self._forward_cache = self.compute()

# call reset, update, compute, on single batch
self._enable_grad = True # allow grads for batch computation
self.reset()
self.update(*args, **kwargs)
self._forward_cache = self.compute()
# reduce batch and global state
self._reduce_states(global_state)

# restore context
for attr, val in cache.items():
setattr(self, attr, val)
self._is_synced = False

self._should_unsync = True
self._to_sync = True
self._computed = None
Expand All @@ -242,6 +270,35 @@ def forward(self, *args: Any, **kwargs: Any) -> Any:

return self._forward_cache

def _reduce_states(self, incoming_state: Dict[str, Any]) -> None:
"""Adds an incoming metric state to the current state of the metric.

Args:
incoming_state: a dict containing a metric state similar metric itself
"""
for attr in self._defaults.keys():
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
local_state = getattr(self, attr)
global_state = incoming_state[attr]
reduce_fn = self._reductions[attr]
if reduce_fn == dim_zero_sum:
SkafteNicki marked this conversation as resolved.
Show resolved Hide resolved
reduced = global_state + local_state
elif reduce_fn == dim_zero_mean:
reduced = (self._update_count * global_state + local_state) / (self._update_count + 1)
Borda marked this conversation as resolved.
Show resolved Hide resolved
elif reduce_fn == dim_zero_max:
reduced = torch.max(global_state, local_state)
elif reduce_fn == dim_zero_min:
reduced = torch.min(global_state, local_state)
elif reduce_fn == dim_zero_cat:
reduced = global_state + local_state
elif reduce_fn is None and isinstance(global_state, Tensor):
reduced = torch.stack([global_state, local_state])
elif reduce_fn is None and isinstance(global_state, list):
reduced = _flatten([global_state, local_state])
else:
reduced = reduce_fn(torch.stack([global_state, local_state])) # type: ignore

setattr(self, attr, reduced)

def _sync_dist(self, dist_sync_fn: Callable = gather_all_tensors, process_group: Optional[Any] = None) -> None:
input_dict = {attr: getattr(self, attr) for attr in self._reductions}

Expand Down Expand Up @@ -273,7 +330,7 @@ def _wrap_update(self, update: Callable) -> Callable:
@functools.wraps(update)
def wrapped_func(*args: Any, **kwargs: Any) -> None:
self._computed = None
self._update_called = True
self._update_count += 1
with torch.set_grad_enabled(self._enable_grad):
update(*args, **kwargs)
if self.compute_on_cpu:
Expand Down Expand Up @@ -383,7 +440,7 @@ def sync_context(
def _wrap_compute(self, compute: Callable) -> Callable:
@functools.wraps(compute)
def wrapped_func(*args: Any, **kwargs: Any) -> Any:
if not self._update_called:
if self._update_count == 0:
rank_zero_warn(
f"The ``compute`` method of metric {self.__class__.__name__}"
" was called before the ``update`` method which may lead to errors,"
Expand Down Expand Up @@ -421,7 +478,7 @@ def compute(self) -> Any:

def reset(self) -> None:
"""This method automatically resets the metric state variables to their default value."""
self._update_called = False
self._update_count = 0
self._forward_cache = None
self._computed = None

Expand Down Expand Up @@ -452,7 +509,7 @@ def __setstate__(self, state: Dict[str, Any]) -> None:
self.compute: Callable = self._wrap_compute(self.compute) # type: ignore

def __setattr__(self, name: str, value: Any) -> None:
if name in ("higher_is_better", "is_differentiable"):
if name in ("higher_is_better", "is_differentiable", "full_state_update"):
raise RuntimeError(f"Can't change const `{name}`.")
super().__setattr__(name, value)

Expand Down
5 changes: 3 additions & 2 deletions torchmetrics/regression/cosine_similarity.py
Expand Up @@ -51,8 +51,9 @@ class CosineSimilarity(Metric):
tensor(0.8536)

"""
is_differentiable = True
higher_is_better = True
is_differentiable: bool = True
higher_is_better: bool = True
full_state_update: bool = False
preds: List[Tensor]
target: List[Tensor]

Expand Down
5 changes: 3 additions & 2 deletions torchmetrics/regression/explained_variance.py
Expand Up @@ -68,8 +68,9 @@ class ExplainedVariance(Metric):
tensor([0.9677, 1.0000])

"""
is_differentiable = True
higher_is_better = True
is_differentiable: bool = True
higher_is_better: bool = True
full_state_update: bool = False
n_obs: Tensor
sum_error: Tensor
sum_squared_error: Tensor
Expand Down
5 changes: 3 additions & 2 deletions torchmetrics/regression/log_mse.py
Expand Up @@ -42,8 +42,9 @@ class MeanSquaredLogError(Metric):
Half precision is only support on GPU for this metric

"""
is_differentiable = True
higher_is_better = False
is_differentiable: bool = True
higher_is_better: bool = False
full_state_update: bool = False
sum_squared_log_error: Tensor
total: Tensor

Expand Down