Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support checkpoint save and load with Stochastic Weight Averaging #9938

Merged
merged 94 commits into from
Aug 9, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
94 commits
Select commit Hold shift + click to select a range
72d0433
Save StochasticWeightAveraging callback data in checkpoints
adamreeve Oct 14, 2021
3d2bf65
Add option to use SWA parameters during validation
adamreeve Oct 14, 2021
1696273
Allow restoring SWA parameters to a model from a checkpoint
adamreeve Oct 14, 2021
c8db9d8
Refactor SWA batch norm moment update to work with validation
adamreeve Oct 18, 2021
004959b
Add test for loading a model from a checkpoint with SWA parameters
adamreeve Oct 19, 2021
d76528b
Recompute batch norm moments when updating parameters from a checkpoint
adamreeve Oct 19, 2021
0ea22e0
Handle when data batch is a list or tuple
adamreeve Oct 20, 2021
01ca2a7
Save SWA scheduler step count in checkpoints
adamreeve Oct 27, 2021
08d655b
Update SWA documentation and changelog
adamreeve Oct 27, 2021
91ab357
Fix DeepSource code style issues
adamreeve Oct 27, 2021
22e5d51
Revert SWA validation changes
adamreeve Nov 9, 2021
ed0a7f8
Merge remote-tracking branch 'upstream/master' into swa_checkpoint
adamreeve Nov 9, 2021
11963f6
Fix resuming from epoch before SWA start and add extra test
adamreeve Nov 9, 2021
226d8aa
Don't save state derived from constructor parameters into checkpoints
adamreeve Nov 9, 2021
9ecc417
Merge branch 'master' into swa_checkpoint
tchaton Nov 15, 2021
5d03d96
Tidy ups from code review
adamreeve Nov 15, 2021
02a04da
Fix handling of n_averaged checkpoint data with multiple processes
adamreeve Nov 15, 2021
8af5b56
Merge remote-tracking branch 'upstream/master' into swa_checkpoint
adamreeve Nov 16, 2021
db9590c
Merge branch 'master' into swa_checkpoint
tchaton Nov 29, 2021
5763e05
Fix deprecation warning in test
adamreeve Nov 29, 2021
d46be83
Remove check for non-empty callback state in checkpoint
adamreeve Nov 29, 2021
e0fd0cb
Raise MisconfigurationException when using SWA with sharded models
adamreeve Nov 29, 2021
2a83f05
Fix test failure with torch 1.7
adamreeve Nov 29, 2021
4a8d81c
Fix crash when fairscale isn't installed
adamreeve Nov 29, 2021
dab0ef4
Skip segfaulting test under pytorch < 1.8
adamreeve Nov 30, 2021
a0d52c8
Changelog merge fix
adamreeve Nov 30, 2021
cdf4734
Remove unnecessary intermediate variable
adamreeve Nov 30, 2021
ba5b8ab
Fix checking for sharded plugins
adamreeve Nov 30, 2021
d2bb0ad
Don't raise an error for DDPSharded and DDPSpawnSharded with SWA
adamreeve Nov 30, 2021
2c35328
Merge remote-tracking branch 'upstream/master' into swa_checkpoint
adamreeve Nov 30, 2021
ffcf011
Fix incorrect multiple context manager syntax for Python < 3.9
adamreeve Dec 5, 2021
c278034
Merge remote-tracking branch 'upstream/master' into swa_checkpoint
adamreeve Dec 6, 2021
d2fbe04
Merge branch 'master' into swa_checkpoint
adamreeve Dec 14, 2021
f13abf9
Merge branch 'master' into swa_checkpoint
adamreeve Dec 14, 2021
8e848dc
Code review tidy up and fix CHANGELOG merge error
adamreeve Dec 15, 2021
11757d5
Add a warning with initializing SWA after start but without checkpoin…
adamreeve Dec 16, 2021
50d525f
Merge branch 'master' into swa_checkpoint
adamreeve Dec 16, 2021
119f9b9
Merge branch 'master' into swa_checkpoint
adamreeve Dec 16, 2021
e332a42
Merge branch 'master' into swa_checkpoint
adamreeve Dec 21, 2021
fd59c41
Fixes to account for changes merged from master
adamreeve Dec 21, 2021
fe62b55
Merge branch 'master' into swa_checkpoint
adamreeve Dec 22, 2021
440c4b6
Merge branch 'master' into swa_checkpoint
adamreeve Jan 12, 2022
b10261e
Fix SWA scheduler not being stepped
adamreeve Jan 12, 2022
5bc9bee
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 12, 2022
8b9c624
Merge branch 'master' into swa_checkpoint
adamreeve Jan 23, 2022
72c0242
Merge branch 'master' into swa_checkpoint
adamreeve Jan 31, 2022
4dfb0df
Merge branch 'master' into swa_checkpoint
awaelchli Feb 5, 2022
9b5fbfc
mark test helper protected
awaelchli Feb 5, 2022
8e0c255
avoid warning for find_unused_parameters
awaelchli Feb 5, 2022
6bb52ba
Merge branch 'master' into swa_checkpoint
adamreeve Feb 9, 2022
c44279f
Use _LRScheduler.state_dict/load_state_dict instead of accessing priv…
adamreeve Feb 10, 2022
b3eee59
Add test to reproduce crash when resuming with SWA and a custom sched…
adamreeve Feb 9, 2022
0107ff1
Prevent trying to restore scheduler state into the wrong type of sche…
adamreeve Feb 10, 2022
8067144
Merge branch 'master' into swa_checkpoint
adamreeve Feb 11, 2022
81ac195
Add test case where trainer.strategy.restore_checkpoint_after_setup i…
adamreeve Feb 14, 2022
20393b1
Minor test refactoring
carmocca Feb 15, 2022
14f9f20
Fix test_swa_resume_training_from_checkpoint[2]
carmocca Feb 15, 2022
c677141
Did not mean to remove this
carmocca Feb 15, 2022
5cf5e1b
Test tidy up from PR review comments
adamreeve Feb 15, 2022
fe79d6c
Store most recent update epoch in the SWA checkpoint data
adamreeve Feb 15, 2022
d799a62
Merge branch 'master' into swa_checkpoint
adamreeve Feb 27, 2022
c7c2818
Fix for master change that broke resuming without validation dataloaders
adamreeve Feb 27, 2022
d2ed468
Adjust SWA tests to account for current checkpoint resume behaviour
adamreeve Feb 27, 2022
a2143a8
Merge branch 'master' into swa_checkpoint
adamreeve Mar 14, 2022
00328e8
Merge branch 'master' into swa_checkpoint
adamreeve Mar 24, 2022
5dbfc2d
Merge branch 'master' into swa_checkpoint
adamreeve Mar 28, 2022
b71b690
Revert workarounds for first epoch after resume having no batches
adamreeve Mar 28, 2022
15e6334
Use state_dict/load_state_dict instead of on_save/load_checkpoint in SWA
adamreeve Mar 28, 2022
e3104bc
Remove unnecessary workaround for handling restore_checkpoint_after_s…
adamreeve Apr 20, 2022
6e9fbba
Merge branch 'master' into swa_checkpoint
adamreeve Apr 20, 2022
08eecbb
Merge branch 'master' into swa_checkpoint
krshrimali Apr 25, 2022
1e9dc33
Merge branch 'master' into swa_checkpoint
adamreeve May 17, 2022
f509178
Fix deprecation warning in tests
adamreeve May 17, 2022
0388aea
Merge branch 'master' into swa_checkpoint
adamreeve May 30, 2022
f7594d6
Merge branch 'master' into swa_checkpoint
Borda Jun 21, 2022
cb6ce90
Merge branch 'master' into swa_checkpoint
Borda Jun 27, 2022
ddcb607
Merge branch 'master' into swa_checkpoint
awaelchli Jul 25, 2022
77f137c
update runif
awaelchli Jul 25, 2022
324499e
Remove no-longer required minimum torch version from test
adamreeve Aug 2, 2022
ab8aca0
Remove redundant None check that could hide a bug
adamreeve Aug 2, 2022
7d6e7a8
Don't save scheduler configs as they will only be overridden
adamreeve Aug 2, 2022
9bf237e
Use state_dict/load_state_dict to save and load average model state
adamreeve Aug 2, 2022
a9b6334
Parametrize misconfiguration error tests
adamreeve Aug 2, 2022
c24522b
Remove DummyError and match exception message
adamreeve Aug 2, 2022
b6b7db9
Merge remote-tracking branch 'upstream/master' into swa_checkpoint
adamreeve Aug 2, 2022
ba7cb5e
Fix state dict key
adamreeve Aug 2, 2022
8bde4f4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Aug 2, 2022
3ed8ea4
Type checking fixes
adamreeve Aug 2, 2022
085bb4a
Merge remote-tracking branch 'upstream/master' into swa_checkpoint
adamreeve Aug 2, 2022
afba59d
Merge branch 'master' into swa_checkpoint
carmocca Aug 3, 2022
807fadf
Merge branch 'master' into swa_checkpoint
awaelchli Aug 3, 2022
15fe88e
fix changelog conflicts
awaelchli Aug 3, 2022
dcf5fea
Merge branch 'master' into swa_checkpoint
rohitgr7 Aug 9, 2022
ce9bcea
Merge branch 'master' into swa_checkpoint
awaelchli Aug 9, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
244 changes: 202 additions & 42 deletions pytorch_lightning/callbacks/stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,20 @@
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
"""
from copy import deepcopy
from typing import Callable, List, Optional, Union
from typing import Any, Callable, Dict, IO, List, Optional, Type, Union

import torch
from torch import nn
from torch.optim.swa_utils import SWALR
from torch.utils.data import DataLoader

import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.exceptions import MisconfigurationException

_AVG_FN = Callable[[torch.Tensor, torch.Tensor, torch.LongTensor], torch.FloatTensor]
Expand All @@ -40,6 +44,7 @@ def __init__(
annealing_strategy: str = "cos",
avg_fn: Optional[_AVG_FN] = None,
device: Optional[Union[torch.device, str]] = torch.device("cpu"),
swa_validation: bool = False,
):
r"""

