From 88a3acc4d80618b7014a243948bf7d2daaa7d60d Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 8 Feb 2022 04:50:32 +0100 Subject: [PATCH 1/5] Support optimizer step progress tracking with manual optimization --- pytorch_lightning/core/optimizer.py | 7 +++ .../loops/optimization/manual_loop.py | 17 ++++++ .../loops/optimization/optimizer_loop.py | 4 +- pytorch_lightning/trainer/trainer.py | 1 + tests/loops/test_loop_state_dict.py | 4 ++ tests/loops/test_loops.py | 8 +++ .../optimization/test_manual_optimization.py | 61 +++++-------------- 7 files changed, 53 insertions(+), 49 deletions(-) diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index f059739ca5665..8690c2b1a055b 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -53,6 +53,9 @@ def __init__(self, optimizer: Optimizer): self._optimizer = optimizer self._strategy: Optional[pl.strategies.Strategy] = None self._optimizer_idx = 0 + # to inject logic around the optimizer step, particularly useful with manual optimization + self.on_before_step = do_nothing_closure + self.on_after_step = do_nothing_closure @property def optimizer(self) -> Optimizer: @@ -154,6 +157,8 @@ def closure_dis(): with opt_dis.toggle_model(sync_grad=accumulated_grad_batches): opt_dis.step(closure=closure_dis) """ + self.on_before_step() + if closure is None: closure = do_nothing_closure profiler_action = "optimizer_step_without_closure" @@ -168,6 +173,8 @@ def closure_dis(): with self._strategy.lightning_module.trainer.profiler.profile(profiler_action): return self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs) + self.on_after_step() + def _init_optimizers_and_lr_schedulers( model: "pl.LightningModule", diff --git a/pytorch_lightning/loops/optimization/manual_loop.py b/pytorch_lightning/loops/optimization/manual_loop.py index 490ad6ce630ac..6150deb76dd5c 100644 --- a/pytorch_lightning/loops/optimization/manual_loop.py +++ b/pytorch_lightning/loops/optimization/manual_loop.py @@ -19,6 +19,7 @@ from pytorch_lightning.loops import Loop from pytorch_lightning.loops.optimization.closure import OutputResult from pytorch_lightning.loops.utilities import _build_training_step_kwargs, _extract_hiddens +from pytorch_lightning.trainer.progress import Progress, ReadyCompletedTracker from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.types import STEP_OUTPUT @@ -74,6 +75,10 @@ class ManualOptimization(Loop[_OUTPUTS_TYPE]): def __init__(self) -> None: super().__init__() + # since manual optimization does not track lr scheduler or optimizer frequencies, we use a simpler progress than + # `OptimizerProgress` + self.optim_step_progress = Progress.from_defaults(ReadyCompletedTracker) + self._done: bool = False self._hiddens: Optional[Any] = None self._output: _OUTPUTS_TYPE = {} @@ -85,6 +90,12 @@ def done(self) -> bool: def reset(self) -> None: self._done = False + def on_run_start(self, *_: Any, **__: Any) -> None: + # inject logic around the optimizer step + for i, lightning_optimizer in self.trainer.strategy._lightning_optimizers.items(): + lightning_optimizer.on_before_step = self._on_before_step + lightning_optimizer.on_after_step = self._on_after_step + def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override] """Performs the training step for manual optimization. @@ -127,3 +138,9 @@ def on_run_end(self) -> _OUTPUTS_TYPE: """Returns the result of this loop, i.e., the post-processed outputs from the training step.""" output, self._output = self._output, {} # free memory return output + + def _on_before_step(self) -> None: + self.optim_step_progress.increment_ready() + + def _on_after_step(self) -> None: + self.optim_step_progress.increment_completed() diff --git a/pytorch_lightning/loops/optimization/optimizer_loop.py b/pytorch_lightning/loops/optimization/optimizer_loop.py index b23ef65f88c22..b8a8341c0d863 100644 --- a/pytorch_lightning/loops/optimization/optimizer_loop.py +++ b/pytorch_lightning/loops/optimization/optimizer_loop.py @@ -317,7 +317,7 @@ def backward_fn(loss: Tensor) -> None: return backward_fn def _run_optimization_start(self, opt_idx: int, optimizer: torch.optim.Optimizer) -> None: - """Toggles the optimizer to ensure the correct one is used and prevend dangling grads. + """Toggles the optimizer to ensure the correct one is used and prevent dangling grads. Args: opt_idx: the index of the optimizer to use @@ -348,7 +348,7 @@ def _optimizer_step( opt_idx: the index of the current :param:`optimizer` batch_idx: the index of the current batch train_step_and_backward_closure: the closure function performing the train step and computing the - gradients. By default called by the optimizer (if possible) + gradients. By default, called by the optimizer (if possible) """ is_lbfgs = isinstance(optimizer, torch.optim.LBFGS) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 4b68adbdf31b7..b6a0d7fa452e0 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -2010,6 +2010,7 @@ def data_parallel_device_ids(self) -> Optional[List[int]]: @property def lightning_module(self) -> "pl.LightningModule": + # TODO: this is actually an optional return return self.strategy.lightning_module @property diff --git a/tests/loops/test_loop_state_dict.py b/tests/loops/test_loop_state_dict.py index 28d96b324435f..1e67fcc0ed8db 100644 --- a/tests/loops/test_loop_state_dict.py +++ b/tests/loops/test_loop_state_dict.py @@ -59,6 +59,10 @@ def test_loops_state_dict_structure(): }, "epoch_loop.batch_loop.state_dict": {}, "epoch_loop.batch_loop.manual_loop.state_dict": {}, + "epoch_loop.batch_loop.manual_loop.optim_step_progress": { + "total": {"ready": 0, "completed": 0}, + "current": {"ready": 0, "completed": 0}, + }, "epoch_loop.batch_loop.optimizer_loop.state_dict": {}, "epoch_loop.batch_loop.optimizer_loop.optim_progress": { "optimizer": { diff --git a/tests/loops/test_loops.py b/tests/loops/test_loops.py index f6e75b3b7c345..3bc26818d0de1 100644 --- a/tests/loops/test_loops.py +++ b/tests/loops/test_loops.py @@ -512,6 +512,10 @@ def configure_optimizers_multiple(self): }, "epoch_loop.batch_loop.state_dict": ANY, "epoch_loop.batch_loop.manual_loop.state_dict": ANY, + "epoch_loop.batch_loop.manual_loop.optim_step_progress": { + "total": {"ready": 0, "completed": 0}, + "current": {"ready": 0, "completed": 0}, + }, "epoch_loop.batch_loop.optimizer_loop.state_dict": {}, "epoch_loop.batch_loop.optimizer_loop.optim_progress": { "optimizer_position": stop_optimizer, @@ -681,6 +685,10 @@ def train_dataloader(self): }, "epoch_loop.batch_loop.state_dict": ANY, "epoch_loop.batch_loop.manual_loop.state_dict": ANY, + "epoch_loop.batch_loop.manual_loop.optim_step_progress": { + "total": {"ready": 0, "completed": 0}, + "current": {"ready": 0, "completed": 0}, + }, "epoch_loop.batch_loop.optimizer_loop.state_dict": {}, "epoch_loop.batch_loop.optimizer_loop.optim_progress": { "optimizer_position": n_optimizers, diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index 17539d1e4d71d..f74966e3e0279 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -167,6 +167,7 @@ def training_epoch_end(self, outputs) -> None: with mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward) as bwd_mock: trainer.fit(model) assert bwd_mock.call_count == limit_train_batches * 3 + assert trainer.global_step == limit_train_batches * 2 def test_multiple_optimizers_manual_log(tmpdir): @@ -530,18 +531,14 @@ def optimizer_closure(): weight_after = self.layer.weight.clone() assert not torch.equal(weight_before, weight_after) - def configure_optimizers(self): - return torch.optim.SGD(self.layer.parameters(), lr=0.1) - model = TestModel() - model.val_dataloader = None model.training_epoch_end = None limit_train_batches = 2 trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=limit_train_batches, - limit_val_batches=2, + limit_val_batches=0, max_epochs=1, log_every_n_steps=1, ) @@ -553,51 +550,37 @@ def configure_optimizers(self): assert trainer.progress_bar_metrics["train_loss_epoch"] == torch.stack(model._losses).mean() -def test_step_with_optimizer_closure_and_accumulated_grad(tmpdir): - """Tests that `step` works with optimizer_closure and accumulated_grad.""" - +def test_step_with_optimizer_closure_2(tmpdir): class TestModel(BoringModel): def __init__(self): super().__init__() self.automatic_optimization = False def training_step(self, batch, batch_idx): - # manual opt = self.optimizers() x = batch[0] - - loss_1 = self(x) - loss_1 = self.loss(loss_1, loss_1) + loss = self(x).sum() def optimizer_closure(): # emulate bayesian optimization. num_backward = 1 for backward_idx in range(num_backward + 1): retain_graph = num_backward != backward_idx - self.manual_backward(loss_1, retain_graph=retain_graph) + self.manual_backward(loss, retain_graph=retain_graph) weight_before = self.layer.weight.clone() - opt.step(closure=optimizer_closure) - weight_after = self.layer.weight.clone() - if not self.trainer.fit_loop._should_accumulate(): - assert not torch.equal(weight_before, weight_after) - else: - assert self.layer.weight.grad is not None - - def configure_optimizers(self): - return torch.optim.SGD(self.layer.parameters(), lr=0.1) + assert not torch.equal(weight_before, weight_after) model = TestModel() - model.val_dataloader = None model.training_epoch_end = None limit_train_batches = 4 trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=limit_train_batches, - limit_val_batches=2, + limit_val_batches=0, max_epochs=1, log_every_n_steps=1, ) @@ -605,6 +588,7 @@ def configure_optimizers(self): with mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward) as bwd_mock: trainer.fit(model) assert bwd_mock.call_count == limit_train_batches * 2 + assert trainer.global_step == limit_train_batches @patch("torch.optim.SGD.step") @@ -620,41 +604,23 @@ def on_train_start(self) -> None: step_mock.reset_mock() def training_step(self, batch, batch_idx): - # manual opt = self.optimizers() - x = batch[0] - - loss_1 = self(x) - loss_1 = self.loss(loss_1, loss_1) - - def optimizer_closure(): - # emulate bayesian optimization. - num_backward = 1 - for backward_idx in range(num_backward + 1): - retain_graph = num_backward != backward_idx - self.manual_backward(loss_1, retain_graph=retain_graph) - - opt.step(closure=optimizer_closure) - opt.zero_grad() - - def configure_optimizers(self): - return torch.optim.SGD(self.layer.parameters(), lr=0.1) + opt.step(closure=lambda: ..., foo=123) model = TestModel() - model.val_dataloader = None model.training_epoch_end = None - limit_train_batches = 4 + limit_train_batches = 2 trainer = Trainer( default_root_dir=tmpdir, limit_train_batches=limit_train_batches, - limit_val_batches=2, + limit_val_batches=0, max_epochs=1, - log_every_n_steps=1, ) trainer.fit(model) - assert step_mock.mock_calls == [call(closure=ANY) for _ in range(limit_train_batches)] + assert step_mock.mock_calls == [call(closure=ANY, foo=123) for _ in range(limit_train_batches)] + assert trainer.global_step == limit_train_batches @patch("torch.optim.Adam.step") @@ -730,6 +696,7 @@ def configure_optimizers(self): trainer.fit(model) assert mock_sgd_step.mock_calls == [call(closure=ANY, optim="sgd") for _ in range(4)] assert mock_adam_step.mock_calls == [call(closure=ANY) for _ in range(2)] + assert trainer.global_step == 4 + 2 class TesManualOptimizationDDPModel(BoringModel): From 47487df06164aa28e05870ebb012585580ea3d1b Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 10 Feb 2022 17:46:26 +0100 Subject: [PATCH 2/5] Fix and CHANGELOG --- CHANGELOG.md | 3 +++ pytorch_lightning/loops/optimization/manual_loop.py | 7 ++++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2782a4cb1d9f1..e7b7622ed716b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -77,6 +77,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `opt_idx` to scheduler config if not assigned by user ([#11247](https://github.com/PyTorchLightning/pytorch-lightning/pull/11247)) +- Added support for optimizer step progress tracking with manual optimization ([#11848](https://github.com/PyTorchLightning/pytorch-lightning/pull/11848)) + + - Return the output of the `optimizer.step`. This can be useful for `LightningLite` users, manual optimization users, or users overriding `LightningModule.optimizer_step` ([#11711](https://github.com/PyTorchLightning/pytorch-lightning/pull/11711)) diff --git a/pytorch_lightning/loops/optimization/manual_loop.py b/pytorch_lightning/loops/optimization/manual_loop.py index 6150deb76dd5c..eebfe405e7c5f 100644 --- a/pytorch_lightning/loops/optimization/manual_loop.py +++ b/pytorch_lightning/loops/optimization/manual_loop.py @@ -16,6 +16,7 @@ from torch import Tensor +from pytorch_lightning.core.optimizer import do_nothing_closure from pytorch_lightning.loops import Loop from pytorch_lightning.loops.optimization.closure import OutputResult from pytorch_lightning.loops.utilities import _build_training_step_kwargs, _extract_hiddens @@ -76,7 +77,7 @@ class ManualOptimization(Loop[_OUTPUTS_TYPE]): def __init__(self) -> None: super().__init__() # since manual optimization does not track lr scheduler or optimizer frequencies, we use a simpler progress than - # `OptimizerProgress` + # `OptimizationProgress` self.optim_step_progress = Progress.from_defaults(ReadyCompletedTracker) self._done: bool = False @@ -137,6 +138,10 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override] def on_run_end(self) -> _OUTPUTS_TYPE: """Returns the result of this loop, i.e., the post-processed outputs from the training step.""" output, self._output = self._output, {} # free memory + # reset logic around the optimizer step + for i, lightning_optimizer in self.trainer.strategy._lightning_optimizers.items(): + lightning_optimizer.on_before_step = do_nothing_closure + lightning_optimizer.on_after_step = do_nothing_closure return output def _on_before_step(self) -> None: From ed2f6c7b63c7395b77ed8e86374bf95aebf31138 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 10 Feb 2022 17:50:26 +0100 Subject: [PATCH 3/5] Remove dupe test --- CHANGELOG.md | 2 +- .../optimization/test_manual_optimization.py | 35 +------------------ 2 files changed, 2 insertions(+), 35 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e7b7622ed716b..e9b5e2fcc3b8a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -77,8 +77,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `opt_idx` to scheduler config if not assigned by user ([#11247](https://github.com/PyTorchLightning/pytorch-lightning/pull/11247)) -- Added support for optimizer step progress tracking with manual optimization ([#11848](https://github.com/PyTorchLightning/pytorch-lightning/pull/11848)) +- Added support for optimizer step progress tracking with manual optimization ([#11848](https://github.com/PyTorchLightning/pytorch-lightning/pull/11848)) - Return the output of the `optimizer.step`. This can be useful for `LightningLite` users, manual optimization users, or users overriding `LightningModule.optimizer_step` ([#11711](https://github.com/PyTorchLightning/pytorch-lightning/pull/11711)) diff --git a/tests/trainer/optimization/test_manual_optimization.py b/tests/trainer/optimization/test_manual_optimization.py index f74966e3e0279..ec7fc05d4b67b 100644 --- a/tests/trainer/optimization/test_manual_optimization.py +++ b/tests/trainer/optimization/test_manual_optimization.py @@ -591,43 +591,9 @@ def optimizer_closure(): assert trainer.global_step == limit_train_batches -@patch("torch.optim.SGD.step") -def test_step_with_optimizer_closure_and_extra_arguments(step_mock, tmpdir): - """Tests that `step` works with optimizer_closure and extra arguments.""" - - class TestModel(BoringModel): - def __init__(self): - super().__init__() - self.automatic_optimization = False - - def on_train_start(self) -> None: - step_mock.reset_mock() - - def training_step(self, batch, batch_idx): - opt = self.optimizers() - opt.step(closure=lambda: ..., foo=123) - - model = TestModel() - model.training_epoch_end = None - - limit_train_batches = 2 - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=limit_train_batches, - limit_val_batches=0, - max_epochs=1, - ) - - trainer.fit(model) - assert step_mock.mock_calls == [call(closure=ANY, foo=123) for _ in range(limit_train_batches)] - assert trainer.global_step == limit_train_batches - - @patch("torch.optim.Adam.step") @patch("torch.optim.SGD.step") def test_step_with_optimizer_closure_with_different_frequencies(mock_sgd_step, mock_adam_step, tmpdir): - """Tests that `step` works with optimizer_closure and different accumulated_gradient frequency.""" - class TestModel(BoringModel): def __init__(self): super().__init__() @@ -666,6 +632,7 @@ def dis_closure(): # this will accumulate gradients for 2 batches and then call opt_gen.step() gen_closure() if batch_idx % 2 == 0: + # passing a custom kwarg opt_gen.step(closure=gen_closure, optim="sgd") opt_gen.zero_grad() From 72cc31886ca5e9910516f0a15340f26a6ba5e595 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 10 Feb 2022 18:17:03 +0100 Subject: [PATCH 4/5] Fixes --- pytorch_lightning/core/optimizer.py | 4 +++- pytorch_lightning/loops/fit_loop.py | 7 +++++-- 2 files changed, 8 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index 8690c2b1a055b..1f58027de54ef 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -171,10 +171,12 @@ def closure_dis(): assert self._strategy is not None assert self._strategy.lightning_module is not None with self._strategy.lightning_module.trainer.profiler.profile(profiler_action): - return self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs) + step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs) self.on_after_step() + return step_output + def _init_optimizers_and_lr_schedulers( model: "pl.LightningModule", diff --git a/pytorch_lightning/loops/fit_loop.py b/pytorch_lightning/loops/fit_loop.py index 8cbe4c167a29d..fc85e05b78299 100644 --- a/pytorch_lightning/loops/fit_loop.py +++ b/pytorch_lightning/loops/fit_loop.py @@ -71,7 +71,10 @@ def __init__( @property def global_step(self) -> int: """Returns the global step.""" - return self.epoch_loop.global_step + lightning_module = self.trainer.lightning_module + if lightning_module is None or lightning_module.automatic_optimization: + return self.epoch_loop.global_step + return self.epoch_loop.batch_loop.manual_loop.optim_step_progress.total.completed @global_step.setter def global_step(self, value: int) -> None: @@ -96,7 +99,7 @@ def split_idx(self) -> int: @property def min_steps(self) -> Optional[int]: # TODO(@justusschock): Why aren't we using the attribute in this class? - """Returns the minimum numnber of steps to run.""" + """Returns the minimum number of steps to run.""" return self.epoch_loop.min_steps @min_steps.setter From 124e9ce6195378b00977f93e02a503fb5a559eb4 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 10 Feb 2022 18:43:28 +0100 Subject: [PATCH 5/5] Fixes --- pytorch_lightning/core/optimizer.py | 8 ++++---- pytorch_lightning/loops/optimization/manual_loop.py | 8 ++++---- .../trainer/connectors/checkpoint_connector.py | 2 +- tests/core/test_lightning_optimizer.py | 3 ++- 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index 1f58027de54ef..5c6e0137fe06e 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -54,8 +54,8 @@ def __init__(self, optimizer: Optimizer): self._strategy: Optional[pl.strategies.Strategy] = None self._optimizer_idx = 0 # to inject logic around the optimizer step, particularly useful with manual optimization - self.on_before_step = do_nothing_closure - self.on_after_step = do_nothing_closure + self._on_before_step = do_nothing_closure + self._on_after_step = do_nothing_closure @property def optimizer(self) -> Optimizer: @@ -157,7 +157,7 @@ def closure_dis(): with opt_dis.toggle_model(sync_grad=accumulated_grad_batches): opt_dis.step(closure=closure_dis) """ - self.on_before_step() + self._on_before_step() if closure is None: closure = do_nothing_closure @@ -173,7 +173,7 @@ def closure_dis(): with self._strategy.lightning_module.trainer.profiler.profile(profiler_action): step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs) - self.on_after_step() + self._on_after_step() return step_output diff --git a/pytorch_lightning/loops/optimization/manual_loop.py b/pytorch_lightning/loops/optimization/manual_loop.py index eebfe405e7c5f..dc92fde5a1ac1 100644 --- a/pytorch_lightning/loops/optimization/manual_loop.py +++ b/pytorch_lightning/loops/optimization/manual_loop.py @@ -94,8 +94,8 @@ def reset(self) -> None: def on_run_start(self, *_: Any, **__: Any) -> None: # inject logic around the optimizer step for i, lightning_optimizer in self.trainer.strategy._lightning_optimizers.items(): - lightning_optimizer.on_before_step = self._on_before_step - lightning_optimizer.on_after_step = self._on_after_step + lightning_optimizer._on_before_step = self._on_before_step + lightning_optimizer._on_after_step = self._on_after_step def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override] """Performs the training step for manual optimization. @@ -140,8 +140,8 @@ def on_run_end(self) -> _OUTPUTS_TYPE: output, self._output = self._output, {} # free memory # reset logic around the optimizer step for i, lightning_optimizer in self.trainer.strategy._lightning_optimizers.items(): - lightning_optimizer.on_before_step = do_nothing_closure - lightning_optimizer.on_after_step = do_nothing_closure + lightning_optimizer._on_before_step = do_nothing_closure + lightning_optimizer._on_after_step = do_nothing_closure return output def _on_before_step(self) -> None: diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index bb6889d797f5c..0712f5df85f51 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -333,7 +333,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict: checkpoint = { # the epoch is saved for compatibility but it's not relevant for restoration "epoch": self.trainer.current_epoch, - "global_step": self.trainer.global_step + 1, + "global_step": self.trainer.global_step + model.automatic_optimization, "pytorch-lightning_version": pl.__version__, "state_dict": self._get_lightning_module_state_dict(), "loops": self._get_loops_state_dict(), diff --git a/tests/core/test_lightning_optimizer.py b/tests/core/test_lightning_optimizer.py index 3b2fe7bfde8cc..0ee82fcc2e4ec 100644 --- a/tests/core/test_lightning_optimizer.py +++ b/tests/core/test_lightning_optimizer.py @@ -152,7 +152,8 @@ def test_state(tmpdir): lightning_dict = { k: v for k, v in lightning_optimizer.__dict__.items() - if k not in {"_optimizer", "_optimizer_idx", "_strategy", "_lightning_module"} + if k + not in {"_optimizer", "_optimizer_idx", "_strategy", "_lightning_module", "_on_before_step", "_on_after_step"} } assert lightning_dict == optimizer.__dict__