Skip to content

Commit

Permalink
fix: move results' keys to device (Lightning-AI#19813)
Browse files Browse the repository at this point in the history
  • Loading branch information
azzhipa committed Apr 25, 2024
1 parent b9680a3 commit 8085e74
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 10 deletions.
2 changes: 2 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Expand Up @@ -16,6 +16,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `on_exception` hook to `LightningDataModule` ([#19601](https://github.com/Lightning-AI/pytorch-lightning/pull/19601))

- Fixed `_LoggerConnector`'s `_ResultMetric` to move all registered keys to the current device ([#19813](https://github.com/Lightning-AI/pytorch-lightning/issues/19813))

-

### Changed
Expand Down
Expand Up @@ -403,26 +403,19 @@ def log(

# register logged value if it doesn't exist
if key not in self:
self.register_key(key, meta, value)
metric = _ResultMetric(meta, isinstance(value, Tensor))
self[key] = metric

# check the stored metadata and the current one match
elif meta != self[key].meta:
raise MisconfigurationException(
f"You called `self.log({name}, ...)` twice in `{fx}` with different arguments. This is not allowed"
)
self[key].to(value.device)

batch_size = self._extract_batch_size(self[key], batch_size, meta)
self.update_metrics(key, value, batch_size)

def register_key(self, key: str, meta: _Metadata, value: _VALUE) -> None:
"""Create one _ResultMetric object per value.
Value can be provided as a nested collection
"""
metric = _ResultMetric(meta, isinstance(value, Tensor)).to(value.device)
self[key] = metric

def update_metrics(self, key: str, value: _VALUE, batch_size: int) -> None:
result_metric = self[key]
# performance: avoid calling `__call__` to avoid the checks in `torch.nn.Module._call_impl`
Expand Down
20 changes: 20 additions & 0 deletions tests/tests_pytorch/trainer/logging_/test_logger_connector.py
Expand Up @@ -639,3 +639,23 @@ def test_result_collection_no_batch_size_extraction():
assert results["training_step.epoch_log_val"].value == log_val * batch_size
assert results["training_step.epoch_log_val"].cumulated_batch_size == batch_size
assert results["training_step.epoch_sum_log_val"].value == log_val


def test_result_collection_changes_device(): # mock_torch):
results = _ResultCollection(training=True)
fx_name = "training_step"
log_val = torch.tensor(7.0)

# same device as the original tensor
results.log(fx_name, "step_log_val", log_val, on_step=True, on_epoch=False, reduce_fx="mean")
assert results["training_step.step_log_val"].cumulated_batch_size.device == log_val.device

# moved to cpu
cumulated_batch_size = results["training_step.step_log_val"].cumulated_batch_size = Mock(spec=torch.Tensor)
cumulated_batch_size.to.return_value = Mock(spec=torch.Tensor)
results.cpu()
cumulated_batch_size.to.assert_called_once_with(log_val.device)

# same device as the new tensor
results.log(fx_name, "step_log_val", log_val, on_step=True, on_epoch=False, reduce_fx="mean")
cumulated_batch_size.to.return_value.to.assert_called_once_with(log_val.device)

0 comments on commit 8085e74

Please sign in to comment.