Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support optimizer step progress tracking with manual optimization #11848

Merged
merged 6 commits into from Feb 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -78,6 +78,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added `opt_idx` to scheduler config if not assigned by user ([#11247](https://github.com/PyTorchLightning/pytorch-lightning/pull/11247))


- Added support for optimizer step progress tracking with manual optimization ([#11848](https://github.com/PyTorchLightning/pytorch-lightning/pull/11848))


- Return the output of the `optimizer.step`. This can be useful for `LightningLite` users, manual optimization users, or users overriding `LightningModule.optimizer_step` ([#11711](https://github.com/PyTorchLightning/pytorch-lightning/pull/11711))


Expand Down
11 changes: 10 additions & 1 deletion pytorch_lightning/core/optimizer.py
Expand Up @@ -53,6 +53,9 @@ def __init__(self, optimizer: Optimizer):
self._optimizer = optimizer
self._strategy: Optional[pl.strategies.Strategy] = None
self._optimizer_idx = 0
# to inject logic around the optimizer step, particularly useful with manual optimization
self._on_before_step = do_nothing_closure
self._on_after_step = do_nothing_closure

@property
def optimizer(self) -> Optimizer:
Expand Down Expand Up @@ -154,6 +157,8 @@ def closure_dis():
with opt_dis.toggle_model(sync_grad=accumulated_grad_batches):
opt_dis.step(closure=closure_dis)
"""
self._on_before_step()

if closure is None:
closure = do_nothing_closure
profiler_action = "optimizer_step_without_closure"
Expand All @@ -166,7 +171,11 @@ def closure_dis():
assert self._strategy is not None
assert self._strategy.lightning_module is not None
with self._strategy.lightning_module.trainer.profiler.profile(profiler_action):
return self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)

self._on_after_step()

return step_output


def _init_optimizers_and_lr_schedulers(
Expand Down
7 changes: 5 additions & 2 deletions pytorch_lightning/loops/fit_loop.py
Expand Up @@ -71,7 +71,10 @@ def __init__(
@property
def global_step(self) -> int:
"""Returns the global step."""
return self.epoch_loop.global_step
lightning_module = self.trainer.lightning_module
if lightning_module is None or lightning_module.automatic_optimization:
return self.epoch_loop.global_step
return self.epoch_loop.batch_loop.manual_loop.optim_step_progress.total.completed

@global_step.setter
def global_step(self, value: int) -> None:
Expand All @@ -96,7 +99,7 @@ def split_idx(self) -> int:
@property
def min_steps(self) -> Optional[int]:
# TODO(@justusschock): Why aren't we using the attribute in this class?
"""Returns the minimum numnber of steps to run."""
"""Returns the minimum number of steps to run."""
return self.epoch_loop.min_steps

@min_steps.setter
Expand Down
22 changes: 22 additions & 0 deletions pytorch_lightning/loops/optimization/manual_loop.py
Expand Up @@ -16,9 +16,11 @@

from torch import Tensor

from pytorch_lightning.core.optimizer import do_nothing_closure
from pytorch_lightning.loops import Loop
from pytorch_lightning.loops.optimization.closure import OutputResult
from pytorch_lightning.loops.utilities import _build_training_step_kwargs, _extract_hiddens
from pytorch_lightning.trainer.progress import Progress, ReadyCompletedTracker
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.types import STEP_OUTPUT

Expand Down Expand Up @@ -74,6 +76,10 @@ class ManualOptimization(Loop[_OUTPUTS_TYPE]):

def __init__(self) -> None:
super().__init__()
# since manual optimization does not track lr scheduler or optimizer frequencies, we use a simpler progress than
# `OptimizationProgress`
self.optim_step_progress = Progress.from_defaults(ReadyCompletedTracker)
ethanwharris marked this conversation as resolved.
Show resolved Hide resolved

self._done: bool = False
self._hiddens: Optional[Any] = None
self._output: _OUTPUTS_TYPE = {}
Expand All @@ -85,6 +91,12 @@ def done(self) -> bool:
def reset(self) -> None:
self._done = False

def on_run_start(self, *_: Any, **__: Any) -> None:
# inject logic around the optimizer step
for i, lightning_optimizer in self.trainer.strategy._lightning_optimizers.items():
lightning_optimizer._on_before_step = self._on_before_step
lightning_optimizer._on_after_step = self._on_after_step

def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override]
"""Performs the training step for manual optimization.

Expand Down Expand Up @@ -126,4 +138,14 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override]
def on_run_end(self) -> _OUTPUTS_TYPE:
"""Returns the result of this loop, i.e., the post-processed outputs from the training step."""
output, self._output = self._output, {} # free memory
# reset logic around the optimizer step
for i, lightning_optimizer in self.trainer.strategy._lightning_optimizers.items():
lightning_optimizer._on_before_step = do_nothing_closure
lightning_optimizer._on_after_step = do_nothing_closure
return output

def _on_before_step(self) -> None:
self.optim_step_progress.increment_ready()

def _on_after_step(self) -> None:
self.optim_step_progress.increment_completed()
4 changes: 2 additions & 2 deletions pytorch_lightning/loops/optimization/optimizer_loop.py
Expand Up @@ -317,7 +317,7 @@ def backward_fn(loss: Tensor) -> None:
return backward_fn

def _run_optimization_start(self, opt_idx: int, optimizer: torch.optim.Optimizer) -> None:
"""Toggles the optimizer to ensure the correct one is used and prevend dangling grads.
"""Toggles the optimizer to ensure the correct one is used and prevent dangling grads.

Args:
opt_idx: the index of the optimizer to use
Expand Down Expand Up @@ -348,7 +348,7 @@ def _optimizer_step(
opt_idx: the index of the current :param:`optimizer`
batch_idx: the index of the current batch
train_step_and_backward_closure: the closure function performing the train step and computing the
gradients. By default called by the optimizer (if possible)
gradients. By default, called by the optimizer (if possible)
"""
is_lbfgs = isinstance(optimizer, torch.optim.LBFGS)

Expand Down
Expand Up @@ -348,7 +348,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
checkpoint = {
# the epoch is saved for compatibility but it's not relevant for restoration
"epoch": self.trainer.current_epoch,
"global_step": self.trainer.global_step + 1,
"global_step": self.trainer.global_step + model.automatic_optimization,
carmocca marked this conversation as resolved.
Show resolved Hide resolved
"pytorch-lightning_version": pl.__version__,
"state_dict": self._get_lightning_module_state_dict(),
"loops": self._get_loops_state_dict(),
Expand Down
1 change: 1 addition & 0 deletions pytorch_lightning/trainer/trainer.py
Expand Up @@ -2010,6 +2010,7 @@ def data_parallel_device_ids(self) -> Optional[List[int]]:

@property
def lightning_module(self) -> "pl.LightningModule":
# TODO: this is actually an optional return
carmocca marked this conversation as resolved.
Show resolved Hide resolved
return self.strategy.lightning_module

@property
Expand Down
3 changes: 2 additions & 1 deletion tests/core/test_lightning_optimizer.py
Expand Up @@ -152,7 +152,8 @@ def test_state(tmpdir):
lightning_dict = {
k: v
for k, v in lightning_optimizer.__dict__.items()
if k not in {"_optimizer", "_optimizer_idx", "_strategy", "_lightning_module"}
if k
not in {"_optimizer", "_optimizer_idx", "_strategy", "_lightning_module", "_on_before_step", "_on_after_step"}
}

assert lightning_dict == optimizer.__dict__
Expand Down
4 changes: 4 additions & 0 deletions tests/loops/test_loop_state_dict.py
Expand Up @@ -59,6 +59,10 @@ def test_loops_state_dict_structure():
},
"epoch_loop.batch_loop.state_dict": {},
"epoch_loop.batch_loop.manual_loop.state_dict": {},
"epoch_loop.batch_loop.manual_loop.optim_step_progress": {
"total": {"ready": 0, "completed": 0},
"current": {"ready": 0, "completed": 0},
},
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
"optimizer": {
Expand Down
8 changes: 8 additions & 0 deletions tests/loops/test_loops.py
Expand Up @@ -512,6 +512,10 @@ def configure_optimizers_multiple(self):
},
"epoch_loop.batch_loop.state_dict": ANY,
"epoch_loop.batch_loop.manual_loop.state_dict": ANY,
"epoch_loop.batch_loop.manual_loop.optim_step_progress": {
"total": {"ready": 0, "completed": 0},
"current": {"ready": 0, "completed": 0},
},
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
"optimizer_position": stop_optimizer,
Expand Down Expand Up @@ -681,6 +685,10 @@ def train_dataloader(self):
},
"epoch_loop.batch_loop.state_dict": ANY,
"epoch_loop.batch_loop.manual_loop.state_dict": ANY,
"epoch_loop.batch_loop.manual_loop.optim_step_progress": {
"total": {"ready": 0, "completed": 0},
"current": {"ready": 0, "completed": 0},
},
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
"optimizer_position": n_optimizers,
Expand Down
86 changes: 10 additions & 76 deletions tests/trainer/optimization/test_manual_optimization.py
Expand Up @@ -167,6 +167,7 @@ def training_epoch_end(self, outputs) -> None:
with mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward) as bwd_mock:
trainer.fit(model)
assert bwd_mock.call_count == limit_train_batches * 3
assert trainer.global_step == limit_train_batches * 2
carmocca marked this conversation as resolved.
Show resolved Hide resolved


def test_multiple_optimizers_manual_log(tmpdir):
Expand Down Expand Up @@ -530,18 +531,14 @@ def optimizer_closure():
weight_after = self.layer.weight.clone()
assert not torch.equal(weight_before, weight_after)

def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)

model = TestModel()
model.val_dataloader = None
model.training_epoch_end = None

limit_train_batches = 2
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
limit_val_batches=0,
max_epochs=1,
log_every_n_steps=1,
)
Expand All @@ -553,115 +550,50 @@ def configure_optimizers(self):
assert trainer.progress_bar_metrics["train_loss_epoch"] == torch.stack(model._losses).mean()


def test_step_with_optimizer_closure_and_accumulated_grad(tmpdir):
"""Tests that `step` works with optimizer_closure and accumulated_grad."""

def test_step_with_optimizer_closure_2(tmpdir):
class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.automatic_optimization = False

def training_step(self, batch, batch_idx):
# manual
opt = self.optimizers()
x = batch[0]

loss_1 = self(x)
loss_1 = self.loss(loss_1, loss_1)
loss = self(x).sum()

def optimizer_closure():
# emulate bayesian optimization.
num_backward = 1
for backward_idx in range(num_backward + 1):
retain_graph = num_backward != backward_idx
self.manual_backward(loss_1, retain_graph=retain_graph)
self.manual_backward(loss, retain_graph=retain_graph)

weight_before = self.layer.weight.clone()

opt.step(closure=optimizer_closure)

weight_after = self.layer.weight.clone()
if not self.trainer.fit_loop._should_accumulate():
assert not torch.equal(weight_before, weight_after)
else:
assert self.layer.weight.grad is not None

def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
assert not torch.equal(weight_before, weight_after)

model = TestModel()
model.val_dataloader = None
model.training_epoch_end = None

limit_train_batches = 4
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
limit_val_batches=0,
max_epochs=1,
log_every_n_steps=1,
)

with mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward) as bwd_mock:
trainer.fit(model)
assert bwd_mock.call_count == limit_train_batches * 2


@patch("torch.optim.SGD.step")
def test_step_with_optimizer_closure_and_extra_arguments(step_mock, tmpdir):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
"""Tests that `step` works with optimizer_closure and extra arguments."""

class TestModel(BoringModel):
def __init__(self):
super().__init__()
self.automatic_optimization = False

def on_train_start(self) -> None:
step_mock.reset_mock()

def training_step(self, batch, batch_idx):
# manual
opt = self.optimizers()
x = batch[0]

loss_1 = self(x)
loss_1 = self.loss(loss_1, loss_1)

def optimizer_closure():
# emulate bayesian optimization.
num_backward = 1
for backward_idx in range(num_backward + 1):
retain_graph = num_backward != backward_idx
self.manual_backward(loss_1, retain_graph=retain_graph)

opt.step(closure=optimizer_closure)
opt.zero_grad()

def configure_optimizers(self):
return torch.optim.SGD(self.layer.parameters(), lr=0.1)

model = TestModel()
model.val_dataloader = None
model.training_epoch_end = None

limit_train_batches = 4
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=limit_train_batches,
limit_val_batches=2,
max_epochs=1,
log_every_n_steps=1,
)

trainer.fit(model)
assert step_mock.mock_calls == [call(closure=ANY) for _ in range(limit_train_batches)]
assert trainer.global_step == limit_train_batches


@patch("torch.optim.Adam.step")
@patch("torch.optim.SGD.step")
def test_step_with_optimizer_closure_with_different_frequencies(mock_sgd_step, mock_adam_step, tmpdir):
"""Tests that `step` works with optimizer_closure and different accumulated_gradient frequency."""

class TestModel(BoringModel):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -700,6 +632,7 @@ def dis_closure():
# this will accumulate gradients for 2 batches and then call opt_gen.step()
gen_closure()
if batch_idx % 2 == 0:
# passing a custom kwarg
opt_gen.step(closure=gen_closure, optim="sgd")
opt_gen.zero_grad()

Expand Down Expand Up @@ -730,6 +663,7 @@ def configure_optimizers(self):
trainer.fit(model)
assert mock_sgd_step.mock_calls == [call(closure=ANY, optim="sgd") for _ in range(4)]
assert mock_adam_step.mock_calls == [call(closure=ANY) for _ in range(2)]
assert trainer.global_step == 4 + 2


class TesManualOptimizationDDPModel(BoringModel):
Expand Down