From a95e1c6dae98b2e6e9de7b610fd57415b8d780e2 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 3 Feb 2022 05:56:29 +0100 Subject: [PATCH 1/7] Return the output of the optimizer step --- pytorch_lightning/core/lightning.py | 20 +++++---- pytorch_lightning/core/optimizer.py | 9 ++-- .../plugins/precision/apex_amp.py | 5 ++- .../plugins/precision/deepspeed.py | 4 +- pytorch_lightning/plugins/precision/ipu.py | 3 +- .../plugins/precision/native_amp.py | 6 ++- .../plugins/precision/precision_plugin.py | 4 +- pytorch_lightning/plugins/precision/tpu.py | 3 +- pytorch_lightning/strategies/strategy.py | 4 +- tests/core/test_lightning_module.py | 41 ------------------- tests/core/test_lightning_optimizer.py | 35 +++++++++++++++- 11 files changed, 68 insertions(+), 66 deletions(-) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 1b280f12d4166..10f083d0eee50 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1544,25 +1544,29 @@ def optimizer_step( on_tpu: bool = False, using_native_amp: bool = False, using_lbfgs: bool = False, - ) -> None: + ) -> Any: r""" - Override this method to adjust the default way the - :class:`~pytorch_lightning.trainer.trainer.Trainer` calls each optimizer. - By default, Lightning calls ``step()`` and ``zero_grad()`` as shown in the example - once per optimizer. This method (and ``zero_grad()``) won't be called during the - accumulation phase when ``Trainer(accumulate_grad_batches != 1)``. + Override this method to adjust the default way the :class:`~pytorch_lightning.trainer.trainer.Trainer` calls + each optimizer. + + By default, Lightning calls ``step()`` and ``zero_grad()`` as shown in the example once per optimizer. + This method (and ``zero_grad()``) won't be called during the accumulation phase when + ``Trainer(accumulate_grad_batches != 1)``. Overriding this hook has no benefit with manual optimization. Args: epoch: Current epoch batch_idx: Index of current batch optimizer: A PyTorch optimizer optimizer_idx: If you used multiple optimizers, this indexes into that list. - optimizer_closure: Closure for all optimizers. This closure must be executed as it includes the + optimizer_closure: The optimizer closure. This closure must be executed as it includes the calls to ``training_step()``, ``optimizer.zero_grad()``, and ``backward()``. on_tpu: ``True`` if TPU backward is required using_native_amp: ``True`` if using native amp using_lbfgs: True if the matching optimizer is :class:`torch.optim.LBFGS` + Returns: + The output from the step call, which is generally the output of the closure execution. + Examples:: # DEFAULT @@ -1615,7 +1619,7 @@ def optimizer_step( pg["lr"] = lr_scale * self.learning_rate """ - optimizer.step(closure=optimizer_closure) + return optimizer.step(closure=optimizer_closure) def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int): """Override this method to change the default behaviour of ``optimizer.zero_grad()``. diff --git a/pytorch_lightning/core/optimizer.py b/pytorch_lightning/core/optimizer.py index 3825382b2fa35..2279d580ebf6e 100644 --- a/pytorch_lightning/core/optimizer.py +++ b/pytorch_lightning/core/optimizer.py @@ -94,13 +94,16 @@ def toggle_model(self, sync_grad: bool = True) -> Generator[None, None, None]: yield lightning_module.untoggle_optimizer(self._optimizer_idx) - def step(self, closure: Optional[Callable[[], Any]] = None, **kwargs: Any) -> None: + def step(self, closure: Optional[Callable[[], Any]] = None, **kwargs: Any) -> Any: """Performs a single optimization step (parameter update). Args: - closure: An optional optimizer_closure. + closure: An optional optimizer closure. kwargs: Any additional arguments to the ``optimizer.step()`` call. + Returns: + The output from the step call, which is generally the output of the closure execution. + Example:: # Scenario for a GAN using manual optimization @@ -163,7 +166,7 @@ def closure_dis(): assert self._strategy is not None assert self._strategy.lightning_module is not None with self._strategy.lightning_module.trainer.profiler.profile(profiler_action): - self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs) + return self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs) def _init_optimizers_and_lr_schedulers( diff --git a/pytorch_lightning/plugins/precision/apex_amp.py b/pytorch_lightning/plugins/precision/apex_amp.py index 1e86ec2633fe9..9e8b35289bce5 100644 --- a/pytorch_lightning/plugins/precision/apex_amp.py +++ b/pytorch_lightning/plugins/precision/apex_amp.py @@ -80,7 +80,7 @@ def optimizer_step( optimizer_idx: int, closure: Callable[[], Any], **kwargs: Any, - ) -> None: + ) -> Any: if isinstance(optimizer, LBFGS): raise MisconfigurationException( f"apex AMP and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})." @@ -90,7 +90,8 @@ def optimizer_step( skipped_backward = closure_result is None # in manual optimization, the closure does not return a value if not isinstance(model, pl.LightningModule) or not model.automatic_optimization or not skipped_backward: - optimizer.step(**kwargs) + return optimizer.step(**kwargs) + return closure_result def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: if "amp_scaling_state" in checkpoint: diff --git a/pytorch_lightning/plugins/precision/deepspeed.py b/pytorch_lightning/plugins/precision/deepspeed.py index 3a6eb85769559..b629c931d2fe9 100644 --- a/pytorch_lightning/plugins/precision/deepspeed.py +++ b/pytorch_lightning/plugins/precision/deepspeed.py @@ -61,7 +61,7 @@ def optimizer_step( optimizer_idx: int, closure: Callable[[], Any], **kwargs: Any, - ) -> None: + ) -> Any: if isinstance(optimizer, LBFGS): raise MisconfigurationException( f"DeepSpeed and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})." @@ -76,7 +76,7 @@ def optimizer_step( ) # DeepSpeed handles the optimizer step internally deepspeed_engine = model.trainer.model if isinstance(model, pl.LightningModule) else model - deepspeed_engine.step(**kwargs) + return deepspeed_engine.step(**kwargs) def clip_gradients( self, diff --git a/pytorch_lightning/plugins/precision/ipu.py b/pytorch_lightning/plugins/precision/ipu.py index 1b321d2c5acd5..cb267000ad551 100644 --- a/pytorch_lightning/plugins/precision/ipu.py +++ b/pytorch_lightning/plugins/precision/ipu.py @@ -47,7 +47,7 @@ def optimizer_step( optimizer_idx: int, closure: Callable[[], Any], **kwargs: Any, - ) -> None: + ) -> Any: """IPUs handle the optimizer step internally.""" if isinstance(optimizer, LBFGS): raise MisconfigurationException( @@ -64,6 +64,7 @@ def optimizer_step( " Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`" " requesting this feature." ) + return closure_result def clip_gradients( self, diff --git a/pytorch_lightning/plugins/precision/native_amp.py b/pytorch_lightning/plugins/precision/native_amp.py index f6cb28c76c867..2c9174e5517cb 100644 --- a/pytorch_lightning/plugins/precision/native_amp.py +++ b/pytorch_lightning/plugins/precision/native_amp.py @@ -74,7 +74,7 @@ def optimizer_step( optimizer_idx: int, closure: Callable[[], Any], **kwargs: Any, - ) -> None: + ) -> Any: if self.scaler is None: # skip scaler logic, as bfloat16 does not require scaler return super().optimizer_step(model, optimizer, optimizer_idx, closure, **kwargs) @@ -90,8 +90,10 @@ def optimizer_step( # in manual optimization, the closure does not return a value if not isinstance(model, pl.LightningModule) or not model.automatic_optimization or not skipped_backward: # note: the scaler will skip the `optimizer.step` if nonfinite gradients are found - self.scaler.step(optimizer, **kwargs) + step_output = self.scaler.step(optimizer, **kwargs) self.scaler.update() + return step_output + return closure_result def autocast_context_manager(self) -> Union["old_autocast", "new_autocast"]: if _TORCH_GREATER_EQUAL_1_10: diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py index e875fe51f19e7..6bdfe12dc2a2d 100644 --- a/pytorch_lightning/plugins/precision/precision_plugin.py +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -146,11 +146,11 @@ def optimizer_step( optimizer_idx: int, closure: Callable[[], Any], **kwargs: Any, - ) -> None: + ) -> Any: """Hook to run the optimizer step.""" if isinstance(model, pl.LightningModule): closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure) - optimizer.step(closure=closure, **kwargs) + return optimizer.step(closure=closure, **kwargs) def _track_grad_norm(self, trainer: "pl.Trainer") -> None: if trainer.track_grad_norm == -1: diff --git a/pytorch_lightning/plugins/precision/tpu.py b/pytorch_lightning/plugins/precision/tpu.py index 8b8871b3006de..1afd34264c60c 100644 --- a/pytorch_lightning/plugins/precision/tpu.py +++ b/pytorch_lightning/plugins/precision/tpu.py @@ -36,7 +36,7 @@ def optimizer_step( optimizer_idx: int, closure: Callable[[], Any], **kwargs: Any - ) -> None: + ) -> Any: if isinstance(model, pl.LightningModule): closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure) closure_result = xm.optimizer_step(optimizer, optimizer_args={"closure": closure, **kwargs}) @@ -49,3 +49,4 @@ def optimizer_step( " Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`" " requesting this feature." ) + return closure_result diff --git a/pytorch_lightning/strategies/strategy.py b/pytorch_lightning/strategies/strategy.py index 4dd4ca135bcf1..629911911b780 100644 --- a/pytorch_lightning/strategies/strategy.py +++ b/pytorch_lightning/strategies/strategy.py @@ -178,7 +178,7 @@ def optimizer_step( closure: Callable[[], Any], model: Optional[Union["pl.LightningModule", Module]] = None, **kwargs: Any, - ) -> None: + ) -> Any: """performs the actual optimizer step. Args: @@ -189,7 +189,7 @@ def optimizer_step( **kwargs: Any extra arguments to ``optimizer.step`` """ model = model or self.lightning_module - self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs) + return self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs) def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]: """Setup a model and multiple optimizers together. diff --git a/tests/core/test_lightning_module.py b/tests/core/test_lightning_module.py index 63d5ecf670f91..70abaed4492d7 100644 --- a/tests/core/test_lightning_module.py +++ b/tests/core/test_lightning_module.py @@ -76,47 +76,6 @@ def test_property_logger(tmpdir): assert model.logger == logger -def test_params_groups_and_state_are_accessible(tmpdir): - class TestModel(BoringModel): - def training_step(self, batch, batch_idx, optimizer_idx): - output = self.layer(batch) - loss = self.loss(batch, output) - return {"loss": loss} - - def configure_optimizers(self): - optimizer = SGD(self.layer.parameters(), lr=0.1) - optimizer_2 = Adam(self.layer.parameters(), lr=0.1) - return [optimizer, optimizer_2] - - def optimizer_step( - self, - epoch, - batch_idx, - optimizer, - optimizer_idx, - optimizer_closure, - on_tpu=False, - using_native_amp=False, - using_lbfgs=False, - ): - # warm up lr - if self.trainer.global_step < 500: - lr_scale = min(1.0, float(self.trainer.global_step + 1) / 500.0) - for pg in optimizer.param_groups: - pg["lr"] = lr_scale * 0.01 - - optimizer.step(closure=optimizer_closure) - - model = TestModel() - model.training_epoch_end = None - - trainer = Trainer( - max_epochs=1, default_root_dir=tmpdir, limit_train_batches=8, limit_val_batches=1, accumulate_grad_batches=1 - ) - - trainer.fit(model) - - def test_toggle_untoggle_2_optimizers_no_shared_parameters(tmpdir): class TestModel(BoringModel): def training_step(self, batch, batch_idx, optimizer_idx=None): diff --git a/tests/core/test_lightning_optimizer.py b/tests/core/test_lightning_optimizer.py index 970acafa37a41..f37b5c6d39129 100644 --- a/tests/core/test_lightning_optimizer.py +++ b/tests/core/test_lightning_optimizer.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from unittest.mock import DEFAULT, patch +from unittest.mock import DEFAULT, Mock, patch import pytest import torch @@ -95,7 +95,10 @@ def closure(opt): opt_1.step() closure(opt_2) - opt_2.step() + step_output = opt_2.step() + # check that the step output is returned with manual optimization + # since the optimizer is mocked, the step output is a Mock + assert isinstance(step_output, Mock) def configure_optimizers(self): optimizer_1 = torch.optim.SGD(self.layer.parameters(), lr=0.1) @@ -314,3 +317,31 @@ def test_lightning_optimizer_keeps_hooks(tmpdir): assert len(optimizer._fwd_handles) == 1 del lightning_optimizer assert len(optimizer._fwd_handles) == 1 + + +def test_params_groups_and_state_are_accessible(tmpdir): + class TestModel(BoringModel): + def training_step(self, batch, batch_idx, optimizer_idx): + output = self.layer(batch) + loss = self.loss(batch, output) + self.__loss = loss + return loss + + def configure_optimizers(self): + optimizer = SGD(self.layer.parameters(), lr=0.1) + optimizer_2 = Adam(self.layer.parameters(), lr=0.1) + return [optimizer, optimizer_2] + + def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, **__): + # check attributes are accessible + assert all("lr" in pg for pg in optimizer.param_groups) + assert optimizer.state is optimizer._optimizer.state + assert optimizer.defaults is optimizer._optimizer.defaults + + loss = optimizer.step(closure=optimizer_closure) + # the optimizer step still returns the loss + assert loss == self.__loss + + model = TestModel() + trainer = Trainer(max_epochs=1, default_root_dir=tmpdir, limit_train_batches=1, limit_val_batches=1) + trainer.fit(model) From 73c8e010673b2d8f847edd49ddf706717bcce849 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 3 Feb 2022 06:00:50 +0100 Subject: [PATCH 2/7] Also for Lite --- pytorch_lightning/lite/wrappers.py | 2 +- tests/lite/test_wrappers.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index b4a1bed0aa17d..340f5d35045ff 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -56,7 +56,7 @@ def state_dict(self) -> Dict[str, Tensor]: def step(self, closure: Optional[Callable] = None) -> None: closure = closure or _do_nothing_closure - self._strategy.optimizer_step( + return self._strategy.optimizer_step( self.optimizer, opt_idx=0, closure=closure, diff --git a/tests/lite/test_wrappers.py b/tests/lite/test_wrappers.py index 65522b74936c8..f70a6a9e38230 100644 --- a/tests/lite/test_wrappers.py +++ b/tests/lite/test_wrappers.py @@ -155,7 +155,9 @@ def test_lite_optimizer_steps(): """Test that the LiteOptimizer forwards the step() and zero_grad() calls to the wrapped optimizer.""" optimizer = Mock() strategy = Mock() + strategy.optimizer_step.return_value = 123 lite_optimizer = _LiteOptimizer(optimizer=optimizer, strategy=strategy) - lite_optimizer.step() + step_output = lite_optimizer.step() + assert step_output == 123 strategy.optimizer_step.assert_called_once() strategy.optimizer_step.assert_called_with(optimizer, opt_idx=0, closure=ANY, model=strategy.model) From 77b3ef8c202d71641a0c19b0c19fe9b7894a5db1 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 3 Feb 2022 06:02:35 +0100 Subject: [PATCH 3/7] Update CHANGELOG --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c9d862bfe65c1..5949e2c3d288c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -75,6 +75,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `opt_idx` to scheduler config if not assigned by user ([#11247](https://github.com/PyTorchLightning/pytorch-lightning/pull/11247)) +- Return the output from the `optimizer.step` ([#11711](https://github.com/PyTorchLightning/pytorch-lightning/pull/11711)) + + - Teardown the active loop and strategy on exception ([#11620](https://github.com/PyTorchLightning/pytorch-lightning/pull/11620)) From ee0a7e5f997ffa1c1e66134047d86cf0bef27500 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 3 Feb 2022 06:06:37 +0100 Subject: [PATCH 4/7] Clarify --- CHANGELOG.md | 2 +- pytorch_lightning/core/lightning.py | 7 ++----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5949e2c3d288c..4accf1a92f029 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -75,7 +75,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `opt_idx` to scheduler config if not assigned by user ([#11247](https://github.com/PyTorchLightning/pytorch-lightning/pull/11247)) -- Return the output from the `optimizer.step` ([#11711](https://github.com/PyTorchLightning/pytorch-lightning/pull/11711)) +- Return the output of the `optimizer.step`, useful for LightningLite users, manual optimization users or users overriding `LightningModule.optimizer_step` ([#11711](https://github.com/PyTorchLightning/pytorch-lightning/pull/11711)) - Teardown the active loop and strategy on exception ([#11620](https://github.com/PyTorchLightning/pytorch-lightning/pull/11620)) diff --git a/pytorch_lightning/core/lightning.py b/pytorch_lightning/core/lightning.py index 10f083d0eee50..29386d7475468 100644 --- a/pytorch_lightning/core/lightning.py +++ b/pytorch_lightning/core/lightning.py @@ -1544,7 +1544,7 @@ def optimizer_step( on_tpu: bool = False, using_native_amp: bool = False, using_lbfgs: bool = False, - ) -> Any: + ) -> None: r""" Override this method to adjust the default way the :class:`~pytorch_lightning.trainer.trainer.Trainer` calls each optimizer. @@ -1564,9 +1564,6 @@ def optimizer_step( using_native_amp: ``True`` if using native amp using_lbfgs: True if the matching optimizer is :class:`torch.optim.LBFGS` - Returns: - The output from the step call, which is generally the output of the closure execution. - Examples:: # DEFAULT @@ -1619,7 +1616,7 @@ def optimizer_step( pg["lr"] = lr_scale * self.learning_rate """ - return optimizer.step(closure=optimizer_closure) + optimizer.step(closure=optimizer_closure) def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int): """Override this method to change the default behaviour of ``optimizer.zero_grad()``. From 12bad0881badb9e619a684cce3e677696fac26be Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 3 Feb 2022 06:07:54 +0100 Subject: [PATCH 5/7] Clarify --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4accf1a92f029..ac98ac375c4f4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -75,7 +75,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `opt_idx` to scheduler config if not assigned by user ([#11247](https://github.com/PyTorchLightning/pytorch-lightning/pull/11247)) -- Return the output of the `optimizer.step`, useful for LightningLite users, manual optimization users or users overriding `LightningModule.optimizer_step` ([#11711](https://github.com/PyTorchLightning/pytorch-lightning/pull/11711)) +- Return the output of the `optimizer.step`. This can be useful for `LightningLite` users, manual optimization users, or users overriding `LightningModule.optimizer_step` ([#11711](https://github.com/PyTorchLightning/pytorch-lightning/pull/11711)) - Teardown the active loop and strategy on exception ([#11620](https://github.com/PyTorchLightning/pytorch-lightning/pull/11620)) From 8300d72b5ffb3020fbb1f09f896dec870ea3a890 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 3 Feb 2022 06:16:52 +0100 Subject: [PATCH 6/7] Skip validaiton --- tests/core/test_lightning_optimizer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/core/test_lightning_optimizer.py b/tests/core/test_lightning_optimizer.py index f37b5c6d39129..3b2fe7bfde8cc 100644 --- a/tests/core/test_lightning_optimizer.py +++ b/tests/core/test_lightning_optimizer.py @@ -343,5 +343,5 @@ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_c assert loss == self.__loss model = TestModel() - trainer = Trainer(max_epochs=1, default_root_dir=tmpdir, limit_train_batches=1, limit_val_batches=1) + trainer = Trainer(max_epochs=1, default_root_dir=tmpdir, limit_train_batches=1, limit_val_batches=0) trainer.fit(model) From ac87d9967a333228d8fc1ace1015d473c78410e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 3 Feb 2022 15:45:03 +0100 Subject: [PATCH 7/7] Update pytorch_lightning/lite/wrappers.py Co-authored-by: Rohit Gupta --- pytorch_lightning/lite/wrappers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index 340f5d35045ff..7e869cbd4cc12 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -54,7 +54,7 @@ def optimizer(self) -> Optimizer: def state_dict(self) -> Dict[str, Tensor]: return self._strategy.optimizer_state(self.optimizer) - def step(self, closure: Optional[Callable] = None) -> None: + def step(self, closure: Optional[Callable] = None) -> Any: closure = closure or _do_nothing_closure return self._strategy.optimizer_step( self.optimizer,