Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support checkpoint save and load with Stochastic Weight Averaging #9938

Merged
merged 94 commits into from
Aug 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
94 commits
Select commit Hold shift + click to select a range
72d0433
Save StochasticWeightAveraging callback data in checkpoints
adamreeve Oct 14, 2021
3d2bf65
Add option to use SWA parameters during validation
adamreeve Oct 14, 2021
1696273
Allow restoring SWA parameters to a model from a checkpoint
adamreeve Oct 14, 2021
c8db9d8
Refactor SWA batch norm moment update to work with validation
adamreeve Oct 18, 2021
004959b
Add test for loading a model from a checkpoint with SWA parameters
adamreeve Oct 19, 2021
d76528b
Recompute batch norm moments when updating parameters from a checkpoint
adamreeve Oct 19, 2021
0ea22e0
Handle when data batch is a list or tuple
adamreeve Oct 20, 2021
01ca2a7
Save SWA scheduler step count in checkpoints
adamreeve Oct 27, 2021
08d655b
Update SWA documentation and changelog
adamreeve Oct 27, 2021
91ab357
Fix DeepSource code style issues
adamreeve Oct 27, 2021
22e5d51
Revert SWA validation changes
adamreeve Nov 9, 2021
ed0a7f8
Merge remote-tracking branch 'upstream/master' into swa_checkpoint
adamreeve Nov 9, 2021
11963f6
Fix resuming from epoch before SWA start and add extra test
adamreeve Nov 9, 2021
226d8aa
Don't save state derived from constructor parameters into checkpoints
adamreeve Nov 9, 2021
9ecc417
Merge branch 'master' into swa_checkpoint
tchaton Nov 15, 2021
5d03d96
Tidy ups from code review
adamreeve Nov 15, 2021
02a04da
Fix handling of n_averaged checkpoint data with multiple processes
adamreeve Nov 15, 2021
8af5b56
Merge remote-tracking branch 'upstream/master' into swa_checkpoint
adamreeve Nov 16, 2021
db9590c
Merge branch 'master' into swa_checkpoint
tchaton Nov 29, 2021
5763e05
Fix deprecation warning in test
adamreeve Nov 29, 2021
d46be83
Remove check for non-empty callback state in checkpoint
adamreeve Nov 29, 2021
e0fd0cb
Raise MisconfigurationException when using SWA with sharded models
adamreeve Nov 29, 2021
2a83f05
Fix test failure with torch 1.7
adamreeve Nov 29, 2021
4a8d81c
Fix crash when fairscale isn't installed
adamreeve Nov 29, 2021
dab0ef4
Skip segfaulting test under pytorch < 1.8
adamreeve Nov 30, 2021
a0d52c8
Changelog merge fix
adamreeve Nov 30, 2021
cdf4734
Remove unnecessary intermediate variable
adamreeve Nov 30, 2021
ba5b8ab
Fix checking for sharded plugins
adamreeve Nov 30, 2021
d2bb0ad
Don't raise an error for DDPSharded and DDPSpawnSharded with SWA
adamreeve Nov 30, 2021
2c35328
Merge remote-tracking branch 'upstream/master' into swa_checkpoint
adamreeve Nov 30, 2021
ffcf011
Fix incorrect multiple context manager syntax for Python < 3.9
adamreeve Dec 5, 2021
c278034
Merge remote-tracking branch 'upstream/master' into swa_checkpoint
adamreeve Dec 6, 2021
d2fbe04
Merge branch 'master' into swa_checkpoint
adamreeve Dec 14, 2021
f13abf9
Merge branch 'master' into swa_checkpoint
adamreeve Dec 14, 2021
8e848dc
Code review tidy up and fix CHANGELOG merge error
adamreeve Dec 15, 2021
11757d5
Add a warning with initializing SWA after start but without checkpoin…
adamreeve Dec 16, 2021
50d525f
Merge branch 'master' into swa_checkpoint
adamreeve Dec 16, 2021
119f9b9
Merge branch 'master' into swa_checkpoint
adamreeve Dec 16, 2021
e332a42
Merge branch 'master' into swa_checkpoint
adamreeve Dec 21, 2021
fd59c41
Fixes to account for changes merged from master
adamreeve Dec 21, 2021
fe62b55
Merge branch 'master' into swa_checkpoint
adamreeve Dec 22, 2021
440c4b6
Merge branch 'master' into swa_checkpoint
adamreeve Jan 12, 2022
b10261e
Fix SWA scheduler not being stepped
adamreeve Jan 12, 2022
5bc9bee
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 12, 2022
8b9c624
Merge branch 'master' into swa_checkpoint
adamreeve Jan 23, 2022
72c0242
Merge branch 'master' into swa_checkpoint
adamreeve Jan 31, 2022
4dfb0df
Merge branch 'master' into swa_checkpoint
awaelchli Feb 5, 2022
9b5fbfc
mark test helper protected
awaelchli Feb 5, 2022
8e0c255
avoid warning for find_unused_parameters
awaelchli Feb 5, 2022
6bb52ba
Merge branch 'master' into swa_checkpoint
adamreeve Feb 9, 2022
c44279f
Use _LRScheduler.state_dict/load_state_dict instead of accessing priv…
adamreeve Feb 10, 2022
b3eee59
Add test to reproduce crash when resuming with SWA and a custom sched…
adamreeve Feb 9, 2022
0107ff1
Prevent trying to restore scheduler state into the wrong type of sche…
adamreeve Feb 10, 2022
8067144
Merge branch 'master' into swa_checkpoint
adamreeve Feb 11, 2022
81ac195
Add test case where trainer.strategy.restore_checkpoint_after_setup i…
adamreeve Feb 14, 2022
20393b1
Minor test refactoring
carmocca Feb 15, 2022
14f9f20
Fix test_swa_resume_training_from_checkpoint[2]
carmocca Feb 15, 2022
c677141
Did not mean to remove this
carmocca Feb 15, 2022
5cf5e1b
Test tidy up from PR review comments
adamreeve Feb 15, 2022
fe79d6c
Store most recent update epoch in the SWA checkpoint data
adamreeve Feb 15, 2022
d799a62
Merge branch 'master' into swa_checkpoint
adamreeve Feb 27, 2022
c7c2818
Fix for master change that broke resuming without validation dataloaders
adamreeve Feb 27, 2022
d2ed468
Adjust SWA tests to account for current checkpoint resume behaviour
adamreeve Feb 27, 2022
a2143a8
Merge branch 'master' into swa_checkpoint
adamreeve Mar 14, 2022
00328e8
Merge branch 'master' into swa_checkpoint
adamreeve Mar 24, 2022
5dbfc2d
Merge branch 'master' into swa_checkpoint
adamreeve Mar 28, 2022
b71b690
Revert workarounds for first epoch after resume having no batches
adamreeve Mar 28, 2022
15e6334
Use state_dict/load_state_dict instead of on_save/load_checkpoint in SWA
adamreeve Mar 28, 2022
e3104bc
Remove unnecessary workaround for handling restore_checkpoint_after_s…
adamreeve Apr 20, 2022
6e9fbba
Merge branch 'master' into swa_checkpoint
adamreeve Apr 20, 2022
08eecbb
Merge branch 'master' into swa_checkpoint
krshrimali Apr 25, 2022
1e9dc33
Merge branch 'master' into swa_checkpoint
adamreeve May 17, 2022
f509178
Fix deprecation warning in tests
adamreeve May 17, 2022
0388aea
Merge branch 'master' into swa_checkpoint
adamreeve May 30, 2022
f7594d6
Merge branch 'master' into swa_checkpoint
Borda Jun 21, 2022
cb6ce90
Merge branch 'master' into swa_checkpoint
Borda Jun 27, 2022
ddcb607
Merge branch 'master' into swa_checkpoint
awaelchli Jul 25, 2022
77f137c
update runif
awaelchli Jul 25, 2022
324499e
Remove no-longer required minimum torch version from test
adamreeve Aug 2, 2022
ab8aca0
Remove redundant None check that could hide a bug
adamreeve Aug 2, 2022
7d6e7a8
Don't save scheduler configs as they will only be overridden
adamreeve Aug 2, 2022
9bf237e
Use state_dict/load_state_dict to save and load average model state
adamreeve Aug 2, 2022
a9b6334
Parametrize misconfiguration error tests
adamreeve Aug 2, 2022
c24522b
Remove DummyError and match exception message
adamreeve Aug 2, 2022
b6b7db9
Merge remote-tracking branch 'upstream/master' into swa_checkpoint
adamreeve Aug 2, 2022
ba7cb5e
Fix state dict key
adamreeve Aug 2, 2022
8bde4f4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 2, 2022
3ed8ea4
Type checking fixes
adamreeve Aug 2, 2022
085bb4a
Merge remote-tracking branch 'upstream/master' into swa_checkpoint
adamreeve Aug 2, 2022
afba59d
Merge branch 'master' into swa_checkpoint
carmocca Aug 3, 2022
807fadf
Merge branch 'master' into swa_checkpoint
awaelchli Aug 3, 2022
15fe88e
fix changelog conflicts
awaelchli Aug 3, 2022
dcf5fea
Merge branch 'master' into swa_checkpoint
rohitgr7 Aug 9, 2022
ce9bcea
Merge branch 'master' into swa_checkpoint
awaelchli Aug 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed incorrect `precision="mixed"` being used with `DeepSpeedStrategy` and `IPUStrategy` ([#14041](https://github.com/Lightning-AI/lightning/pull/14041))


- Fixed resuming from a checkpoint when using Stochastic Weight Averaging (SWA) ([#9938](https://github.com/Lightning-AI/lightning/pull/9938))


- Fixed dtype inference during gradient norm computation ([#14051](https://github.com/Lightning-AI/lightning/pull/14051))


Expand Down
78 changes: 71 additions & 7 deletions src/pytorch_lightning/callbacks/stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
"""
from copy import deepcopy
from typing import Any, Callable, cast, List, Optional, Union
from typing import Any, Callable, cast, Dict, List, Optional, Union

import torch
from torch import nn, Tensor
from torch.optim.swa_utils import SWALR

import pytorch_lightning as pl
from pytorch_lightning.callbacks.callback import Callback
from pytorch_lightning.strategies import DDPFullyShardedStrategy, DeepSpeedStrategy
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.types import _LRScheduler, LRSchedulerConfig
Expand Down Expand Up @@ -112,15 +113,22 @@ def __init__(
if device is not None and not isinstance(device, (torch.device, str)):
raise MisconfigurationException(f"device is expected to be a torch.device or a str. Found {device}")

self.n_averaged: Optional[torch.Tensor] = None
self._swa_epoch_start = swa_epoch_start
self._swa_lrs = swa_lrs
self._annealing_epochs = annealing_epochs
self._annealing_strategy = annealing_strategy
self._avg_fn = avg_fn or self.avg_fn
self._device = device
self._max_epochs: int
self._model_contains_batch_norm: bool
self._model_contains_batch_norm: Optional[bool] = None
self._average_model: "pl.LightningModule"
self._initialized = False
self._swa_scheduler: Optional[_LRScheduler] = None
self._scheduler_state: Optional[Dict] = None
self._init_n_averaged = 0
self._latest_update_epoch = -1
self.momenta: Optional[Dict[nn.modules.batchnorm._BatchNorm, float]] = None
self._max_epochs: int

@property
def swa_start(self) -> int:
Expand All @@ -147,6 +155,9 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
if len(trainer.lr_scheduler_configs) > 1:
raise MisconfigurationException("SWA currently not supported for more than 1 `lr_scheduler`.")

if isinstance(trainer.strategy, (DDPFullyShardedStrategy, DeepSpeedStrategy)):
raise MisconfigurationException("SWA does not currently support sharded models.")

if isinstance(self._swa_epoch_start, float):
self._swa_epoch_start = int(trainer.max_epochs * self._swa_epoch_start)

Expand All @@ -158,8 +169,13 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
assert trainer.fit_loop.max_epochs is not None
trainer.fit_loop.max_epochs += 1

if self._scheduler_state is not None:
self._clear_schedulers(trainer)

def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if trainer.current_epoch == self.swa_start:
if (not self._initialized) and (self.swa_start <= trainer.current_epoch <= self.swa_end):
self._initialized = True

# move average model to request device.
self._average_model = self._average_model.to(self._device or pl_module.device)

Expand All @@ -180,6 +196,17 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1,
),
)
if self._scheduler_state is not None:
# Restore scheduler state from checkpoint
self._swa_scheduler.load_state_dict(self._scheduler_state)
elif trainer.current_epoch != self.swa_start:
# Log a warning if we're initializing after start without any checkpoint data,
# as behaviour will be different compared to having checkpoint data.
rank_zero_warn(
"SWA is initializing after swa_start without any checkpoint data. "
"This may be caused by loading a checkpoint from an older version of PyTorch Lightning."
)

# We assert that there is only one optimizer on fit start, so know opt_idx is always 0
default_scheduler_cfg = LRSchedulerConfig(self._swa_scheduler, opt_idx=0)
assert default_scheduler_cfg.interval == "epoch" and default_scheduler_cfg.frequency == 1
Expand All @@ -196,14 +223,18 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
else:
trainer.lr_scheduler_configs.append(default_scheduler_cfg)

self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device)
if self.n_averaged is None:
self.n_averaged = torch.tensor(self._init_n_averaged, dtype=torch.long, device=pl_module.device)

if self.swa_start <= trainer.current_epoch <= self.swa_end:
if (self.swa_start <= trainer.current_epoch <= self.swa_end) and (
trainer.current_epoch > self._latest_update_epoch
):
assert self.n_averaged is not None
self.update_parameters(self._average_model, pl_module, self.n_averaged, self._avg_fn)
self._latest_update_epoch = trainer.current_epoch

# Note: No > here in case the callback is saved with the model and training continues
if trainer.current_epoch == self.swa_end + 1:

# Transfer weights from average model to pl_module
self.transfer_weights(self._average_model, pl_module)

Expand Down Expand Up @@ -265,6 +296,7 @@ def reset_batch_norm_and_save_state(self, pl_module: "pl.LightningModule") -> No

def reset_momenta(self) -> None:
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165."""
assert self.momenta is not None
for bn_module in self.momenta:
bn_module.momentum = self.momenta[bn_module]

Expand All @@ -285,3 +317,35 @@ def update_parameters(
def avg_fn(averaged_model_parameter: Tensor, model_parameter: Tensor, num_averaged: Tensor) -> Tensor:
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97."""
return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (num_averaged + 1)

def state_dict(self) -> Dict[str, Any]:
return {
"n_averaged": 0 if self.n_averaged is None else self.n_averaged.item(),
"latest_update_epoch": self._latest_update_epoch,
"scheduler_state": None if self._swa_scheduler is None else self._swa_scheduler.state_dict(),
"average_model_state": None if self._average_model is None else self._average_model.state_dict(),
}

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
self._init_n_averaged = state_dict["n_averaged"]
self._latest_update_epoch = state_dict["latest_update_epoch"]
self._scheduler_state = state_dict["scheduler_state"]
self._load_average_model_state(state_dict["average_model_state"])

@staticmethod
def _clear_schedulers(trainer: "pl.Trainer") -> None:
# If we have scheduler state saved, clear the scheduler configs so that we don't try to
# load state into the wrong type of schedulers when restoring scheduler checkpoint state.
# We'll configure the scheduler and re-load its state in on_train_epoch_start.
# Note that this relies on the callback state being restored before the scheduler state is
# restored, and doesn't work if restore_checkpoint_after_setup is True, but at the time of
# writing that is only True for deepspeed which is already not supported by SWA.
# See https://github.com/PyTorchLightning/pytorch-lightning/issues/11665 for background.
if trainer.lr_scheduler_configs:
assert len(trainer.lr_scheduler_configs) == 1
trainer.lr_scheduler_configs.clear()

def _load_average_model_state(self, model_state: Any) -> None:
if self._average_model is None:
return
self._average_model.load_state_dict(model_state)
128 changes: 121 additions & 7 deletions tests/tests_pytorch/callbacks/test_stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from pathlib import Path
from typing import ContextManager, Optional
from unittest import mock

import pytest
import torch
from torch import nn
from torch.optim.lr_scheduler import LambdaLR
from torch.optim.swa_utils import SWALR
from torch.utils.data import DataLoader

Expand All @@ -30,7 +34,9 @@


class SwaTestModel(BoringModel):
def __init__(self, batchnorm: bool = True, interval: str = "epoch", iterable_dataset: bool = False):
def __init__(
self, batchnorm: bool = True, interval: str = "epoch", iterable_dataset: bool = False, crash_on_epoch=None
):
super().__init__()
layers = [nn.Linear(32, 32)]
if batchnorm:
Expand All @@ -39,17 +45,18 @@ def __init__(self, batchnorm: bool = True, interval: str = "epoch", iterable_dat
self.layer = nn.Sequential(*layers)
self.interval = interval
self.iterable_dataset = iterable_dataset
self.crash_on_epoch = crash_on_epoch

def training_step(self, batch, batch_idx):
if self.crash_on_epoch and self.trainer.current_epoch >= self.crash_on_epoch:
raise Exception("SWA crash test")
output = self.forward(batch)
loss = self.loss(batch, output)
return {"loss": loss}

def train_dataloader(self):

dset_cls = RandomIterableDataset if self.iterable_dataset else RandomDataset
dset = dset_cls(32, 64)

return DataLoader(dset, batch_size=2)

def configure_optimizers(self):
Expand All @@ -66,6 +73,8 @@ def configure_optimizers(self):
class SwaTestCallback(StochasticWeightAveraging):
update_parameters_calls: int = 0
transfer_weights_calls: int = 0
# Record the first epoch, as if we are resuming from a checkpoint this may not be equal to 0
first_epoch: Optional[int] = None

def update_parameters(self, *args, **kwargs):
self.update_parameters_calls += 1
Expand All @@ -77,6 +86,11 @@ def transfer_weights(self, *args, **kwargs):

def on_train_epoch_start(self, trainer, *args):
super().on_train_epoch_start(trainer, *args)
if self.first_epoch is None and not trainer.fit_loop.restarting:
# since the checkpoint loaded was saved `on_train_epoch_end`, the first `FitLoop` iteration will
# not update the model and just call the epoch-level hooks, for that reason, we check that we are not
# restarting before choosing the first epoch
self.first_epoch = trainer.current_epoch
assert trainer.fit_loop._skip_backward == (trainer.current_epoch > self.swa_end)
if self.swa_start <= trainer.current_epoch:
assert isinstance(trainer.lr_scheduler_configs[0].scheduler, SWALR)
Expand All @@ -88,6 +102,7 @@ def on_train_epoch_end(self, trainer, *args):
if self.swa_start <= trainer.current_epoch <= self.swa_end:
swa_epoch = trainer.current_epoch - self.swa_start
assert self.n_averaged == swa_epoch + 1
assert self._swa_scheduler is not None
# Scheduler is stepped once on initialization and then at the end of each epoch
assert self._swa_scheduler._step_count == swa_epoch + 2
elif trainer.current_epoch > self.swa_end:
Expand All @@ -103,10 +118,13 @@ def on_train_end(self, trainer, pl_module):

if not isinstance(trainer.strategy, DDPSpawnStrategy):
# check backward call count. the batchnorm update epoch should not backward
assert trainer.strategy.backward.call_count == trainer.max_epochs * trainer.limit_train_batches
assert trainer.strategy.backward.call_count == (
(trainer.max_epochs - self.first_epoch) * trainer.limit_train_batches
)

# check call counts
assert self.update_parameters_calls == trainer.max_epochs - (self._swa_epoch_start - 1)
first_swa_epoch = max(self.first_epoch, self.swa_start)
assert self.update_parameters_calls == trainer.max_epochs - first_swa_epoch
assert self.transfer_weights_calls == 1


Expand Down Expand Up @@ -140,7 +158,7 @@ def train_with_swa(
devices=devices,
)

with mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward):
with _backward_patch(trainer):
trainer.fit(model)

# check the model is the expected
Expand Down Expand Up @@ -226,9 +244,10 @@ def test_swa_multiple_lrs(tmpdir):

class TestModel(BoringModel):
def __init__(self):
super(BoringModel, self).__init__()
super().__init__()
self.layer1 = torch.nn.Linear(32, 32)
self.layer2 = torch.nn.Linear(32, 2)
self.on_train_epoch_start_called = False

def forward(self, x):
x = self.layer1(x)
Expand All @@ -255,3 +274,98 @@ def on_train_epoch_start(self):
)
trainer.fit(model)
assert model.on_train_epoch_start_called


def _swa_resume_training_from_checkpoint(tmpdir, model, resume_model, ddp=False):
swa_start = 3
trainer_kwargs = {
"default_root_dir": tmpdir,
"max_epochs": 5,
"accelerator": "cpu",
"strategy": "ddp_spawn_find_unused_parameters_false" if ddp else None,
"devices": 2 if ddp else 1,
"limit_train_batches": 5,
"limit_val_batches": 0,
"accumulate_grad_batches": 2,
"enable_progress_bar": False,
}
trainer = Trainer(callbacks=SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1), **trainer_kwargs)

with _backward_patch(trainer), pytest.raises(Exception, match="SWA crash test"):
trainer.fit(model)

checkpoint_dir = Path(tmpdir) / "lightning_logs" / "version_0" / "checkpoints"
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
checkpoint_files = os.listdir(checkpoint_dir)
assert len(checkpoint_files) == 1
ckpt_path = str(checkpoint_dir / checkpoint_files[0])

trainer = Trainer(callbacks=SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1), **trainer_kwargs)

with _backward_patch(trainer):
trainer.fit(resume_model, ckpt_path=ckpt_path)


class CustomSchedulerModel(SwaTestModel):
def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)

def lr_lambda(current_step: int):
return 0.1

scheduler = LambdaLR(optimizer, lr_lambda, -1)
return {
"optimizer": optimizer,
"lr_scheduler": {
"scheduler": scheduler,
"interval": self.interval,
},
}


@pytest.mark.parametrize("crash_on_epoch", [1, 3])
def test_swa_resume_training_from_checkpoint(tmpdir, crash_on_epoch):
model = SwaTestModel(crash_on_epoch=crash_on_epoch)
resume_model = SwaTestModel()
_swa_resume_training_from_checkpoint(tmpdir, model, resume_model)


@pytest.mark.parametrize("crash_on_epoch", [1, 3])
def test_swa_resume_training_from_checkpoint_custom_scheduler(tmpdir, crash_on_epoch):
# Reproduces the bug reported in https://github.com/PyTorchLightning/pytorch-lightning/issues/11665
model = CustomSchedulerModel(crash_on_epoch=crash_on_epoch)
resume_model = CustomSchedulerModel()
_swa_resume_training_from_checkpoint(tmpdir, model, resume_model)


@RunIf(skip_windows=True)
def test_swa_resume_training_from_checkpoint_ddp(tmpdir):
model = SwaTestModel(crash_on_epoch=3)
resume_model = SwaTestModel()
_swa_resume_training_from_checkpoint(tmpdir, model, resume_model, ddp=True)


@pytest.mark.parametrize(
"strategy",
[
pytest.param("fsdp", marks=RunIf(fairscale_fully_sharded=True, min_cuda_gpus=1)),
pytest.param("deepspeed", marks=RunIf(deepspeed=True, min_cuda_gpus=1)),
],
)
def test_misconfiguration_error_with_sharded_model(tmpdir, strategy: str):
model = SwaTestModel()
swa_callback = SwaTestCallback(swa_epoch_start=2, swa_lrs=0.1)
trainer = Trainer(
default_root_dir=tmpdir,
enable_progress_bar=False,
max_epochs=5,
callbacks=[swa_callback],
strategy=strategy,
accelerator="gpu",
devices=1,
)
with pytest.raises(MisconfigurationException, match="SWA does not currently support sharded models"):
trainer.fit(model)


def _backward_patch(trainer: Trainer) -> ContextManager:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
return mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward)