Skip to content

Commit

Permalink
fix: move results' keys to device (#19813)
Browse files Browse the repository at this point in the history
  • Loading branch information
azzhipa committed Apr 25, 2024
1 parent b9680a3 commit 99dcf93
Showing 1 changed file with 3 additions and 10 deletions.
Expand Up @@ -403,26 +403,19 @@ def log(

# register logged value if it doesn't exist
if key not in self:
self.register_key(key, meta, value)
metric = _ResultMetric(meta, isinstance(value, Tensor))
self[key] = metric

# check the stored metadata and the current one match
elif meta != self[key].meta:
raise MisconfigurationException(
f"You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed"
)
self[key].to(value.device)

batch_size = self._extract_batch_size(self[key], batch_size, meta)
self.update_metrics(key, value, batch_size)

def register_key(self, key: str, meta: _Metadata, value: _VALUE) -> None:
"""Create one _ResultMetric object per value.
Value can be provided as a nested collection
"""
metric = _ResultMetric(meta, isinstance(value, Tensor)).to(value.device)
self[key] = metric

def update_metrics(self, key: str, value: _VALUE, batch_size: int) -> None:
result_metric = self[key]
# performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl`
Expand Down

0 comments on commit 99dcf93

Please sign in to comment.