diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e1d302ab52..e94fce5c328 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -49,6 +49,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed `TypeError` when providing superclass arguments as kwargs ([#1069](https://github.com/PyTorchLightning/metrics/pull/1069)) +- Fixed bug related to state reference in metric collection when using compute groups ([#1076](https://github.com/PyTorchLightning/metrics/pull/1076)) + + ## [0.9.0] - 2022-05-30 ### Added diff --git a/integrations/test_lightning.py b/integrations/test_lightning.py index af24cd80c11..a5371117724 100644 --- a/integrations/test_lightning.py +++ b/integrations/test_lightning.py @@ -19,6 +19,7 @@ from torch.utils.data import DataLoader from integrations.lightning.boring_model import BoringModel, RandomDataset +from tests.helpers.utilities import no_warning_call from torchmetrics import Accuracy, AveragePrecision, MetricCollection, SumMetric @@ -210,7 +211,11 @@ def training_epoch_end(self, outs): max_epochs=2, log_every_n_steps=1, ) - trainer.fit(model) + with no_warning_call( + UserWarning, + match="Torchmetrics v0.9 introduced a new argument class property called.*", + ): + trainer.fit(model) logged = trainer.logged_metrics assert torch.allclose(tensor(logged["sum_step"]), model.sum, atol=2e-4) @@ -249,7 +254,11 @@ def training_epoch_end(self, outputs): log_every_n_steps=1, weights_summary=None, ) - trainer.fit(model) + with no_warning_call( + UserWarning, + match="Torchmetrics v0.9 introduced a new argument class property called.*", + ): + trainer.fit(model) logged = trainer.logged_metrics assert torch.allclose(tensor(logged["SumMetric_epoch"]), model.sum, atol=2e-4) diff --git a/tests/bases/test_collections.py b/tests/bases/test_collections.py index 3b84ae24e8c..c3b5bb888f0 100644 --- a/tests/bases/test_collections.py +++ b/tests/bases/test_collections.py @@ -322,53 +322,87 @@ def compute(self): ), ], ) -@pytest.mark.parametrize( - "prefix, postfix", - [ - [None, None], - ["prefix_", None], - [None, "_postfix"], - ["prefix_", "_postfix"], - ], -) -def test_check_compute_groups(metrics, expected, prefix, postfix): - """Check that compute groups are formed after initialization.""" - m = MetricCollection(deepcopy(metrics), prefix=prefix, postfix=postfix, compute_groups=True) - # Construct without for comparison - m2 = MetricCollection(deepcopy(metrics), prefix=prefix, postfix=postfix, compute_groups=False) - - assert len(m.compute_groups) == len(m) - assert m2.compute_groups == {} - - for _ in range(2): # repeat to emulate effect of multiple epochs - preds = torch.randn(10, 3).softmax(dim=-1) - target = torch.randint(3, (10,)) - m.update(preds, target) - m2.update(preds, target) - - for _, member in m.items(): - assert member._update_called +class TestComputeGroups: + @pytest.mark.parametrize( + "prefix, postfix", + [ + [None, None], + ["prefix_", None], + [None, "_postfix"], + ["prefix_", "_postfix"], + ], + ) + def test_check_compute_groups_correctness(self, metrics, expected, prefix, postfix): + """Check that compute groups are formed after initialization and that metrics are correctly computed.""" + m = MetricCollection(deepcopy(metrics), prefix=prefix, postfix=postfix, compute_groups=True) + # Construct without for comparison + m2 = MetricCollection(deepcopy(metrics), prefix=prefix, postfix=postfix, compute_groups=False) - assert m.compute_groups == expected + assert len(m.compute_groups) == len(m) assert m2.compute_groups == {} - preds = torch.randn(10, 3).softmax(dim=-1) - target = torch.randint(3, (10,)) - # compute groups should kick in here - m.update(preds, target) - m2.update(preds, target) - - for _, member in m.items(): - assert member._update_called - - # compare results for correctness - res_cg = m.compute() - res_without_cg = m2.compute() - for key in res_cg.keys(): - assert torch.allclose(res_cg[key], res_without_cg[key]) - - m.reset() - m2.reset() + for _ in range(2): # repeat to emulate effect of multiple epochs + preds = torch.randn(10, 3).softmax(dim=-1) + target = torch.randint(3, (10,)) + m.update(preds, target) + m2.update(preds, target) + + for _, member in m.items(): + assert member._update_called + + assert m.compute_groups == expected + assert m2.compute_groups == {} + + preds = torch.randn(10, 3).softmax(dim=-1) + target = torch.randint(3, (10,)) + # compute groups should kick in here + m.update(preds, target) + m2.update(preds, target) + + for _, member in m.items(): + assert member._update_called + + # compare results for correctness + res_cg = m.compute() + res_without_cg = m2.compute() + for key in res_cg.keys(): + assert torch.allclose(res_cg[key], res_without_cg[key]) + + m.reset() + m2.reset() + + @pytest.mark.parametrize("method", ["items", "values", "keys"]) + def test_check_compute_groups_items_and_values(self, metrics, expected, method): + """Check that whenever user call a methods that give access to the indivitual metric that state are copied + instead of just passed by reference.""" + m = MetricCollection(deepcopy(metrics), compute_groups=True) + m2 = MetricCollection(deepcopy(metrics), compute_groups=False) + + for _ in range(2): # repeat to emulate effect of multiple epochs + for _ in range(2): # repeat to emulate effect of multiple batches + preds = torch.randn(10, 3).softmax(dim=-1) + target = torch.randint(3, (10,)) + m.update(preds, target) + m2.update(preds, target) + + def _compare(m1, m2): + for state in m1._defaults: + assert torch.allclose(getattr(m1, state), getattr(m2, state)) + # if states are still by reference the reset will make following metrics fail + m1.reset() + m2.reset() + + if method == "items": + for (name_cg, metric_cg), (name_no_cg, metric_no_cg) in zip(m.items(), m2.items()): + assert name_cg == name_no_cg + _compare(metric_cg, metric_no_cg) + if method == "values": + for metric_cg, metric_no_cg in zip(m.values(), m2.values()): + _compare(metric_cg, metric_no_cg) + if method == "keys": + for key in m.keys(): + metric_cg, metric_no_cg = m[key], m2[key] + _compare(metric_cg, metric_no_cg) @pytest.mark.parametrize( diff --git a/torchmetrics/collections.py b/torchmetrics/collections.py index aabbfc0a007..88002ec2fa3 100644 --- a/torchmetrics/collections.py +++ b/torchmetrics/collections.py @@ -52,11 +52,11 @@ class name as key for the output dict. metric state and are therefore only different in their compute step e.g. accuracy, precision and recall can all be computed from the true positives/negatives and false positives/negatives. By default, this argument is ``True`` which enables this feature. Set this argument to `False` for disabling - this behaviour. Can also be set to a list of list of metrics for setting the compute groups yourself. + this behaviour. Can also be set to a list of lists of metrics for setting the compute groups yourself. .. note:: Metric collections can be nested at initilization (see last example) but the output of the collection will - still be a single flattened dictionary combining the prefix and postfix arguments from the nested collection. + still be a single flatten dictionary combining the prefix and postfix arguments from the nested collection. Raises: ValueError: @@ -143,6 +143,7 @@ def __init__( self.postfix = self._check_arg(postfix, "postfix") self._enable_compute_groups = compute_groups self._groups_checked: bool = False + self._state_is_copy: bool = False self.add_metrics(metrics, *additional_metrics) @@ -153,7 +154,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: Positional arguments (args) will be passed to every metric in the collection, while keyword arguments (kwargs) will be filtered based on the signature of the individual metric. """ - res = {k: m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items(keep_base=True)} + res = {k: m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items(keep_base=True, copy_state=False)} res = _flatten_dict(res) return {self._set_name(k): v for k, v in res.items()} @@ -172,13 +173,19 @@ def update(self, *args: Any, **kwargs: Any) -> None: for i in range(1, len(cg)): # copy over the update count mi = getattr(self, cg[i]) mi._update_count = m0._update_count + if self._state_is_copy: + # If we have deep copied state inbetween updates, reestablish link + self._compute_groups_create_state_ref() + self._state_is_copy = False else: # the first update always do per metric to form compute groups - for _, m in self.items(keep_base=True): + for _, m in self.items(keep_base=True, copy_state=False): m_kwargs = m._filter_kwargs(**kwargs) m.update(*args, **m_kwargs) if self._enable_compute_groups: self._merge_compute_groups() + # create reference between states + self._compute_groups_create_state_ref() self._groups_checked = True def _merge_compute_groups(self) -> None: @@ -241,24 +248,37 @@ def _equal_metric_states(metric1: Metric, metric2: Metric) -> bool: return True - def compute(self) -> Dict[str, Any]: - """Compute the result for each metric in the collection.""" - if self._enable_compute_groups and self._groups_checked: + def _compute_groups_create_state_ref(self, copy: bool = False) -> None: + """Create reference between metrics in the same compute group. + + Args: + copy: If `True` the metric state will between members will be copied instead + of just passed by reference + """ + if not self._state_is_copy: for _, cg in self._groups.items(): m0 = getattr(self, cg[0]) - # copy the state to the remaining metrics in the compute group for i in range(1, len(cg)): mi = getattr(self, cg[i]) for state in m0._defaults: - setattr(mi, state, getattr(m0, state)) - res = {k: m.compute() for k, m in self.items(keep_base=True)} + m0_state = getattr(m0, state) + # Determine if we just should set a reference or a full copy + setattr(mi, state, deepcopy(m0_state) if copy else m0_state) + self._state_is_copy = copy + + def compute(self) -> Dict[str, Any]: + """Compute the result for each metric in the collection.""" + res = {k: m.compute() for k, m in self.items(keep_base=True, copy_state=False)} res = _flatten_dict(res) return {self._set_name(k): v for k, v in res.items()} def reset(self) -> None: """Iteratively call reset for each metric.""" - for _, m in self.items(keep_base=True): + for _, m in self.items(keep_base=True, copy_state=False): m.reset() + if self._enable_compute_groups and self._groups_checked: + # reset state reference + self._compute_groups_create_state_ref() def clone(self, prefix: Optional[str] = None, postfix: Optional[str] = None) -> "MetricCollection": """Make a copy of the metric collection @@ -276,7 +296,7 @@ def clone(self, prefix: Optional[str] = None, postfix: Optional[str] = None) -> def persistent(self, mode: bool = True) -> None: """Method for post-init to change if metric states should be saved to its state_dict.""" - for _, m in self.items(keep_base=True): + for _, m in self.items(keep_base=True, copy_state=False): m.persistent(mode) def add_metrics( @@ -388,15 +408,40 @@ def keys(self, keep_base: bool = False) -> Iterable[Hashable]: return self._modules.keys() return self._to_renamed_ordered_dict().keys() - def items(self, keep_base: bool = False) -> Iterable[Tuple[str, Module]]: + def items(self, keep_base: bool = False, copy_state: bool = True) -> Iterable[Tuple[str, Module]]: r"""Return an iterable of the ModuleDict key/value pairs. + Args: - keep_base: Whether to add prefix/postfix on the items collection. + keep_base: Whether to add prefix/postfix on the collection. + copy_state: + If metric states should be copied between metrics in the same compute group or just passed by reference """ + self._compute_groups_create_state_ref(copy_state) if keep_base: return self._modules.items() return self._to_renamed_ordered_dict().items() + def values(self, copy_state: bool = True) -> Iterable[Module]: + """Return an iterable of the ModuleDict values. + + Args: + copy_state: + If metric states should be copied between metrics in the same compute group or just passed by reference + """ + self._compute_groups_create_state_ref(copy_state) + return self._modules.values() + + def __getitem__(self, key: str, copy_state: bool = True) -> Module: + """Retrieve a single metric from the collection. + + Args: + key: name of metric to retrieve + copy_state: + If metric states should be copied between metrics in the same compute group or just passed by reference + """ + self._compute_groups_create_state_ref(copy_state) + return self._modules[key] + @staticmethod def _check_arg(arg: Optional[str], name: str) -> Optional[str]: if arg is None or isinstance(arg, str): diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index edecf0fa72e..2fa4eacb218 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -303,7 +303,8 @@ def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any: # reduce batch and global state self._update_count = _update_count + 1 - self._reduce_states(global_state) + with torch.no_grad(): + self._reduce_states(global_state) # restore context self._is_synced = False