Expand Down Expand Up @@ -93,6 +98,9 @@ def __init__(
When None is provided, it will infer the `device` from ``pl_module``.
(default: ``"cpu"``)

swa_validation: if True, then the averaged model weights are used during validation
(default: ``False``)

"""

err_msg = "swa_epoch_start should be a >0 integer or a float between 0 and 1."
Expand All @@ -115,14 +123,21 @@ def __init__(
if device is not None and not isinstance(device, (torch.device, str)):
raise MisconfigurationException(f"device is expected to be a torch.device or a str. Found {device}")

self.n_averaged = None
adamreeve marked this conversation as resolved.
Show resolved Hide resolved
self._swa_epoch_start = swa_epoch_start
self._swa_lrs = swa_lrs
self._annealing_epochs = annealing_epochs
self._annealing_strategy = annealing_strategy
self._avg_fn = avg_fn or self.avg_fn
self._swa_validation = swa_validation
self._device = device
self._model_contains_batch_norm = None
self._average_model = None
self._temp_model = None
self._initialized = False
self._swa_scheduler = None
self._batch_norm_moments = None
self._scheduler_step_count = None

@property
def swa_start(self) -> int:
Expand All @@ -140,6 +155,9 @@ def on_before_accelerator_backend_setup(self, trainer: "pl.Trainer", pl_module:
# copy the model before moving it to accelerator device.
with pl_module._prevent_trainer_and_dataloaders_deepcopy():
self._average_model = deepcopy(pl_module)
if self._swa_validation:
# Also create a model for temporarily copying weights to during validation
self._temp_model = deepcopy(pl_module)

def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
optimizers = trainer.optimizers
Expand All @@ -157,14 +175,16 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
self._model_contains_batch_norm = self.pl_module_contains_batch_norm(pl_module)

self._max_epochs = trainer.max_epochs
if self._model_contains_batch_norm:
# virtually increase max_epochs to perform batch norm update on latest epoch.
trainer.fit_loop.max_epochs += 1

def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
if trainer.current_epoch == self.swa_start:
resuming_after_start = (not self._initialized) and (self.swa_start < trainer.current_epoch <= self.swa_end)
if trainer.current_epoch == self.swa_start or resuming_after_start:
adamreeve marked this conversation as resolved.
Show resolved Hide resolved
self._initialized = True

# move average model to request device.
self._average_model = self._average_model.to(self._device or pl_module.device)
if self._temp_model:
self._temp_model = self._temp_model.to(self._device or pl_module.device)

optimizer = trainer.optimizers[0]
if self._swa_lrs is None:
Expand All @@ -182,6 +202,9 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
anneal_strategy=self._annealing_strategy,
last_epoch=trainer.max_epochs if self._annealing_strategy == "cos" else -1,
)
if self._scheduler_step_count is not None:
# Restore scheduler step count from checkpoint
self._swa_scheduler._step_count = self._scheduler_step_count
adamreeve marked this conversation as resolved.
Show resolved Hide resolved
default_scheduler_cfg = _get_default_scheduler_config()
assert default_scheduler_cfg["interval"] == "epoch" and default_scheduler_cfg["frequency"] == 1
default_scheduler_cfg["scheduler"] = self._swa_scheduler
Expand All @@ -198,68 +221,97 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
else:
trainer.lr_schedulers.append(default_scheduler_cfg)

self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device)
if self.n_averaged is None:
self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device)

if self.swa_start <= trainer.current_epoch <= self.swa_end:
self.update_parameters(self._average_model, pl_module, self.n_averaged, self.avg_fn)

# Note: No > here in case the callback is saved with the model and training continues
if trainer.current_epoch == self.swa_end + 1:

# Transfer weights from average model to pl_module
self.transfer_weights(self._average_model, pl_module)

# Reset BatchNorm for update
self.reset_batch_norm_and_save_state(pl_module)

# There is no need to perform either backward or optimizer.step as we are
# performing only one pass over the train data-loader to compute activation statistics
# Therefore, we will virtually increase `num_training_batches` by 1 and skip backward.
trainer.num_training_batches += 1
trainer.fit_loop._skip_backward = True
self._accumulate_grad_batches = trainer.accumulate_grad_batches

trainer.accumulate_grad_batches = trainer.num_training_batches

def on_train_epoch_end(self, trainer: "pl.Trainer", *args):
trainer.fit_loop._skip_backward = False

def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
if self._model_contains_batch_norm and trainer.current_epoch == self.swa_end + 1:
# BatchNorm epoch update. Reset state
trainer.accumulate_grad_batches = self._accumulate_grad_batches
trainer.num_training_batches -= 1
trainer.fit_loop.max_epochs -= 1
self.reset_momenta()
elif trainer.current_epoch == self.swa_end:
if trainer.current_epoch == self.swa_end:
# Last SWA epoch. Transfer weights from average model to pl_module
self.transfer_weights(self._average_model, pl_module)
if self._model_contains_batch_norm:
self._update_batch_norm_moments(trainer, pl_module, store_moments=False)

def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self._swa_validation and (self.swa_start <= trainer.current_epoch <= self.swa_end):
# Take a temporary copy of the model parameters
self.transfer_weights(pl_module, self._temp_model)
# Update the model with the averaged parameters
self.transfer_weights(self._average_model, pl_module)
if self._model_contains_batch_norm:
self._update_batch_norm_moments(trainer, pl_module)

def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self._swa_validation and (self.swa_start <= trainer.current_epoch <= self.swa_end):
# Copy original model parameters back
self.transfer_weights(self._temp_model, pl_module)
if self._model_contains_batch_norm:
self._restore_batch_norm_moments()

@staticmethod
def transfer_weights(src_pl_module: "pl.LightningModule", dst_pl_module: "pl.LightningModule"):
for src_param, dst_param in zip(src_pl_module.parameters(), dst_pl_module.parameters()):
dst_param.detach().copy_(src_param.to(dst_param.device))

def reset_batch_norm_and_save_state(self, pl_module: "pl.LightningModule"):
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L154."""
self.momenta = {}
def _update_batch_norm_moments(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", store_moments: bool = True
):
self._batch_norm_moments = {}

train_dataloader = trainer.train_dataloader
if train_dataloader is None:
# Training data not yet connected, could be in a validation sanity check
return

self._update_module_batch_norm_moments(
train_dataloader, pl_module, self._batch_norm_moments if store_moments else None
)

@staticmethod
def _update_module_batch_norm_moments(
data_loader: Union[DataLoader, CombinedLoader],
pl_module: "pl.LightningModule",
moment_cache: Optional[Dict[nn.Module, Any]] = None,
):
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L166."""
prev_momenta = {}

was_training = pl_module.training
pl_module.train()

for module in pl_module.modules():
if not isinstance(module, nn.modules.batchnorm._BatchNorm):
continue
prev_momenta[module] = module.momentum
if moment_cache is not None:
moment_cache[module] = (module.running_mean, module.running_var)
module.running_mean = torch.zeros_like(
module.running_mean, device=pl_module.device, dtype=module.running_mean.dtype
)
module.running_var = torch.ones_like(
module.running_var, device=pl_module.device, dtype=module.running_var.dtype
)
self.momenta[module] = module.momentum
module.momentum = None
module.num_batches_tracked *= 0

def reset_momenta(self):
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165."""
for bn_module in self.momenta:
bn_module.momentum = self.momenta[bn_module]
# Recompute mean and variance for all batch norm layers by doing a full pass over the training data
for batch in data_loader:
if isinstance(batch, (list, tuple)):
batch = batch[0]
batch = batch.to(pl_module.device)
pl_module(batch)
adamreeve marked this conversation as resolved.
Show resolved Hide resolved

# Reset model state
for bn_module, momenta in prev_momenta.items():
bn_module.momentum = momenta
pl_module.train(was_training)

def _restore_batch_norm_moments(self):
for bn_module, (mean, variance) in self._batch_norm_moments.items():
bn_module.running_mean = mean
bn_module.running_var = variance

@staticmethod
def update_parameters(
Expand All @@ -280,3 +332,111 @@ def avg_fn(
) -> torch.FloatTensor:
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L95-L97."""
return averaged_model_parameter + (model_parameter - averaged_model_parameter) / (num_averaged + 1)

def on_save_checkpoint(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
) -> dict:
checkpoint_data = {
"n_averaged": self.n_averaged,
"swa_lrs": self._swa_lrs,
"annealing_epochs": self._annealing_epochs,
"annealing_strategy": self._annealing_strategy,
"scheduler_step_count": None if self._swa_scheduler is None else self._swa_scheduler._step_count,
"average_model_parameters": self._get_average_model_parameters(trainer),
}
return checkpoint_data

def on_load_checkpoint(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", callback_state: Dict[str, Any]
) -> None:
if callback_state:
self.n_averaged = callback_state["n_averaged"]
self._swa_lrs = callback_state["swa_lrs"]
self._annealing_strategy = callback_state["annealing_strategy"]
self._annealing_epochs = callback_state["annealing_epochs"]
self._scheduler_step_count = callback_state["scheduler_step_count"]
self._load_average_model_parameters(callback_state["average_model_parameters"])
else:
rank_zero_warn(
adamreeve marked this conversation as resolved.
Show resolved Hide resolved
f"Checkpoint has no data for the {self.state_key} callback, not initializing the callback state."
)

@classmethod
def restore_average_parameters_from_checkpoint(
cls,
pl_module: "pl.LightningModule",
checkpoint_path: Union[str, IO],
map_location: Optional[Union[Dict[str, str], str, torch.device, int, Callable]] = None,
datamodule: Optional[LightningDataModule] = None,
) -> bool:
r"""
Set model weights to the SWA averaged weights saved in a checkpoint.

Arguments:
pl_module: The module to set weights on

checkpoint_path: Path to checkpoint. This can also be a URL, or file-like object

map_location: If your checkpoint saved a GPU model and you now load on CPUs
or a different number of GPUs, use this to map to the new setup.
The behaviour is the same as in :func:`torch.load`.

datamodule: If the module uses batch normalization and does not implement the train_dataloder method,
a data module must be provided in order to allow recomputing the batch normalization parameters after
loading the SWA weights.

Return:
A `bool` indicating whether averaged weights were loaded. If `False`, this means the checkpoint is
from an epoch before the SWA epoch start.
"""
if map_location is not None:
checkpoint = pl_load(checkpoint_path, map_location=map_location)
else:
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)
callback_states: Dict[Union[Type, str], Dict] = checkpoint.get("callbacks")
if not callback_states:
raise ValueError("callback states are not present in the checkpoint")

state_key = cls.__qualname__ # Default state key defined in Callback base class
state = callback_states.get(state_key)
if not state:
raise ValueError(f"no {state_key} state found in the checkpoint")
state = deepcopy(state)
average_model_parameters = state["average_model_parameters"]

if not average_model_parameters:
return False

for p_model, p_swa in zip(pl_module.parameters(), average_model_parameters):
device = p_model.device
p_swa_ = p_swa.detach().to(device)
p_model.detach().copy_(p_swa_)

if cls.pl_module_contains_batch_norm(pl_module):
if datamodule is not None:
train_dataloaders = datamodule.train_dataloader()
else:
train_dataloaders = pl_module.train_dataloader()
train_dataloaders = CombinedLoader(train_dataloaders, mode="max_size_cycle")
cls._update_module_batch_norm_moments(train_dataloaders, pl_module)

return True

def _get_average_model_parameters(self, trainer: "pl.Trainer") -> Any:
adamreeve marked this conversation as resolved.
Show resolved Hide resolved
if self._average_model is None or not (self.swa_start <= trainer.current_epoch <= self.swa_end):
# If we're not within the SWA epochs then when loading checkpoint data we would want
# to use parameters from the underlying model rather than the SWA parameters.
return None
adamreeve marked this conversation as resolved.
Show resolved Hide resolved
parameters = []
for p_swa in self._average_model.parameters():
parameters.append(p_swa.detach())
adamreeve marked this conversation as resolved.
Show resolved Hide resolved
return parameters

def _load_average_model_parameters(self, parameter_state: Any):
adamreeve marked this conversation as resolved.
Show resolved Hide resolved
if self._average_model is None:
return
for p_swa, p_checkpoint in zip(self._average_model.parameters(), parameter_state):
device = p_swa.device
p_swa_ = p_swa.detach()
p_checkpoint_ = p_checkpoint.detach().to(device)
p_swa_.copy_(p_checkpoint_)