Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return the output of the optimizer step #11711

Merged
merged 7 commits into from Feb 9, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -75,6 +75,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `opt_idx` to scheduler config if not assigned by user ([#11247](https://github.com/PyTorchLightning/pytorch-lightning/pull/11247))


- Return the output of the `optimizer.step`. This can be useful for `LightningLite` users, manual optimization users, or users overriding `LightningModule.optimizer_step` ([#11711](https://github.com/PyTorchLightning/pytorch-lightning/pull/11711))


- Teardown the active loop and strategy on exception ([#11620](https://github.com/PyTorchLightning/pytorch-lightning/pull/11620))


Expand Down
13 changes: 7 additions & 6 deletions pytorch_lightning/core/lightning.py
Expand Up @@ -1546,18 +1546,19 @@ def optimizer_step(
using_lbfgs: bool = False,
) -> None:
r"""
Override this method to adjust the default way the
:class:`~pytorch_lightning.trainer.trainer.Trainer` calls each optimizer.
By default, Lightning calls ``step()`` and ``zero_grad()`` as shown in the example
once per optimizer. This method (and ``zero_grad()``) won't be called during the
accumulation phase when ``Trainer(accumulate_grad_batches != 1)``.
Override this method to adjust the default way the :class:`~pytorch_lightning.trainer.trainer.Trainer` calls
each optimizer.

By default, Lightning calls ``step()`` and ``zero_grad()`` as shown in the example once per optimizer.
This method (and ``zero_grad()``) won't be called during the accumulation phase when
``Trainer(accumulate_grad_batches != 1)``. Overriding this hook has no benefit with manual optimization.

Args:
epoch: Current epoch
batch_idx: Index of current batch
optimizer: A PyTorch optimizer
optimizer_idx: If you used multiple optimizers, this indexes into that list.
optimizer_closure: Closure for all optimizers. This closure must be executed as it includes the
optimizer_closure: The optimizer closure. This closure must be executed as it includes the
carmocca marked this conversation as resolved.
Show resolved Hide resolved
carmocca marked this conversation as resolved.
Show resolved Hide resolved
calls to ``training_step()``, ``optimizer.zero_grad()``, and ``backward()``.
on_tpu: ``True`` if TPU backward is required
using_native_amp: ``True`` if using native amp
Expand Down
9 changes: 6 additions & 3 deletions pytorch_lightning/core/optimizer.py
Expand Up @@ -94,13 +94,16 @@ def toggle_model(self, sync_grad: bool = True) -> Generator[None, None, None]:
yield
lightning_module.untoggle_optimizer(self._optimizer_idx)

def step(self, closure: Optional[Callable[[], Any]] = None, **kwargs: Any) -> None:
def step(self, closure: Optional[Callable[[], Any]] = None, **kwargs: Any) -> Any:
"""Performs a single optimization step (parameter update).

Args:
closure: An optional optimizer_closure.
closure: An optional optimizer closure.
kwargs: Any additional arguments to the ``optimizer.step()`` call.

Returns:
The output from the step call, which is generally the output of the closure execution.

Example::

# Scenario for a GAN using manual optimization
Expand Down Expand Up @@ -163,7 +166,7 @@ def closure_dis():
assert self._strategy is not None
assert self._strategy.lightning_module is not None
with self._strategy.lightning_module.trainer.profiler.profile(profiler_action):
self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
return self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)


def _init_optimizers_and_lr_schedulers(
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/lite/wrappers.py
Expand Up @@ -56,7 +56,7 @@ def state_dict(self) -> Dict[str, Tensor]:

def step(self, closure: Optional[Callable] = None) -> None:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
closure = closure or _do_nothing_closure
self._strategy.optimizer_step(
return self._strategy.optimizer_step(
self.optimizer,
opt_idx=0,
closure=closure,
Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/plugins/precision/apex_amp.py
Expand Up @@ -80,7 +80,7 @@ def optimizer_step(
optimizer_idx: int,
closure: Callable[[], Any],
**kwargs: Any,
) -> None:
) -> Any:
if isinstance(optimizer, LBFGS):
raise MisconfigurationException(
f"apex AMP and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})."
Expand All @@ -90,7 +90,8 @@ def optimizer_step(
skipped_backward = closure_result is None
# in manual optimization, the closure does not return a value
if not isinstance(model, pl.LightningModule) or not model.automatic_optimization or not skipped_backward:
optimizer.step(**kwargs)
return optimizer.step(**kwargs)
return closure_result

def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None:
if "amp_scaling_state" in checkpoint:
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/precision/deepspeed.py
Expand Up @@ -61,7 +61,7 @@ def optimizer_step(
optimizer_idx: int,
closure: Callable[[], Any],
**kwargs: Any,
) -> None:
) -> Any:
if isinstance(optimizer, LBFGS):
raise MisconfigurationException(
f"DeepSpeed and the LBFGS optimizer are not compatible (optimizer {optimizer_idx})."
Expand All @@ -76,7 +76,7 @@ def optimizer_step(
)
# DeepSpeed handles the optimizer step internally
deepspeed_engine = model.trainer.model if isinstance(model, pl.LightningModule) else model
deepspeed_engine.step(**kwargs)
return deepspeed_engine.step(**kwargs)

def clip_gradients(
self,
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/plugins/precision/ipu.py
Expand Up @@ -47,7 +47,7 @@ def optimizer_step(
optimizer_idx: int,
closure: Callable[[], Any],
**kwargs: Any,
) -> None:
) -> Any:
"""IPUs handle the optimizer step internally."""
if isinstance(optimizer, LBFGS):
raise MisconfigurationException(
Expand All @@ -64,6 +64,7 @@ def optimizer_step(
" Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`"
" requesting this feature."
)
return closure_result

def clip_gradients(
self,
Expand Down
6 changes: 4 additions & 2 deletions pytorch_lightning/plugins/precision/native_amp.py
Expand Up @@ -74,7 +74,7 @@ def optimizer_step(
optimizer_idx: int,
closure: Callable[[], Any],
**kwargs: Any,
) -> None:
) -> Any:
if self.scaler is None:
# skip scaler logic, as bfloat16 does not require scaler
return super().optimizer_step(model, optimizer, optimizer_idx, closure, **kwargs)
Expand All @@ -90,8 +90,10 @@ def optimizer_step(
# in manual optimization, the closure does not return a value
akihironitta marked this conversation as resolved.
Show resolved Hide resolved
if not isinstance(model, pl.LightningModule) or not model.automatic_optimization or not skipped_backward:
# note: the scaler will skip the `optimizer.step` if nonfinite gradients are found
self.scaler.step(optimizer, **kwargs)
step_output = self.scaler.step(optimizer, **kwargs)
self.scaler.update()
return step_output
return closure_result

def autocast_context_manager(self) -> Union["old_autocast", "new_autocast"]:
if _TORCH_GREATER_EQUAL_1_10:
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/plugins/precision/precision_plugin.py
Expand Up @@ -146,11 +146,11 @@ def optimizer_step(
optimizer_idx: int,
closure: Callable[[], Any],
**kwargs: Any,
) -> None:
) -> Any:
"""Hook to run the optimizer step."""
if isinstance(model, pl.LightningModule):
closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)
optimizer.step(closure=closure, **kwargs)
return optimizer.step(closure=closure, **kwargs)

def _track_grad_norm(self, trainer: "pl.Trainer") -> None:
if trainer.track_grad_norm == -1:
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/plugins/precision/tpu.py
Expand Up @@ -36,7 +36,7 @@ def optimizer_step(
optimizer_idx: int,
closure: Callable[[], Any],
**kwargs: Any
) -> None:
) -> Any:
if isinstance(model, pl.LightningModule):
closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)
closure_result = xm.optimizer_step(optimizer, optimizer_args={"closure": closure, **kwargs})
Expand All @@ -49,3 +49,4 @@ def optimizer_step(
" Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`"
" requesting this feature."
)
return closure_result
4 changes: 2 additions & 2 deletions pytorch_lightning/strategies/strategy.py
Expand Up @@ -178,7 +178,7 @@ def optimizer_step(
closure: Callable[[], Any],
model: Optional[Union["pl.LightningModule", Module]] = None,
**kwargs: Any,
) -> None:
) -> Any:
"""performs the actual optimizer step.

Args:
Expand All @@ -189,7 +189,7 @@ def optimizer_step(
**kwargs: Any extra arguments to ``optimizer.step``
"""
model = model or self.lightning_module
self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)
return self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)

def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]:
"""Setup a model and multiple optimizers together.
Expand Down
41 changes: 0 additions & 41 deletions tests/core/test_lightning_module.py
Expand Up @@ -76,47 +76,6 @@ def test_property_logger(tmpdir):
assert model.logger == logger


def test_params_groups_and_state_are_accessible(tmpdir):
class TestModel(BoringModel):
def training_step(self, batch, batch_idx, optimizer_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
return {"loss": loss}

def configure_optimizers(self):
optimizer = SGD(self.layer.parameters(), lr=0.1)
optimizer_2 = Adam(self.layer.parameters(), lr=0.1)
return [optimizer, optimizer_2]

def optimizer_step(
self,
epoch,
batch_idx,
optimizer,
optimizer_idx,
optimizer_closure,
on_tpu=False,
using_native_amp=False,
using_lbfgs=False,
):
# warm up lr
if self.trainer.global_step < 500:
lr_scale = min(1.0, float(self.trainer.global_step + 1) / 500.0)
for pg in optimizer.param_groups:
pg["lr"] = lr_scale * 0.01

optimizer.step(closure=optimizer_closure)

model = TestModel()
model.training_epoch_end = None

trainer = Trainer(
max_epochs=1, default_root_dir=tmpdir, limit_train_batches=8, limit_val_batches=1, accumulate_grad_batches=1
)

trainer.fit(model)


def test_toggle_untoggle_2_optimizers_no_shared_parameters(tmpdir):
class TestModel(BoringModel):
def training_step(self, batch, batch_idx, optimizer_idx=None):
Expand Down
35 changes: 33 additions & 2 deletions tests/core/test_lightning_optimizer.py
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import DEFAULT, patch
from unittest.mock import DEFAULT, Mock, patch

import pytest
import torch
Expand Down Expand Up @@ -95,7 +95,10 @@ def closure(opt):
opt_1.step()

closure(opt_2)
opt_2.step()
step_output = opt_2.step()
# check that the step output is returned with manual optimization
# since the optimizer is mocked, the step output is a Mock
assert isinstance(step_output, Mock)

def configure_optimizers(self):
optimizer_1 = torch.optim.SGD(self.layer.parameters(), lr=0.1)
Expand Down Expand Up @@ -314,3 +317,31 @@ def test_lightning_optimizer_keeps_hooks(tmpdir):
assert len(optimizer._fwd_handles) == 1
del lightning_optimizer
assert len(optimizer._fwd_handles) == 1


def test_params_groups_and_state_are_accessible(tmpdir):
class TestModel(BoringModel):
def training_step(self, batch, batch_idx, optimizer_idx):
output = self.layer(batch)
loss = self.loss(batch, output)
self.__loss = loss
tchaton marked this conversation as resolved.
Show resolved Hide resolved
return loss

def configure_optimizers(self):
optimizer = SGD(self.layer.parameters(), lr=0.1)
optimizer_2 = Adam(self.layer.parameters(), lr=0.1)
return [optimizer, optimizer_2]

def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure, **__):
# check attributes are accessible
assert all("lr" in pg for pg in optimizer.param_groups)
assert optimizer.state is optimizer._optimizer.state
assert optimizer.defaults is optimizer._optimizer.defaults

loss = optimizer.step(closure=optimizer_closure)
# the optimizer step still returns the loss
assert loss == self.__loss
tchaton marked this conversation as resolved.
Show resolved Hide resolved

model = TestModel()
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir, limit_train_batches=1, limit_val_batches=0)
trainer.fit(model)
4 changes: 3 additions & 1 deletion tests/lite/test_wrappers.py
Expand Up @@ -155,7 +155,9 @@ def test_lite_optimizer_steps():
"""Test that the LiteOptimizer forwards the step() and zero_grad() calls to the wrapped optimizer."""
optimizer = Mock()
strategy = Mock()
strategy.optimizer_step.return_value = 123
lite_optimizer = _LiteOptimizer(optimizer=optimizer, strategy=strategy)
lite_optimizer.step()
step_output = lite_optimizer.step()
assert step_output == 123
strategy.optimizer_step.assert_called_once()
strategy.optimizer_step.assert_called_with(optimizer, opt_idx=0, closure=ANY, model=strategy.model)