Skip to content

Commit

Permalink
Support checkpoint save and load with Stochastic Weight Averaging (Li…
Browse files Browse the repository at this point in the history
…ghtning-AI#9938)

Co-authored-by: thomas chaton <thomas@grid.ai>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
Co-authored-by: Kushashwa Ravi Shrimali <kushashwaravishrimali@gmail.com>
Co-authored-by: Jirka <jirka.borovec@seznam.cz>
Co-authored-by: Rohit Gupta <rohitgr1998@gmail.com>
  • Loading branch information
8 people authored and jessecambon committed Aug 16, 2022
1 parent 76f6aa9 commit 41d0687
Show file tree
Hide file tree
Showing 3 changed files with 195 additions and 14 deletions.
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed incorrect `precision="mixed"` being used with `DeepSpeedStrategy` and `IPUStrategy` ([#14041](https://github.com/Lightning-AI/lightning/pull/14041))


- Fixed resuming from a checkpoint when using Stochastic Weight Averaging (SWA) ([#9938](https://github.com/Lightning-AI/lightning/pull/9938))


- Fixed dtype inference during gradient norm computation ([#14051](https://github.com/Lightning-AI/lightning/pull/14051))


Expand Down
78 changes: 71 additions & 7 deletions src/pytorch_lightning/callbacks/stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
"""
from copy import deepcopy
from typing import Any, Callable, cast, List, Optional, Union
from typing import Any, Callable, cast, Dict, List, Optional, Union

import torch
from torch import nn, Tensor
from torch.optim.swa_utils import SWALR

import pytorch_lightning as pl
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.strategies import DDPFullyShardedStrategy, DeepSpeedStrategy
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.types import _LRScheduler, LRSchedulerConfig
Expand Down Expand Up @@ -112,15 +113,22 @@ def __init__(
if device is not None and not isinstance(device, (torch.device, str)):
raise MisconfigurationException(f"device is expected to be a torch.device or a str. Found {device}")

self.n_averaged: Optional[torch.Tensor] = None
self._swa_epoch_start = swa_epoch_start
self._swa_lrs = swa_lrs
self._annealing_epochs = annealing_epochs
self._annealing_strategy = annealing_strategy
self._avg_fn = avg_fn or self.avg_fn
self._device = device
self._max_epochs: int
self._model_contains_batch_norm: bool
self._model_contains_batch_norm: Optional[bool] = None
self._average_model: "pl.LightningModule"
self._initialized = False
self._swa_scheduler: Optional[_LRScheduler] = None
self._scheduler_state: Optional[Dict] = None
self._init_n_averaged = 0
self._latest_update_epoch = -1
self.momenta: Optional[Dict[nn.modules.batchnorm._BatchNorm, float]] = None
self._max_epochs: int

@property
def swa_start(self) -> int:
Expand All @@ -147,6 +155,9 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
if len(trainer.lr_scheduler_configs) > 1:
raise MisconfigurationException("SWA currently not supported for more than 1 `lr_scheduler`.")

if isinstance(trainer.strategy, (DDPFullyShardedStrategy, DeepSpeedStrategy)):
raise MisconfigurationException("SWA does not currently support sharded models.")

if isinstance(self._swa_epoch_start, float):
self._swa_epoch_start = int(trainer.max_epochs * self._swa_epoch_start)

Expand All @@ -158,8 +169,13 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
assert trainer.fit_loop.max_epochs is not None
trainer.fit_loop.max_epochs += 1

if self._scheduler_state is not None:
self._clear_schedulers(trainer)

def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if trainer.current_epoch == self.swa_start:
if (not self._initialized) and (self.swa_start <= trainer.current_epoch <= self.swa_end):
self._initialized = True

# move average model to request device.
self._average_model = self._average_model.to(self._device or pl_module.device)

Expand All @@ -180,6 +196,17 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1,
),
)
if self._scheduler_state is not None:
# Restore scheduler state from checkpoint
self._swa_scheduler.load_state_dict(self._scheduler_state)
elif trainer.current_epoch != self.swa_start:
# Log a warning if we're initializing after start without any checkpoint data,
# as behaviour will be different compared to having checkpoint data.
rank_zero_warn(
"SWA is initializing after swa_start without any checkpoint data. "
"This may be caused by loading a checkpoint from an older version of PyTorch Lightning."
)

# We assert that there is only one optimizer on fit start, so know opt_idx is always 0
default_scheduler_cfg = LRSchedulerConfig(self._swa_scheduler, opt_idx=0)
assert default_scheduler_cfg.interval == "epoch" and default_scheduler_cfg.frequency == 1
Expand All @@ -196,14 +223,18 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
else:
trainer.lr_scheduler_configs.append(default_scheduler_cfg)

self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device)
if self.n_averaged is None:
self.n_averaged = torch.tensor(self._init_n_averaged, dtype=torch.long, device=pl_module.device)

if self.swa_start <= trainer.current_epoch <= self.swa_end:
if (self.swa_start <= trainer.current_epoch <= self.swa_end) and (
trainer.current_epoch > self._latest_update_epoch
):
assert self.n_averaged is not None
self.update_parameters(self._average_model, pl_module, self.n_averaged, self._avg_fn)
self._latest_update_epoch = trainer.current_epoch

# Note: No > here in case the callback is saved with the model and training continues
if trainer.current_epoch == self.swa_end + 1:

# Transfer weights from average model to pl_module
self.transfer_weights(self._average_model, pl_module)

Expand Down Expand Up @@ -265,6 +296,7 @@ def reset_batch_norm_and_save_state(self, pl_module: "pl.LightningModule") -> No

def reset_momenta(self) -> None:
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165."""
assert self.momenta is not None
for bn_module in self.momenta:
bn_module.momentum = self.momenta[bn_module]

Expand All @@ -285,3 +317,35 @@ def update_parameters(
def avg_fn(averaged_model_parameter: Tensor, model_parameter: Tensor, num_averaged: Tensor) -> Tensor:
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97."""
return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (num_averaged + 1)

def state_dict(self) -> Dict[str, Any]:
return {
"n_averaged": 0 if self.n_averaged is None else self.n_averaged.item(),
"latest_update_epoch": self._latest_update_epoch,
"scheduler_state": None if self._swa_scheduler is None else self._swa_scheduler.state_dict(),
"average_model_state": None if self._average_model is None else self._average_model.state_dict(),
}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self._init_n_averaged = state_dict["n_averaged"]
self._latest_update_epoch = state_dict["latest_update_epoch"]
self._scheduler_state = state_dict["scheduler_state"]
self._load_average_model_state(state_dict["average_model_state"])

@staticmethod
def _clear_schedulers(trainer: "pl.Trainer") -> None:
# If we have scheduler state saved, clear the scheduler configs so that we don't try to
# load state into the wrong type of schedulers when restoring scheduler checkpoint state.
# We'll configure the scheduler and re-load its state in on_train_epoch_start.
# Note that this relies on the callback state being restored before the scheduler state is
# restored, and doesn't work if restore_checkpoint_after_setup is True, but at the time of
# writing that is only True for deepspeed which is already not supported by SWA.
# See https://github.com/PyTorchLightning/pytorch-lightning/issues/11665 for background.
if trainer.lr_scheduler_configs:
assert len(trainer.lr_scheduler_configs) == 1
trainer.lr_scheduler_configs.clear()

def _load_average_model_state(self, model_state: Any) -> None:
if self._average_model is None:
return
self._average_model.load_state_dict(model_state)
128 changes: 121 additions & 7 deletions tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from pathlib import Path
from typing import ContextManager, Optional
from unittest import mock

import pytest
import torch
from torch import nn
from torch.optim.lr_scheduler import LambdaLR
from torch.optim.swa_utils import SWALR
from torch.utils.data import DataLoader

Expand All @@ -30,7 +34,9 @@


class SwaTestModel(BoringModel):
def __init__(self, batchnorm: bool = True, interval: str = "epoch", iterable_dataset: bool = False):
def __init__(
self, batchnorm: bool = True, interval: str = "epoch", iterable_dataset: bool = False, crash_on_epoch=None
):
super().__init__()
layers = [nn.Linear(32, 32)]
if batchnorm:
Expand All @@ -39,17 +45,18 @@ def __init__(self, batchnorm: bool = True, interval: str = "epoch", iterable_dat
self.layer = nn.Sequential(*layers)
self.interval = interval
self.iterable_dataset = iterable_dataset
self.crash_on_epoch = crash_on_epoch

def training_step(self, batch, batch_idx):
if self.crash_on_epoch and self.trainer.current_epoch >= self.crash_on_epoch:
raise Exception("SWA crash test")
output = self.forward(batch)
loss = self.loss(batch, output)
return {"loss": loss}

def train_dataloader(self):

dset_cls = RandomIterableDataset if self.iterable_dataset else RandomDataset
dset = dset_cls(32, 64)

return DataLoader(dset, batch_size=2)

def configure_optimizers(self):
Expand All @@ -66,6 +73,8 @@ def configure_optimizers(self):
class SwaTestCallback(StochasticWeightAveraging):
update_parameters_calls: int = 0
transfer_weights_calls: int = 0
# Record the first epoch, as if we are resuming from a checkpoint this may not be equal to 0
first_epoch: Optional[int] = None

def update_parameters(self, *args, **kwargs):
self.update_parameters_calls += 1
Expand All @@ -77,6 +86,11 @@ def transfer_weights(self, *args, **kwargs):

def on_train_epoch_start(self, trainer, *args):
super().on_train_epoch_start(trainer, *args)
if self.first_epoch is None and not trainer.fit_loop.restarting:
# since the checkpoint loaded was saved `on_train_epoch_end`, the first `FitLoop` iteration will
# not update the model and just call the epoch-level hooks, for that reason, we check that we are not
# restarting before choosing the first epoch
self.first_epoch = trainer.current_epoch
assert trainer.fit_loop._skip_backward == (trainer.current_epoch > self.swa_end)
if self.swa_start <= trainer.current_epoch:
assert isinstance(trainer.lr_scheduler_configs[0].scheduler, SWALR)
Expand All @@ -88,6 +102,7 @@ def on_train_epoch_end(self, trainer, *args):
if self.swa_start <= trainer.current_epoch <= self.swa_end:
swa_epoch = trainer.current_epoch - self.swa_start
assert self.n_averaged == swa_epoch + 1
assert self._swa_scheduler is not None
# Scheduler is stepped once on initialization and then at the end of each epoch
assert self._swa_scheduler._step_count == swa_epoch + 2
elif trainer.current_epoch > self.swa_end:
Expand All @@ -103,10 +118,13 @@ def on_train_end(self, trainer, pl_module):

if not isinstance(trainer.strategy, DDPSpawnStrategy):
# check backward call count. the batchnorm update epoch should not backward
assert trainer.strategy.backward.call_count == trainer.max_epochs * trainer.limit_train_batches
assert trainer.strategy.backward.call_count == (
(trainer.max_epochs - self.first_epoch) * trainer.limit_train_batches
)

# check call counts
assert self.update_parameters_calls == trainer.max_epochs - (self._swa_epoch_start - 1)
first_swa_epoch = max(self.first_epoch, self.swa_start)
assert self.update_parameters_calls == trainer.max_epochs - first_swa_epoch
assert self.transfer_weights_calls == 1


Expand Down Expand Up @@ -140,7 +158,7 @@ def train_with_swa(
devices=devices,
)

with mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward):
with _backward_patch(trainer):
trainer.fit(model)

# check the model is the expected
Expand Down Expand Up @@ -226,9 +244,10 @@ def test_swa_multiple_lrs(tmpdir):

class TestModel(BoringModel):
def __init__(self):
super(BoringModel, self).__init__()
super().__init__()
self.layer1 = torch.nn.Linear(32, 32)
self.layer2 = torch.nn.Linear(32, 2)
self.on_train_epoch_start_called = False

def forward(self, x):
x = self.layer1(x)
Expand All @@ -255,3 +274,98 @@ def on_train_epoch_start(self):
)
trainer.fit(model)
assert model.on_train_epoch_start_called


def _swa_resume_training_from_checkpoint(tmpdir, model, resume_model, ddp=False):
swa_start = 3
trainer_kwargs = {
"default_root_dir": tmpdir,
"max_epochs": 5,
"accelerator": "cpu",
"strategy": "ddp_spawn_find_unused_parameters_false" if ddp else None,
"devices": 2 if ddp else 1,
"limit_train_batches": 5,
"limit_val_batches": 0,
"accumulate_grad_batches": 2,
"enable_progress_bar": False,
}
trainer = Trainer(callbacks=SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1), **trainer_kwargs)

with _backward_patch(trainer), pytest.raises(Exception, match="SWA crash test"):
trainer.fit(model)

checkpoint_dir = Path(tmpdir) / "lightning_logs" / "version_0" / "checkpoints"
checkpoint_files = os.listdir(checkpoint_dir)
assert len(checkpoint_files) == 1
ckpt_path = str(checkpoint_dir / checkpoint_files[0])

trainer = Trainer(callbacks=SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1), **trainer_kwargs)

with _backward_patch(trainer):
trainer.fit(resume_model, ckpt_path=ckpt_path)


class CustomSchedulerModel(SwaTestModel):
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)

def lr_lambda(current_step: int):
return 0.1

scheduler = LambdaLR(optimizer, lr_lambda, -1)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"interval": self.interval,
},
}


@pytest.mark.parametrize("crash_on_epoch", [1, 3])
def test_swa_resume_training_from_checkpoint(tmpdir, crash_on_epoch):
model = SwaTestModel(crash_on_epoch=crash_on_epoch)
resume_model = SwaTestModel()
_swa_resume_training_from_checkpoint(tmpdir, model, resume_model)


@pytest.mark.parametrize("crash_on_epoch", [1, 3])
def test_swa_resume_training_from_checkpoint_custom_scheduler(tmpdir, crash_on_epoch):
# Reproduces the bug reported in https://github.com/PyTorchLightning/pytorch-lightning/issues/11665
model = CustomSchedulerModel(crash_on_epoch=crash_on_epoch)
resume_model = CustomSchedulerModel()
_swa_resume_training_from_checkpoint(tmpdir, model, resume_model)


@RunIf(skip_windows=True)
def test_swa_resume_training_from_checkpoint_ddp(tmpdir):
model = SwaTestModel(crash_on_epoch=3)
resume_model = SwaTestModel()
_swa_resume_training_from_checkpoint(tmpdir, model, resume_model, ddp=True)


@pytest.mark.parametrize(
"strategy",
[
pytest.param("fsdp", marks=RunIf(fairscale_fully_sharded=True, min_cuda_gpus=1)),
pytest.param("deepspeed", marks=RunIf(deepspeed=True, min_cuda_gpus=1)),
],
)
def test_misconfiguration_error_with_sharded_model(tmpdir, strategy: str):
model = SwaTestModel()
swa_callback = SwaTestCallback(swa_epoch_start=2, swa_lrs=0.1)
trainer = Trainer(
default_root_dir=tmpdir,
enable_progress_bar=False,
max_epochs=5,
callbacks=[swa_callback],
strategy=strategy,
accelerator="gpu",
devices=1,
)
with pytest.raises(MisconfigurationException, match="SWA does not currently support sharded models"):
trainer.fit(model)


def _backward_patch(trainer: Trainer) -> ContextManager:
return mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward)

0 comments on commit 41d0687

Please sign in to comment.