Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support checkpoint save and load with Stochastic Weight Averaging #9938

Merged
merged 94 commits into from
Aug 9, 2022
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
94 commits
Select commit Hold shift + click to select a range
72d0433
Save StochasticWeightAveraging callback data in checkpoints
adamreeve Oct 14, 2021
3d2bf65
Add option to use SWA parameters during validation
adamreeve Oct 14, 2021
1696273
Allow restoring SWA parameters to a model from a checkpoint
adamreeve Oct 14, 2021
c8db9d8
Refactor SWA batch norm moment update to work with validation
adamreeve Oct 18, 2021
004959b
Add test for loading a model from a checkpoint with SWA parameters
adamreeve Oct 19, 2021
d76528b
Recompute batch norm moments when updating parameters from a checkpoint
adamreeve Oct 19, 2021
0ea22e0
Handle when data batch is a list or tuple
adamreeve Oct 20, 2021
01ca2a7
Save SWA scheduler step count in checkpoints
adamreeve Oct 27, 2021
08d655b
Update SWA documentation and changelog
adamreeve Oct 27, 2021
91ab357
Fix DeepSource code style issues
adamreeve Oct 27, 2021
22e5d51
Revert SWA validation changes
adamreeve Nov 9, 2021
ed0a7f8
Merge remote-tracking branch 'upstream/master' into swa_checkpoint
adamreeve Nov 9, 2021
11963f6
Fix resuming from epoch before SWA start and add extra test
adamreeve Nov 9, 2021
226d8aa
Don't save state derived from constructor parameters into checkpoints
adamreeve Nov 9, 2021
9ecc417
Merge branch 'master' into swa_checkpoint
tchaton Nov 15, 2021
5d03d96
Tidy ups from code review
adamreeve Nov 15, 2021
02a04da
Fix handling of n_averaged checkpoint data with multiple processes
adamreeve Nov 15, 2021
8af5b56
Merge remote-tracking branch 'upstream/master' into swa_checkpoint
adamreeve Nov 16, 2021
db9590c
Merge branch 'master' into swa_checkpoint
tchaton Nov 29, 2021
5763e05
Fix deprecation warning in test
adamreeve Nov 29, 2021
d46be83
Remove check for non-empty callback state in checkpoint
adamreeve Nov 29, 2021
e0fd0cb
Raise MisconfigurationException when using SWA with sharded models
adamreeve Nov 29, 2021
2a83f05
Fix test failure with torch 1.7
adamreeve Nov 29, 2021
4a8d81c
Fix crash when fairscale isn't installed
adamreeve Nov 29, 2021
dab0ef4
Skip segfaulting test under pytorch < 1.8
adamreeve Nov 30, 2021
a0d52c8
Changelog merge fix
adamreeve Nov 30, 2021
cdf4734
Remove unnecessary intermediate variable
adamreeve Nov 30, 2021
ba5b8ab
Fix checking for sharded plugins
adamreeve Nov 30, 2021
d2bb0ad
Don't raise an error for DDPSharded and DDPSpawnSharded with SWA
adamreeve Nov 30, 2021
2c35328
Merge remote-tracking branch 'upstream/master' into swa_checkpoint
adamreeve Nov 30, 2021
ffcf011
Fix incorrect multiple context manager syntax for Python < 3.9
adamreeve Dec 5, 2021
c278034
Merge remote-tracking branch 'upstream/master' into swa_checkpoint
adamreeve Dec 6, 2021
d2fbe04
Merge branch 'master' into swa_checkpoint
adamreeve Dec 14, 2021
f13abf9
Merge branch 'master' into swa_checkpoint
adamreeve Dec 14, 2021
8e848dc
Code review tidy up and fix CHANGELOG merge error
adamreeve Dec 15, 2021
11757d5
Add a warning with initializing SWA after start but without checkpoin…
adamreeve Dec 16, 2021
50d525f
Merge branch 'master' into swa_checkpoint
adamreeve Dec 16, 2021
119f9b9
Merge branch 'master' into swa_checkpoint
adamreeve Dec 16, 2021
e332a42
Merge branch 'master' into swa_checkpoint
adamreeve Dec 21, 2021
fd59c41
Fixes to account for changes merged from master
adamreeve Dec 21, 2021
fe62b55
Merge branch 'master' into swa_checkpoint
adamreeve Dec 22, 2021
440c4b6
Merge branch 'master' into swa_checkpoint
adamreeve Jan 12, 2022
b10261e
Fix SWA scheduler not being stepped
adamreeve Jan 12, 2022
5bc9bee
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 12, 2022
8b9c624
Merge branch 'master' into swa_checkpoint
adamreeve Jan 23, 2022
72c0242
Merge branch 'master' into swa_checkpoint
adamreeve Jan 31, 2022
4dfb0df
Merge branch 'master' into swa_checkpoint
awaelchli Feb 5, 2022
9b5fbfc
mark test helper protected
awaelchli Feb 5, 2022
8e0c255
avoid warning for find_unused_parameters
awaelchli Feb 5, 2022
6bb52ba
Merge branch 'master' into swa_checkpoint
adamreeve Feb 9, 2022
c44279f
Use _LRScheduler.state_dict/load_state_dict instead of accessing priv…
adamreeve Feb 10, 2022
b3eee59
Add test to reproduce crash when resuming with SWA and a custom sched…
adamreeve Feb 9, 2022
0107ff1
Prevent trying to restore scheduler state into the wrong type of sche…
adamreeve Feb 10, 2022
8067144
Merge branch 'master' into swa_checkpoint
adamreeve Feb 11, 2022
81ac195
Add test case where trainer.strategy.restore_checkpoint_after_setup i…
adamreeve Feb 14, 2022
20393b1
Minor test refactoring
carmocca Feb 15, 2022
14f9f20
Fix test_swa_resume_training_from_checkpoint[2]
carmocca Feb 15, 2022
c677141
Did not mean to remove this
carmocca Feb 15, 2022
5cf5e1b
Test tidy up from PR review comments
adamreeve Feb 15, 2022
fe79d6c
Store most recent update epoch in the SWA checkpoint data
adamreeve Feb 15, 2022
d799a62
Merge branch 'master' into swa_checkpoint
adamreeve Feb 27, 2022
c7c2818
Fix for master change that broke resuming without validation dataloaders
adamreeve Feb 27, 2022
d2ed468
Adjust SWA tests to account for current checkpoint resume behaviour
adamreeve Feb 27, 2022
a2143a8
Merge branch 'master' into swa_checkpoint
adamreeve Mar 14, 2022
00328e8
Merge branch 'master' into swa_checkpoint
adamreeve Mar 24, 2022
5dbfc2d
Merge branch 'master' into swa_checkpoint
adamreeve Mar 28, 2022
b71b690
Revert workarounds for first epoch after resume having no batches
adamreeve Mar 28, 2022
15e6334
Use state_dict/load_state_dict instead of on_save/load_checkpoint in SWA
adamreeve Mar 28, 2022
e3104bc
Remove unnecessary workaround for handling restore_checkpoint_after_s…
adamreeve Apr 20, 2022
6e9fbba
Merge branch 'master' into swa_checkpoint
adamreeve Apr 20, 2022
08eecbb
Merge branch 'master' into swa_checkpoint
krshrimali Apr 25, 2022
1e9dc33
Merge branch 'master' into swa_checkpoint
adamreeve May 17, 2022
f509178
Fix deprecation warning in tests
adamreeve May 17, 2022
0388aea
Merge branch 'master' into swa_checkpoint
adamreeve May 30, 2022
f7594d6
Merge branch 'master' into swa_checkpoint
Borda Jun 21, 2022
cb6ce90
Merge branch 'master' into swa_checkpoint
Borda Jun 27, 2022
ddcb607
Merge branch 'master' into swa_checkpoint
awaelchli Jul 25, 2022
77f137c
update runif
awaelchli Jul 25, 2022
324499e
Remove no-longer required minimum torch version from test
adamreeve Aug 2, 2022
ab8aca0
Remove redundant None check that could hide a bug
adamreeve Aug 2, 2022
7d6e7a8
Don't save scheduler configs as they will only be overridden
adamreeve Aug 2, 2022
9bf237e
Use state_dict/load_state_dict to save and load average model state
adamreeve Aug 2, 2022
a9b6334
Parametrize misconfiguration error tests
adamreeve Aug 2, 2022
c24522b
Remove DummyError and match exception message
adamreeve Aug 2, 2022
b6b7db9
Merge remote-tracking branch 'upstream/master' into swa_checkpoint
adamreeve Aug 2, 2022
ba7cb5e
Fix state dict key
adamreeve Aug 2, 2022
8bde4f4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 2, 2022
3ed8ea4
Type checking fixes
adamreeve Aug 2, 2022
085bb4a
Merge remote-tracking branch 'upstream/master' into swa_checkpoint
adamreeve Aug 2, 2022
afba59d
Merge branch 'master' into swa_checkpoint
carmocca Aug 3, 2022
807fadf
Merge branch 'master' into swa_checkpoint
awaelchli Aug 3, 2022
15fe88e
fix changelog conflicts
awaelchli Aug 3, 2022
dcf5fea
Merge branch 'master' into swa_checkpoint
rohitgr7 Aug 9, 2022
ce9bcea
Merge branch 'master' into swa_checkpoint
awaelchli Aug 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed support for logging within callbacks returned from `LightningModule` ([#10991](https://github.com/PyTorchLightning/pytorch-lightning/pull/10991))


- Fixed resuming from a checkpoint when using Stochastic Weight Averaging (SWA) ([#9938](https://github.com/PyTorchLightning/pytorch-lightning/pull/9938))


- The TQDM progress bar now correctly shows the `on_epoch` logged values on train epoch end ([#11069](https://github.com/PyTorchLightning/pytorch-lightning/pull/11069))


Expand Down
66 changes: 60 additions & 6 deletions pytorch_lightning/callbacks/stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
"""
from copy import deepcopy
from typing import Callable, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union

import torch
from torch import nn
from torch.optim.swa_utils import SWALR

import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.plugins.training_type import DDPFullyShardedPlugin, DeepSpeedPlugin
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -115,14 +116,20 @@ def __init__(
if device is not None and not isinstance(device, (torch.device, str)):
raise MisconfigurationException(f"device is expected to be a torch.device or a str. Found {device}")

self.n_averaged: Optional[torch.Tensor] = None
self._swa_epoch_start = swa_epoch_start
self._swa_lrs = swa_lrs
self._annealing_epochs = annealing_epochs
self._annealing_strategy = annealing_strategy
self._avg_fn = avg_fn or self.avg_fn
self._device = device
self._model_contains_batch_norm = None
self._average_model = None
self._model_contains_batch_norm: Optional[bool] = None
self._average_model: Optional[pl.LightningModule] = None
self._initialized = False
self._swa_scheduler: Optional[SWALR] = None
self._scheduler_step_count: Optional[int] = None
self._init_n_averaged = 0
self.momenta: Optional[Dict[nn.modules.batchnorm._BatchNorm, float]] = None

@property
def swa_start(self) -> int:
Expand All @@ -145,6 +152,9 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
optimizers = trainer.optimizers
lr_schedulers = trainer.lr_schedulers

if isinstance(trainer.training_type_plugin, (DDPFullyShardedPlugin, DeepSpeedPlugin)):
raise MisconfigurationException("SWA does not currently support sharded models.")
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

if len(optimizers) != 1:
raise MisconfigurationException("SWA currently works with 1 `optimizer`.")

Expand All @@ -162,7 +172,9 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
trainer.fit_loop.max_epochs += 1

def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
if trainer.current_epoch == self.swa_start:
if (not self._initialized) and (self.swa_start <= trainer.current_epoch <= self.swa_end):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self._initialized = True

# move average model to request device.
self._average_model = self._average_model.to(self._device or pl_module.device)

Expand All @@ -182,6 +194,17 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
anneal_strategy=self._annealing_strategy,
last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1,
)
if self._scheduler_step_count is not None:
# Restore scheduler step count from checkpoint
self._swa_scheduler._step_count = self._scheduler_step_count
adamreeve marked this conversation as resolved.
Show resolved Hide resolved
elif trainer.current_epoch != self.swa_start:
# Log a warning if we're initializing after start without any checkpoint data,
# as behaviour will be different compared to having checkpoint data.
rank_zero_warn(
"SWA is initializing after swa_start without any checkpoint data. "
"This may be caused by loading a checkpoint from an older version of PyTorch Lightning."
)

default_scheduler_cfg = _get_default_scheduler_config()
assert default_scheduler_cfg["interval"] == "epoch" and default_scheduler_cfg["frequency"] == 1
default_scheduler_cfg["scheduler"] = self._swa_scheduler
Expand All @@ -198,14 +221,14 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
else:
trainer.lr_schedulers.append(default_scheduler_cfg)

self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device)
if self.n_averaged is None:
self.n_averaged = torch.tensor(self._init_n_averaged, dtype=torch.long, device=pl_module.device)

if self.swa_start <= trainer.current_epoch <= self.swa_end:
self.update_parameters(self._average_model, pl_module, self.n_averaged, self.avg_fn)

# Note: No > here in case the callback is saved with the model and training continues
if trainer.current_epoch == self.swa_end + 1:

# Transfer weights from average model to pl_module
self.transfer_weights(self._average_model, pl_module)

Expand Down Expand Up @@ -280,3 +303,34 @@ def avg_fn(
) -> torch.FloatTensor:
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97."""
return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (num_averaged + 1)

def on_save_checkpoint(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
) -> dict:
return {
"n_averaged": 0 if self.n_averaged is None else self.n_averaged.item(),
"scheduler_step_count": None if self._swa_scheduler is None else self._swa_scheduler._step_count,
"average_model_parameters": self._get_average_model_parameters(trainer),
}

def on_load_checkpoint(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", callback_state: Dict[str, Any]
) -> None:
self._init_n_averaged = callback_state["n_averaged"]
self._scheduler_step_count = callback_state["scheduler_step_count"]
self._load_average_model_parameters(callback_state["average_model_parameters"])

awaelchli marked this conversation as resolved.
Show resolved Hide resolved
def _get_average_model_parameters(self, trainer: "pl.Trainer") -> Optional[List[nn.Parameter]]:
if self._average_model is None or not (self.swa_start <= trainer.current_epoch <= self.swa_end):
# If we're not within the SWA epochs then when loading checkpoint data we would want
# to use parameters from the underlying model rather than the SWA parameters.
return
return list(self._average_model.parameters())

def _load_average_model_parameters(self, parameter_state: Any) -> None:
if self._average_model is None or parameter_state is None:
return
for p_swa, p_checkpoint in zip(self._average_model.parameters(), parameter_state):
device = p_swa.device
p_swa_ = p_swa.detach()
p_swa_.copy_(p_checkpoint.to(device))
120 changes: 116 additions & 4 deletions tests/callbacks/test_stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from pathlib import Path
from typing import Optional
from unittest import mock

import pytest
Expand All @@ -30,7 +33,9 @@


class SwaTestModel(BoringModel):
def __init__(self, batchnorm: bool = True, interval: str = "epoch", iterable_dataset: bool = False):
def __init__(
self, batchnorm: bool = True, interval: str = "epoch", iterable_dataset: bool = False, crash_after_epoch=None
):
super().__init__()
layers = [nn.Linear(32, 32)]
if batchnorm:
Expand All @@ -39,6 +44,9 @@ def __init__(self, batchnorm: bool = True, interval: str = "epoch", iterable_dat
self.layer = nn.Sequential(*layers)
self.interval = interval
self.iterable_dataset = iterable_dataset
self.crash_after_epoch = crash_after_epoch
self._epoch_count = 0
self.save_hyperparameters()
adamreeve marked this conversation as resolved.
Show resolved Hide resolved

def training_step(self, batch, batch_idx):
output = self.forward(batch)
Expand All @@ -62,10 +70,19 @@ def configure_optimizers(self):
},
}

def training_epoch_end(self, _):
if not self.crash_after_epoch:
return
self._epoch_count += 1
if self._epoch_count >= self.crash_after_epoch:
raise RuntimeError("Crash test")
adamreeve marked this conversation as resolved.
Show resolved Hide resolved


class SwaTestCallback(StochasticWeightAveraging):
update_parameters_calls: int = 0
transfer_weights_calls: int = 0
# Record the first epoch, as if we are resuming from a checkpoint this may not be equal to 0
first_epoch: Optional[int] = None

def update_parameters(self, *args, **kwargs):
self.update_parameters_calls += 1
Expand All @@ -77,6 +94,8 @@ def transfer_weights(self, *args, **kwargs):

def on_train_epoch_start(self, trainer, *args):
super().on_train_epoch_start(trainer, *args)
if self.first_epoch is None:
self.first_epoch = trainer.current_epoch
assert trainer.fit_loop._skip_backward == (trainer.current_epoch > self.swa_end)
if self.swa_start <= trainer.current_epoch:
assert isinstance(trainer.lr_schedulers[0]["scheduler"], SWALR)
Expand All @@ -88,6 +107,9 @@ def on_train_epoch_end(self, trainer, *args):
if self.swa_start <= trainer.current_epoch <= self.swa_end:
swa_epoch = trainer.current_epoch - self.swa_start
assert self.n_averaged == swa_epoch + 1
assert self._swa_scheduler is not None
# Scheduler is stepped once on initialization and then at the end of each epoch
assert self._swa_scheduler._step_count == swa_epoch + 2
elif trainer.current_epoch > self.swa_end:
assert self.n_averaged == self._max_epochs - self.swa_start

Expand All @@ -101,10 +123,13 @@ def on_train_end(self, trainer, pl_module):

if not isinstance(trainer.training_type_plugin, DDPSpawnPlugin):
# check backward call count. the batchnorm update epoch should not backward
assert trainer.training_type_plugin.backward.call_count == trainer.max_epochs * trainer.limit_train_batches
assert trainer.training_type_plugin.backward.call_count == (
(trainer.max_epochs - self.first_epoch) * trainer.limit_train_batches
)

# check call counts
assert self.update_parameters_calls == trainer.max_epochs - (self._swa_epoch_start - 1)
first_swa_epoch = max(self.first_epoch, self.swa_start)
assert self.update_parameters_calls == trainer.max_epochs - first_swa_epoch
assert self.transfer_weights_calls == 1


Expand Down Expand Up @@ -247,9 +272,10 @@ def test_swa_multiple_lrs(tmpdir):

class TestModel(BoringModel):
def __init__(self):
super(BoringModel, self).__init__()
super().__init__()
self.layer1 = torch.nn.Linear(32, 32)
self.layer2 = torch.nn.Linear(32, 2)
self.on_train_epoch_start_called = False

def forward(self, x):
x = self.layer1(x)
Expand All @@ -276,3 +302,89 @@ def on_train_epoch_start(self):
)
trainer.fit(model)
assert model.on_train_epoch_start_called


def swa_resume_training_from_checkpoint(tmpdir, crash_after_epoch=4, ddp=False):
model = SwaTestModel(crash_after_epoch=crash_after_epoch)
swa_start = 3
max_epochs = 5
swa_callback = SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1)

num_processes = 2 if ddp else 1
strategy = "ddp_spawn" if ddp else None

trainer = Trainer(
default_root_dir=tmpdir,
enable_progress_bar=False,
max_epochs=max_epochs,
limit_train_batches=5,
limit_val_batches=0,
callbacks=[swa_callback],
accumulate_grad_batches=2,
num_processes=num_processes,
strategy=strategy,
)

exception_type = Exception if ddp else RuntimeError
backward_patch = mock.patch.object(TrainingTypePlugin, "backward", wraps=trainer.training_type_plugin.backward)
with backward_patch, pytest.raises(exception_type):
trainer.fit(model)

checkpoint_dir = Path(tmpdir) / "lightning_logs" / "version_0" / "checkpoints"
carmocca marked this conversation as resolved.
Show resolved Hide resolved
checkpoint_files = os.listdir(checkpoint_dir)
assert len(checkpoint_files) == 1
checkpoint_path = checkpoint_dir / checkpoint_files[0]

model = SwaTestModel()
swa_callback = SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1)
trainer = Trainer(
default_root_dir=tmpdir,
enable_progress_bar=False,
max_epochs=max_epochs,
limit_train_batches=5,
limit_val_batches=0,
callbacks=[swa_callback],
accumulate_grad_batches=2,
num_processes=num_processes,
strategy=strategy,
)

with mock.patch.object(TrainingTypePlugin, "backward", wraps=trainer.training_type_plugin.backward):
trainer.fit(model, ckpt_path=checkpoint_path.as_posix())


@pytest.mark.parametrize("crash_after_epoch", [2, 4])
def test_swa_resume_training_from_checkpoint(tmpdir, crash_after_epoch):
swa_resume_training_from_checkpoint(tmpdir, crash_after_epoch=crash_after_epoch)


@RunIf(skip_windows=True, min_torch="1.8")
def test_swa_resume_training_from_checkpoint_ddp(tmpdir):
# Requires PyTorch >= 1.8 to include this segfault fix:
# https://github.com/pytorch/pytorch/pull/50998
swa_resume_training_from_checkpoint(tmpdir, ddp=True)


def _test_misconfiguration_error_with_sharded_model(tmpdir, strategy, gpus=None):
model = SwaTestModel()
swa_callback = SwaTestCallback(swa_epoch_start=2, swa_lrs=0.1)
trainer = Trainer(
default_root_dir=tmpdir,
enable_progress_bar=False,
max_epochs=5,
callbacks=[swa_callback],
strategy=strategy,
gpus=gpus,
)
with pytest.raises(MisconfigurationException, match="SWA does not currently support sharded models"):
trainer.fit(model)


@RunIf(fairscale_fully_sharded=True, min_gpus=1)
def test_misconfiguration_error_with_ddp_fully_sharded(tmpdir):
_test_misconfiguration_error_with_sharded_model(tmpdir, "fsdp", 1)


@RunIf(deepspeed=True)
def test_misconfiguration_error_with_deep_speed(tmpdir):
_test_misconfiguration_error_with_sharded_model(tmpdir, "deepspeed")