Skip to content

Commit

Permalink
Add test for loading a model from a checkpoint with SWA parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
adamreeve committed Oct 19, 2021
1 parent c8db9d8 commit 004959b
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 7 deletions.
15 changes: 11 additions & 4 deletions pytorch_lightning/callbacks/stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,11 @@ def _update_batch_norm_moments(
prev_momenta = {}
self._batch_norm_moments = {}

train_data_fetcher = trainer.data_connector.train_data_fetcher
if train_data_fetcher is None:
# Training data not yet connected, could be in a validation sanity check
return

was_training = pl_module.training
pl_module.train()

Expand All @@ -274,7 +279,7 @@ def _update_batch_norm_moments(
module.num_batches_tracked *= 0

# Recompute mean and variance for all batch norm layers by doing a full pass over the training data
for batch, _ in trainer.data_connector.train_data_fetcher:
for batch, _ in train_data_fetcher:
batch = batch.to(pl_module.device)
pl_module(batch)

Expand Down Expand Up @@ -316,7 +321,7 @@ def on_save_checkpoint(
"swa_lrs": self._swa_lrs,
"annealing_epochs": self._annealing_epochs,
"annealing_strategy": self._annealing_strategy,
"average_model_parameters": self._get_average_model_parameters(),
"average_model_parameters": self._get_average_model_parameters(trainer),
}
return checkpoint_data

Expand Down Expand Up @@ -380,8 +385,10 @@ def restore_average_parameters_from_checkpoint(
p_model.detach().copy_(p_swa_)
return True

def _get_average_model_parameters(self) -> Any:
if self._average_model is None:
def _get_average_model_parameters(self, trainer: "pl.Trainer") -> Any:
if self._average_model is None or not (self.swa_start <= trainer.current_epoch <= self.swa_end):
# If we're not within the SWA epochs then when loading checkpoint data we would want
# to use parameters from the underlying model rather than the SWA parameters.
return None
parameters = []
for p_swa in self._average_model.parameters():
Expand Down
46 changes: 43 additions & 3 deletions tests/callbacks/test_stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.accelerators import Accelerator
from pytorch_lightning.callbacks import StochasticWeightAveraging
from pytorch_lightning.callbacks import ModelCheckpoint, StochasticWeightAveraging
from pytorch_lightning.plugins import DDPSpawnPlugin
from pytorch_lightning.trainer.connectors.data_connector import _PatchDataLoader
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand All @@ -46,6 +46,7 @@ def __init__(
self.iterable_dataset = iterable_dataset
self.crash_after_epoch = crash_after_epoch
self._epoch_count = 0
self.save_hyperparameters()

def training_step(self, batch, batch_idx):
output = self.forward(batch)
Expand All @@ -55,6 +56,7 @@ def training_step(self, batch, batch_idx):
def validation_step(self, batch, batch_idx):
output = self.forward(batch)
loss = self.loss(batch, output)
self.log("val_loss", loss)
return {"x": loss}

def train_dataloader(self):
Expand Down Expand Up @@ -142,7 +144,7 @@ def on_train_end(self, trainer, pl_module):
assert self.update_parameters_calls == expected_update_calls
if self._swa_validation:
# 3 weight transfers are needed per SWA validation step
assert self.transfer_weights_calls == (self.validation_calls - self._swa_epoch_start) * 3 + 1
assert self.transfer_weights_calls == (self.validation_calls - self.swa_start) * 3 + 1
else:
assert self.transfer_weights_calls == 1

Expand All @@ -169,7 +171,8 @@ def train_with_swa(
enable_progress_bar=False,
max_epochs=max_epochs,
limit_train_batches=5,
limit_val_batches=1.0 if validation else 0.0,
limit_val_batches=5 if validation else 0,
num_sanity_val_steps=0,
callbacks=[swa_callback],
accumulate_grad_batches=2,
strategy=strategy,
Expand Down Expand Up @@ -362,3 +365,40 @@ def test_swa_resume_training_from_checkpoint(tmpdir):

with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward):
trainer.fit(model)


@pytest.mark.parametrize("batchnorm", (True, False))
@pytest.mark.parametrize("within_swa_epochs", (True, False))
def test_swa_load_best_checkpoint(tmpdir, batchnorm: bool, within_swa_epochs: bool):
model = SwaTestModel(batchnorm=batchnorm)
if within_swa_epochs:
# Start at epoch 1 so we can guarantee the best checkpoint should be saved with SWA weights
swa_start = 1
else:
# Start after the last epoch, so we never save a checkpoint with SWA parameters
swa_start = 6
max_epochs = 5

swa_callback = SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1, swa_validation=True)
checkpoint_callback = ModelCheckpoint(monitor="val_loss", save_top_k=3, mode="min")

trainer = Trainer(
default_root_dir=tmpdir,
enable_progress_bar=False,
max_epochs=max_epochs,
limit_train_batches=5,
limit_val_batches=5,
num_sanity_val_steps=0,
callbacks=[swa_callback, checkpoint_callback],
accumulate_grad_batches=2,
num_processes=1,
)

with mock.patch.object(Accelerator, "backward", wraps=trainer.accelerator.backward):
trainer.fit(model)

checkpoint_path = checkpoint_callback.best_model_path
new_model = SwaTestModel.load_from_checkpoint(checkpoint_path)
parameters_loaded = SwaTestCallback.restore_average_parameters_from_checkpoint(new_model, checkpoint_path)

assert parameters_loaded == within_swa_epochs

0 comments on commit 004959b

Please sign in to comment.