Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix state reference in MetricCollection #1076

Merged
merged 13 commits into from Jun 8, 2022
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -49,6 +49,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Fixed `TypeError` when providing superclass arguments as kwargs ([#1069](https://github.com/PyTorchLightning/metrics/pull/1069))


- Fixed bug related to state reference in metric collection when using compute groups ([#1076](https://github.com/PyTorchLightning/metrics/pull/1076))


## [0.9.0] - 2022-05-30

### Added
Expand Down
120 changes: 76 additions & 44 deletions tests/bases/test_collections.py
Expand Up @@ -322,53 +322,85 @@ def compute(self):
),
],
)
@pytest.mark.parametrize(
"prefix, postfix",
[
[None, None],
["prefix_", None],
[None, "_postfix"],
["prefix_", "_postfix"],
],
)
def test_check_compute_groups(metrics, expected, prefix, postfix):
"""Check that compute groups are formed after initialization."""
m = MetricCollection(deepcopy(metrics), prefix=prefix, postfix=postfix, compute_groups=True)
# Construct without for comparison
m2 = MetricCollection(deepcopy(metrics), prefix=prefix, postfix=postfix, compute_groups=False)

assert len(m.compute_groups) == len(m)
assert m2.compute_groups == {}

for _ in range(2): # repeat to emulate effect of multiple epochs
preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))
m.update(preds, target)
m2.update(preds, target)

for _, member in m.items():
assert member._update_called
class TestComputeGroups:
@pytest.mark.parametrize(
"prefix, postfix",
[
[None, None],
["prefix_", None],
[None, "_postfix"],
["prefix_", "_postfix"],
],
)
def test_check_compute_groups_correctness(self, metrics, expected, prefix, postfix):
"""Check that compute groups are formed after initialization and that metrics are correctly computed."""
m = MetricCollection(deepcopy(metrics), prefix=prefix, postfix=postfix, compute_groups=True)
# Construct without for comparison
m2 = MetricCollection(deepcopy(metrics), prefix=prefix, postfix=postfix, compute_groups=False)

assert m.compute_groups == expected
assert len(m.compute_groups) == len(m)
assert m2.compute_groups == {}

preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))
# compute groups should kick in here
m.update(preds, target)
m2.update(preds, target)

for _, member in m.items():
assert member._update_called

# compare results for correctness
res_cg = m.compute()
res_without_cg = m2.compute()
for key in res_cg.keys():
assert torch.allclose(res_cg[key], res_without_cg[key])

m.reset()
m2.reset()
for _ in range(2): # repeat to emulate effect of multiple epochs
preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))
m.update(preds, target)
m2.update(preds, target)

for _, member in m.items():
assert member._update_called

assert m.compute_groups == expected
assert m2.compute_groups == {}

preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))
# compute groups should kick in here
m.update(preds, target)
m2.update(preds, target)

for _, member in m.items():
assert member._update_called

# compare results for correctness
res_cg = m.compute()
res_without_cg = m2.compute()
for key in res_cg.keys():
assert torch.allclose(res_cg[key], res_without_cg[key])

m.reset()
m2.reset()

@pytest.mark.parametrize("method", ["items", "values", "keys"])
def test_check_compute_groups_items_and_values(self, metrics, expected, method):
m = MetricCollection(deepcopy(metrics), compute_groups=True)
m2 = MetricCollection(deepcopy(metrics), compute_groups=False)

for _ in range(2): # repeat to emulate effect of multiple epochs
for _ in range(2): # repeat to emulate effect of multiple batches
preds = torch.randn(10, 3).softmax(dim=-1)
target = torch.randint(3, (10,))
m.update(preds, target)
m2.update(preds, target)

def _compare(m1, m2):
for state in m1._defaults:
assert torch.allclose(getattr(m1, state), getattr(m2, state))
# if states are still by reference the reset will make following metrics fail
m1.reset()
m2.reset()

if method == "items":
for (name_cg, metric_cg), (name_no_cg, metric_no_cg) in zip(m.items(), m2.items()):
assert name_cg == name_no_cg
_compare(metric_cg, metric_no_cg)
if method == "values":
for metric_cg, metric_no_cg in zip(m.values(), m2.values()):
_compare(metric_cg, metric_no_cg)
if method == "keys":
for key in m.keys():
metric_cg, metric_no_cg = m[key], m2[key]
_compare(metric_cg, metric_no_cg)


