From 141bafc6d574b817aabeac45d8687b81b94ab42a Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Tue, 7 Jun 2022 16:24:38 +0200 Subject: [PATCH 01/10] update --- torchmetrics/collections.py | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/torchmetrics/collections.py b/torchmetrics/collections.py index aabbfc0a007..0566a1f63e6 100644 --- a/torchmetrics/collections.py +++ b/torchmetrics/collections.py @@ -216,6 +216,9 @@ def _merge_compute_groups(self) -> None: for idx, values in enumerate(temp.values()): self._groups[idx] = values + # create reference between states + self._compute_groups_create_state_ref() + @staticmethod def _equal_metric_states(metric1: Metric, metric2: Metric) -> bool: """Check if the metric state of two metrics are the same.""" @@ -241,16 +244,17 @@ def _equal_metric_states(metric1: Metric, metric2: Metric) -> bool: return True + def _compute_groups_create_state_ref(self) -> None: + for _, cg in self._groups.items(): + m0 = getattr(self, cg[0]) + for i in range(1, len(cg)): + mi = getattr(self, cg[i]) + for state in m0._defaults: + m0_state = getattr(m0, state) + setattr(mi, state, m0_state) + def compute(self) -> Dict[str, Any]: """Compute the result for each metric in the collection.""" - if self._enable_compute_groups and self._groups_checked: - for _, cg in self._groups.items(): - m0 = getattr(self, cg[0]) - # copy the state to the remaining metrics in the compute group - for i in range(1, len(cg)): - mi = getattr(self, cg[i]) - for state in m0._defaults: - setattr(mi, state, getattr(m0, state)) res = {k: m.compute() for k, m in self.items(keep_base=True)} res = _flatten_dict(res) return {self._set_name(k): v for k, v in res.items()} @@ -259,6 +263,9 @@ def reset(self) -> None: """Iteratively call reset for each metric.""" for _, m in self.items(keep_base=True): m.reset() + if self._enable_compute_groups and self._groups_checked: + # reset state reference + self._compute_groups_create_state_ref() def clone(self, prefix: Optional[str] = None, postfix: Optional[str] = None) -> "MetricCollection": """Make a copy of the metric collection From 9c50b82589815489df509643ec280f3d6d931478 Mon Sep 17 00:00:00 2001 From: Skaftenicki Date: Wed, 8 Jun 2022 08:11:30 +0200 Subject: [PATCH 02/10] fix states --- tests/bases/test_collections.py | 124 ++++++++++++++++++++------------ torchmetrics/collections.py | 45 ++++++++---- 2 files changed, 110 insertions(+), 59 deletions(-) diff --git a/tests/bases/test_collections.py b/tests/bases/test_collections.py index 3b84ae24e8c..4c5df4d1959 100644 --- a/tests/bases/test_collections.py +++ b/tests/bases/test_collections.py @@ -322,53 +322,87 @@ def compute(self): ), ], ) -@pytest.mark.parametrize( - "prefix, postfix", - [ - [None, None], - ["prefix_", None], - [None, "_postfix"], - ["prefix_", "_postfix"], - ], -) -def test_check_compute_groups(metrics, expected, prefix, postfix): - """Check that compute groups are formed after initialization.""" - m = MetricCollection(deepcopy(metrics), prefix=prefix, postfix=postfix, compute_groups=True) - # Construct without for comparison - m2 = MetricCollection(deepcopy(metrics), prefix=prefix, postfix=postfix, compute_groups=False) - - assert len(m.compute_groups) == len(m) - assert m2.compute_groups == {} - - for _ in range(2): # repeat to emulate effect of multiple epochs - preds = torch.randn(10, 3).softmax(dim=-1) - target = torch.randint(3, (10,)) - m.update(preds, target) - m2.update(preds, target) - - for _, member in m.items(): - assert member._update_called - - assert m.compute_groups == expected +class TestComputeGroups: + @pytest.mark.parametrize( + "prefix, postfix", + [ + [None, None], + ["prefix_", None], + [None, "_postfix"], + ["prefix_", "_postfix"], + ], + ) + def test_check_compute_groups_correctness(self, metrics, expected, prefix, postfix): + """Check that compute groups are formed after initialization and + that metrics are correctly computed. + """ + m = MetricCollection(deepcopy(metrics), prefix=prefix, postfix=postfix, compute_groups=True) + # Construct without for comparison + m2 = MetricCollection(deepcopy(metrics), prefix=prefix, postfix=postfix, compute_groups=False) + + assert len(m.compute_groups) == len(m) assert m2.compute_groups == {} - preds = torch.randn(10, 3).softmax(dim=-1) - target = torch.randint(3, (10,)) - # compute groups should kick in here - m.update(preds, target) - m2.update(preds, target) - - for _, member in m.items(): - assert member._update_called - - # compare results for correctness - res_cg = m.compute() - res_without_cg = m2.compute() - for key in res_cg.keys(): - assert torch.allclose(res_cg[key], res_without_cg[key]) - - m.reset() - m2.reset() + for _ in range(2): # repeat to emulate effect of multiple epochs + preds = torch.randn(10, 3).softmax(dim=-1) + target = torch.randint(3, (10,)) + m.update(preds, target) + m2.update(preds, target) + + for _, member in m.items(): + assert member._update_called + + assert m.compute_groups == expected + assert m2.compute_groups == {} + + preds = torch.randn(10, 3).softmax(dim=-1) + target = torch.randint(3, (10,)) + # compute groups should kick in here + m.update(preds, target) + m2.update(preds, target) + + for _, member in m.items(): + assert member._update_called + + # compare results for correctness + res_cg = m.compute() + res_without_cg = m2.compute() + for key in res_cg.keys(): + assert torch.allclose(res_cg[key], res_without_cg[key]) + + m.reset() + m2.reset() + + @pytest.mark.parametrize("method", ['items', 'values', 'keys']) + def test_check_compute_groups_items_and_values(self, metrics, expected, method): + m = MetricCollection(deepcopy(metrics), compute_groups=True) + m2 = MetricCollection(deepcopy(metrics), compute_groups=False) + + for _ in range(2): # repeat to emulate effect of multiple epochs + for _ in range(2): # repeat to emulate effect of multiple batches + preds = torch.randn(10, 3).softmax(dim=-1) + target = torch.randint(3, (10,)) + m.update(preds, target) + m2.update(preds, target) + + def _compare(m1, m2): + for state in m1._defaults: + assert torch.allclose(getattr(m1, state), getattr(m2, state)) + # if states are still by reference the reset will make following metrics fail + m1.reset() + m2.reset() + + if method == 'items': + for (name_cg, metric_cg), (name_no_cg, metric_no_cg) in zip(m.items(), m2.items()): + assert name_cg == name_no_cg + _compare(metric_cg, metric_no_cg) + if method == 'values': + for metric_cg, metric_no_cg in zip(m.values(), m2.values()): + _compare(metric_cg, metric_no_cg) + if method == "keys": + for key in m.keys(): + metric_cg, metric_no_cg = m[key], m2[key] + _compare(metric_cg, metric_no_cg) @pytest.mark.parametrize( diff --git a/torchmetrics/collections.py b/torchmetrics/collections.py index 0566a1f63e6..82fea5fd1d4 100644 --- a/torchmetrics/collections.py +++ b/torchmetrics/collections.py @@ -153,7 +153,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Dict[str, Any]: Positional arguments (args) will be passed to every metric in the collection, while keyword arguments (kwargs) will be filtered based on the signature of the individual metric. """ - res = {k: m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items(keep_base=True)} + res = {k: m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items(keep_base=True, copy_state=False)} res = _flatten_dict(res) return {self._set_name(k): v for k, v in res.items()} @@ -173,7 +173,7 @@ def update(self, *args: Any, **kwargs: Any) -> None: mi = getattr(self, cg[i]) mi._update_count = m0._update_count else: # the first update always do per metric to form compute groups - for _, m in self.items(keep_base=True): + for _, m in self.items(keep_base=True, copy_state=False): m_kwargs = m._filter_kwargs(**kwargs) m.update(*args, **m_kwargs) @@ -244,24 +244,29 @@ def _equal_metric_states(metric1: Metric, metric2: Metric) -> bool: return True - def _compute_groups_create_state_ref(self) -> None: - for _, cg in self._groups.items(): - m0 = getattr(self, cg[0]) - for i in range(1, len(cg)): - mi = getattr(self, cg[i]) - for state in m0._defaults: - m0_state = getattr(m0, state) - setattr(mi, state, m0_state) + def _compute_groups_create_state_ref(self, copy: bool = False) -> None: + """Create reference between metrics in the same compute group.""" + if self._groups_checked: + for _, cg in self._groups.items(): + m0 = getattr(self, cg[0]) + for i in range(1, len(cg)): + mi = getattr(self, cg[i]) + for state in m0._defaults: + m0_state = getattr(m0, state) + # Determine if we just should set a reference or a full copy + setattr(mi, state, deepcopy(m0_state) if copy else m0_state) + if copy: + self._groups_checked = False def compute(self) -> Dict[str, Any]: """Compute the result for each metric in the collection.""" - res = {k: m.compute() for k, m in self.items(keep_base=True)} + res = {k: m.compute() for k, m in self.items(keep_base=True, copy_state=False)} res = _flatten_dict(res) return {self._set_name(k): v for k, v in res.items()} def reset(self) -> None: """Iteratively call reset for each metric.""" - for _, m in self.items(keep_base=True): + for _, m in self.items(keep_base=True, copy_state=False): m.reset() if self._enable_compute_groups and self._groups_checked: # reset state reference @@ -283,7 +288,7 @@ def clone(self, prefix: Optional[str] = None, postfix: Optional[str] = None) -> def persistent(self, mode: bool = True) -> None: """Method for post-init to change if metric states should be saved to its state_dict.""" - for _, m in self.items(keep_base=True): + for _, m in self.items(keep_base=True, copy_state=False): m.persistent(mode) def add_metrics( @@ -395,14 +400,26 @@ def keys(self, keep_base: bool = False) -> Iterable[Hashable]: return self._modules.keys() return self._to_renamed_ordered_dict().keys() - def items(self, keep_base: bool = False) -> Iterable[Tuple[str, Module]]: + def items(self, keep_base: bool = False, copy_state: bool = True) -> Iterable[Tuple[str, Module]]: r"""Return an iterable of the ModuleDict key/value pairs. Args: keep_base: Whether to add prefix/postfix on the items collection. """ + if copy_state: + self._compute_groups_create_state_ref(copy_state) if keep_base: return self._modules.items() return self._to_renamed_ordered_dict().items() + + def values(self, copy_state: bool = True) -> Iterable[Module]: + if copy_state: + self._compute_groups_create_state_ref(copy_state) + return self._modules.values() + + def __getitem__(self, key: str, copy_state: bool = True) -> Module: + if copy_state: + self._compute_groups_create_state_ref(copy_state) + return self._modules[key] @staticmethod def _check_arg(arg: Optional[str], name: str) -> Optional[str]: From 8f965479e2b354f5224e569d17e369bd508e00b4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 8 Jun 2022 06:40:42 +0000 Subject: [PATCH 03/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/bases/test_collections.py | 16 +++++++--------- torchmetrics/collections.py | 2 +- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/tests/bases/test_collections.py b/tests/bases/test_collections.py index 4c5df4d1959..51a91cb8af8 100644 --- a/tests/bases/test_collections.py +++ b/tests/bases/test_collections.py @@ -333,9 +333,7 @@ class TestComputeGroups: ], ) def test_check_compute_groups_correctness(self, metrics, expected, prefix, postfix): - """Check that compute groups are formed after initialization and - that metrics are correctly computed. - """ + """Check that compute groups are formed after initialization and that metrics are correctly computed.""" m = MetricCollection(deepcopy(metrics), prefix=prefix, postfix=postfix, compute_groups=True) # Construct without for comparison m2 = MetricCollection(deepcopy(metrics), prefix=prefix, postfix=postfix, compute_groups=False) @@ -373,30 +371,30 @@ def test_check_compute_groups_correctness(self, metrics, expected, prefix, postf m.reset() m2.reset() - @pytest.mark.parametrize("method", ['items', 'values', 'keys']) + @pytest.mark.parametrize("method", ["items", "values", "keys"]) def test_check_compute_groups_items_and_values(self, metrics, expected, method): m = MetricCollection(deepcopy(metrics), compute_groups=True) m2 = MetricCollection(deepcopy(metrics), compute_groups=False) - + for _ in range(2): # repeat to emulate effect of multiple epochs for _ in range(2): # repeat to emulate effect of multiple batches preds = torch.randn(10, 3).softmax(dim=-1) target = torch.randint(3, (10,)) m.update(preds, target) m2.update(preds, target) - + def _compare(m1, m2): for state in m1._defaults: assert torch.allclose(getattr(m1, state), getattr(m2, state)) # if states are still by reference the reset will make following metrics fail m1.reset() m2.reset() - - if method == 'items': + + if method == "items": for (name_cg, metric_cg), (name_no_cg, metric_no_cg) in zip(m.items(), m2.items()): assert name_cg == name_no_cg _compare(metric_cg, metric_no_cg) - if method == 'values': + if method == "values": for metric_cg, metric_no_cg in zip(m.values(), m2.values()): _compare(metric_cg, metric_no_cg) if method == "keys": diff --git a/torchmetrics/collections.py b/torchmetrics/collections.py index 82fea5fd1d4..2a1362cf87d 100644 --- a/torchmetrics/collections.py +++ b/torchmetrics/collections.py @@ -410,7 +410,7 @@ def items(self, keep_base: bool = False, copy_state: bool = True) -> Iterable[Tu if keep_base: return self._modules.items() return self._to_renamed_ordered_dict().items() - + def values(self, copy_state: bool = True) -> Iterable[Module]: if copy_state: self._compute_groups_create_state_ref(copy_state) From f72664469df0acb7835d50eada5e7d0a70865ab4 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 8 Jun 2022 18:54:52 +0200 Subject: [PATCH 04/10] changelog --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e1d302ab52..e94fce5c328 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -49,6 +49,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Fixed `TypeError` when providing superclass arguments as kwargs ([#1069](https://github.com/PyTorchLightning/metrics/pull/1069)) +- Fixed bug related to state reference in metric collection when using compute groups ([#1076](https://github.com/PyTorchLightning/metrics/pull/1076)) + + ## [0.9.0] - 2022-05-30 ### Added From 4c0f2b21cc98521897be3fc0a7ea0cb58ab62409 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 8 Jun 2022 19:03:53 +0200 Subject: [PATCH 05/10] docstring --- tests/bases/test_collections.py | 2 ++ torchmetrics/collections.py | 22 +++++++++++++++++++++- 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/tests/bases/test_collections.py b/tests/bases/test_collections.py index 51a91cb8af8..c3b5bb888f0 100644 --- a/tests/bases/test_collections.py +++ b/tests/bases/test_collections.py @@ -373,6 +373,8 @@ def test_check_compute_groups_correctness(self, metrics, expected, prefix, postf @pytest.mark.parametrize("method", ["items", "values", "keys"]) def test_check_compute_groups_items_and_values(self, metrics, expected, method): + """Check that whenever user call a methods that give access to the indivitual metric that state are copied + instead of just passed by reference.""" m = MetricCollection(deepcopy(metrics), compute_groups=True) m2 = MetricCollection(deepcopy(metrics), compute_groups=False) diff --git a/torchmetrics/collections.py b/torchmetrics/collections.py index 2a1362cf87d..d70db4248cc 100644 --- a/torchmetrics/collections.py +++ b/torchmetrics/collections.py @@ -245,7 +245,12 @@ def _equal_metric_states(metric1: Metric, metric2: Metric) -> bool: return True def _compute_groups_create_state_ref(self, copy: bool = False) -> None: - """Create reference between metrics in the same compute group.""" + """Create reference between metrics in the same compute group. + + Args: + copy: If `True` the metric state will between members will be copied instead + of just passed by reference + """ if self._groups_checked: for _, cg in self._groups.items(): m0 = getattr(self, cg[0]) @@ -404,6 +409,8 @@ def items(self, keep_base: bool = False, copy_state: bool = True) -> Iterable[Tu r"""Return an iterable of the ModuleDict key/value pairs. Args: keep_base: Whether to add prefix/postfix on the items collection. + copy_state: If metric states should be copied between metrics in + the same compute group or just passed by reference """ if copy_state: self._compute_groups_create_state_ref(copy_state) @@ -412,11 +419,24 @@ def items(self, keep_base: bool = False, copy_state: bool = True) -> Iterable[Tu return self._to_renamed_ordered_dict().items() def values(self, copy_state: bool = True) -> Iterable[Module]: + """Return an iterable of the ModuleDict values. + + Args: + copy_state: If metric states should be copied between metrics in + the same compute group or just passed by reference + """ if copy_state: self._compute_groups_create_state_ref(copy_state) return self._modules.values() def __getitem__(self, key: str, copy_state: bool = True) -> Module: + """Retrieve a single metric from the collection. + + Args: + key: name of metric to retrieve + copy_state: If metric states should be copied between metrics in + the same compute group or just passed by reference + """ if copy_state: self._compute_groups_create_state_ref(copy_state) return self._modules[key] From 253a02ffa586275c272ca02d3343a5fbc4ba3ec3 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 8 Jun 2022 19:09:49 +0200 Subject: [PATCH 06/10] integration testing --- integrations/test_lightning.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/integrations/test_lightning.py b/integrations/test_lightning.py index af24cd80c11..a5371117724 100644 --- a/integrations/test_lightning.py +++ b/integrations/test_lightning.py @@ -19,6 +19,7 @@ from torch.utils.data import DataLoader from integrations.lightning.boring_model import BoringModel, RandomDataset +from tests.helpers.utilities import no_warning_call from torchmetrics import Accuracy, AveragePrecision, MetricCollection, SumMetric @@ -210,7 +211,11 @@ def training_epoch_end(self, outs): max_epochs=2, log_every_n_steps=1, ) - trainer.fit(model) + with no_warning_call( + UserWarning, + match="Torchmetrics v0.9 introduced a new argument class property called.*", + ): + trainer.fit(model) logged = trainer.logged_metrics assert torch.allclose(tensor(logged["sum_step"]), model.sum, atol=2e-4) @@ -249,7 +254,11 @@ def training_epoch_end(self, outputs): log_every_n_steps=1, weights_summary=None, ) - trainer.fit(model) + with no_warning_call( + UserWarning, + match="Torchmetrics v0.9 introduced a new argument class property called.*", + ): + trainer.fit(model) logged = trainer.logged_metrics assert torch.allclose(tensor(logged["SumMetric_epoch"]), model.sum, atol=2e-4) From 4730bda6b672a2415fc4c83aad71cc72843533fc Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 8 Jun 2022 19:36:21 +0200 Subject: [PATCH 07/10] fix logic --- torchmetrics/collections.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/torchmetrics/collections.py b/torchmetrics/collections.py index d70db4248cc..c767ab78103 100644 --- a/torchmetrics/collections.py +++ b/torchmetrics/collections.py @@ -143,6 +143,7 @@ def __init__( self.postfix = self._check_arg(postfix, "postfix") self._enable_compute_groups = compute_groups self._groups_checked: bool = False + self._state_is_copy: bool = False self.add_metrics(metrics, *additional_metrics) @@ -172,6 +173,10 @@ def update(self, *args: Any, **kwargs: Any) -> None: for i in range(1, len(cg)): # copy over the update count mi = getattr(self, cg[i]) mi._update_count = m0._update_count + if self._state_is_copy: + # If we have deep copied state inbetween updates, reestablish link + self._compute_groups_create_state_ref() + self._state_is_copy = False else: # the first update always do per metric to form compute groups for _, m in self.items(keep_base=True, copy_state=False): m_kwargs = m._filter_kwargs(**kwargs) @@ -179,6 +184,8 @@ def update(self, *args: Any, **kwargs: Any) -> None: if self._enable_compute_groups: self._merge_compute_groups() + # create reference between states + self._compute_groups_create_state_ref() self._groups_checked = True def _merge_compute_groups(self) -> None: @@ -216,9 +223,6 @@ def _merge_compute_groups(self) -> None: for idx, values in enumerate(temp.values()): self._groups[idx] = values - # create reference between states - self._compute_groups_create_state_ref() - @staticmethod def _equal_metric_states(metric1: Metric, metric2: Metric) -> bool: """Check if the metric state of two metrics are the same.""" @@ -251,7 +255,7 @@ def _compute_groups_create_state_ref(self, copy: bool = False) -> None: copy: If `True` the metric state will between members will be copied instead of just passed by reference """ - if self._groups_checked: + if not self._state_is_copy: for _, cg in self._groups.items(): m0 = getattr(self, cg[0]) for i in range(1, len(cg)): @@ -260,8 +264,7 @@ def _compute_groups_create_state_ref(self, copy: bool = False) -> None: m0_state = getattr(m0, state) # Determine if we just should set a reference or a full copy setattr(mi, state, deepcopy(m0_state) if copy else m0_state) - if copy: - self._groups_checked = False + self._state_is_copy = copy def compute(self) -> Dict[str, Any]: """Compute the result for each metric in the collection.""" @@ -412,8 +415,7 @@ def items(self, keep_base: bool = False, copy_state: bool = True) -> Iterable[Tu copy_state: If metric states should be copied between metrics in the same compute group or just passed by reference """ - if copy_state: - self._compute_groups_create_state_ref(copy_state) + self._compute_groups_create_state_ref(copy_state) if keep_base: return self._modules.items() return self._to_renamed_ordered_dict().items() @@ -425,8 +427,7 @@ def values(self, copy_state: bool = True) -> Iterable[Module]: copy_state: If metric states should be copied between metrics in the same compute group or just passed by reference """ - if copy_state: - self._compute_groups_create_state_ref(copy_state) + self._compute_groups_create_state_ref(copy_state) return self._modules.values() def __getitem__(self, key: str, copy_state: bool = True) -> Module: @@ -437,8 +438,7 @@ def __getitem__(self, key: str, copy_state: bool = True) -> Module: copy_state: If metric states should be copied between metrics in the same compute group or just passed by reference """ - if copy_state: - self._compute_groups_create_state_ref(copy_state) + self._compute_groups_create_state_ref(copy_state) return self._modules[key] @staticmethod From b43e0e7753aabc38209fba71e282b69a0bd71ad5 Mon Sep 17 00:00:00 2001 From: Jirka Date: Wed, 8 Jun 2022 19:36:25 +0200 Subject: [PATCH 08/10] docs --- torchmetrics/collections.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/torchmetrics/collections.py b/torchmetrics/collections.py index d70db4248cc..bce5ec59d7e 100644 --- a/torchmetrics/collections.py +++ b/torchmetrics/collections.py @@ -52,11 +52,11 @@ class name as key for the output dict. metric state and are therefore only different in their compute step e.g. accuracy, precision and recall can all be computed from the true positives/negatives and false positives/negatives. By default, this argument is ``True`` which enables this feature. Set this argument to `False` for disabling - this behaviour. Can also be set to a list of list of metrics for setting the compute groups yourself. + this behaviour. Can also be set to a list of lists of metrics for setting the compute groups yourself. .. note:: Metric collections can be nested at initilization (see last example) but the output of the collection will - still be a single flattened dictionary combining the prefix and postfix arguments from the nested collection. + still be a single flatten dictionary combining the prefix and postfix arguments from the nested collection. Raises: ValueError: @@ -407,8 +407,9 @@ def keys(self, keep_base: bool = False) -> Iterable[Hashable]: def items(self, keep_base: bool = False, copy_state: bool = True) -> Iterable[Tuple[str, Module]]: r"""Return an iterable of the ModuleDict key/value pairs. + Args: - keep_base: Whether to add prefix/postfix on the items collection. + keep_base: Whether to add prefix/postfix on the collection. copy_state: If metric states should be copied between metrics in the same compute group or just passed by reference """ @@ -422,8 +423,8 @@ def values(self, copy_state: bool = True) -> Iterable[Module]: """Return an iterable of the ModuleDict values. Args: - copy_state: If metric states should be copied between metrics in - the same compute group or just passed by reference + copy_state: If metric states should be copied between metrics in the same compute group + or just passed by reference """ if copy_state: self._compute_groups_create_state_ref(copy_state) @@ -434,8 +435,8 @@ def __getitem__(self, key: str, copy_state: bool = True) -> Module: Args: key: name of metric to retrieve - copy_state: If metric states should be copied between metrics in - the same compute group or just passed by reference + copy_state: If metric states should be copied between metrics in the same compute group + or just passed by reference """ if copy_state: self._compute_groups_create_state_ref(copy_state) From 42faae757cddd74e0741e360576083395560c4ef Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 8 Jun 2022 19:38:28 +0200 Subject: [PATCH 09/10] fix docs --- torchmetrics/collections.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/torchmetrics/collections.py b/torchmetrics/collections.py index c767ab78103..13c74b8ac38 100644 --- a/torchmetrics/collections.py +++ b/torchmetrics/collections.py @@ -412,8 +412,8 @@ def items(self, keep_base: bool = False, copy_state: bool = True) -> Iterable[Tu r"""Return an iterable of the ModuleDict key/value pairs. Args: keep_base: Whether to add prefix/postfix on the items collection. - copy_state: If metric states should be copied between metrics in - the same compute group or just passed by reference + copy_state: + If metric states should be copied between metrics in the same compute group or just passed by reference """ self._compute_groups_create_state_ref(copy_state) if keep_base: @@ -424,8 +424,8 @@ def values(self, copy_state: bool = True) -> Iterable[Module]: """Return an iterable of the ModuleDict values. Args: - copy_state: If metric states should be copied between metrics in - the same compute group or just passed by reference + copy_state: + If metric states should be copied between metrics in the same compute group or just passed by reference """ self._compute_groups_create_state_ref(copy_state) return self._modules.values() @@ -435,8 +435,8 @@ def __getitem__(self, key: str, copy_state: bool = True) -> Module: Args: key: name of metric to retrieve - copy_state: If metric states should be copied between metrics in - the same compute group or just passed by reference + copy_state: + If metric states should be copied between metrics in the same compute group or just passed by reference """ self._compute_groups_create_state_ref(copy_state) return self._modules[key] From 5abcedd04c9216d433a22de860ee211c1d5dce07 Mon Sep 17 00:00:00 2001 From: SkafteNicki Date: Wed, 8 Jun 2022 19:48:25 +0200 Subject: [PATCH 10/10] fix mistake --- torchmetrics/metric.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchmetrics/metric.py b/torchmetrics/metric.py index edecf0fa72e..2fa4eacb218 100644 --- a/torchmetrics/metric.py +++ b/torchmetrics/metric.py @@ -303,7 +303,8 @@ def _forward_reduce_state_update(self, *args: Any, **kwargs: Any) -> Any: # reduce batch and global state self._update_count = _update_count + 1 - self._reduce_states(global_state) + with torch.no_grad(): + self._reduce_states(global_state) # restore context self._is_synced = False