@pytest.mark.parametrize(
Expand Down
46 changes: 35 additions & 11 deletions torchmetrics/collections.py
Expand Up @@ -153,7 +153,7 @@ def forward(self, *args: Any, **kwargs: Any) -> Dict[str, Any]:
Positional arguments (args) will be passed to every metric in the collection, while keyword arguments (kwargs)
will be filtered based on the signature of the individual metric.
"""
res = {k: m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items(keep_base=True)}
res = {k: m(*args, **m._filter_kwargs(**kwargs)) for k, m in self.items(keep_base=True, copy_state=False)}
res = _flatten_dict(res)
return {self._set_name(k): v for k, v in res.items()}

Expand All @@ -173,7 +173,7 @@ def update(self, *args: Any, **kwargs: Any) -> None:
mi = getattr(self, cg[i])
mi._update_count = m0._update_count
else: # the first update always do per metric to form compute groups
for _, m in self.items(keep_base=True):
for _, m in self.items(keep_base=True, copy_state=False):
m_kwargs = m._filter_kwargs(**kwargs)
m.update(*args, **m_kwargs)

Expand Down Expand Up @@ -216,6 +216,9 @@ def _merge_compute_groups(self) -> None:
for idx, values in enumerate(temp.values()):
self._groups[idx] = values

# create reference between states
self._compute_groups_create_state_ref()

@staticmethod
def _equal_metric_states(metric1: Metric, metric2: Metric) -> bool:
"""Check if the metric state of two metrics are the same."""
Expand All @@ -241,24 +244,33 @@ def _equal_metric_states(metric1: Metric, metric2: Metric) -> bool:

return True

def compute(self) -> Dict[str, Any]:
"""Compute the result for each metric in the collection."""
if self._enable_compute_groups and self._groups_checked:
def _compute_groups_create_state_ref(self, copy: bool = False) -> None:
"""Create reference between metrics in the same compute group."""
if self._groups_checked:
for _, cg in self._groups.items():
m0 = getattr(self, cg[0])
# copy the state to the remaining metrics in the compute group
for i in range(1, len(cg)):
mi = getattr(self, cg[i])
for state in m0._defaults:
setattr(mi, state, getattr(m0, state))
res = {k: m.compute() for k, m in self.items(keep_base=True)}
m0_state = getattr(m0, state)
# Determine if we just should set a reference or a full copy
setattr(mi, state, deepcopy(m0_state) if copy else m0_state)
if copy:
self._groups_checked = False

def compute(self) -> Dict[str, Any]:
"""Compute the result for each metric in the collection."""
res = {k: m.compute() for k, m in self.items(keep_base=True, copy_state=False)}
res = _flatten_dict(res)
return {self._set_name(k): v for k, v in res.items()}

def reset(self) -> None:
"""Iteratively call reset for each metric."""
for _, m in self.items(keep_base=True):
for _, m in self.items(keep_base=True, copy_state=False):
m.reset()
if self._enable_compute_groups and self._groups_checked:
# reset state reference
self._compute_groups_create_state_ref()

def clone(self, prefix: Optional[str] = None, postfix: Optional[str] = None) -> "MetricCollection":
"""Make a copy of the metric collection
Expand All @@ -276,7 +288,7 @@ def clone(self, prefix: Optional[str] = None, postfix: Optional[str] = None) ->

def persistent(self, mode: bool = True) -> None:
"""Method for post-init to change if metric states should be saved to its state_dict."""
for _, m in self.items(keep_base=True):
for _, m in self.items(keep_base=True, copy_state=False):
m.persistent(mode)

def add_metrics(
Expand Down Expand Up @@ -388,15 +400,27 @@ def keys(self, keep_base: bool = False) -> Iterable[Hashable]:
return self._modules.keys()
return self._to_renamed_ordered_dict().keys()

def items(self, keep_base: bool = False) -> Iterable[Tuple[str, Module]]:
def items(self, keep_base: bool = False, copy_state: bool = True) -> Iterable[Tuple[str, Module]]:
r"""Return an iterable of the ModuleDict key/value pairs.
Args:
keep_base: Whether to add prefix/postfix on the items collection.
"""
if copy_state:
self._compute_groups_create_state_ref(copy_state)
if keep_base:
return self._modules.items()
return self._to_renamed_ordered_dict().items()

def values(self, copy_state: bool = True) -> Iterable[Module]:
if copy_state:
self._compute_groups_create_state_ref(copy_state)
return self._modules.values()

def __getitem__(self, key: str, copy_state: bool = True) -> Module:
if copy_state:
self._compute_groups_create_state_ref(copy_state)
return self._modules[key]

@staticmethod
def _check_arg(arg: Optional[str], name: str) -> Optional[str]:
if arg is None or isinstance(arg, str):
Expand Down