Skip to content

Commit

Permalink
Refactor SWA batch norm moment update to work with validation
Browse files Browse the repository at this point in the history
  • Loading branch information
adamreeve committed Oct 18, 2021
1 parent 1696273 commit c8db9d8
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 55 deletions.
79 changes: 36 additions & 43 deletions pytorch_lightning/callbacks/stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,6 @@ def __init__(
if device is not None and not isinstance(device, (torch.device, str)):
raise MisconfigurationException(f"device is expected to be a torch.device or a str. Found {device}")

self.momenta = None
self.n_averaged = None
self._swa_epoch_start = swa_epoch_start
self._swa_lrs = swa_lrs
Expand All @@ -134,6 +133,7 @@ def __init__(
self._temp_model = None
self._initialized = False
self._swa_scheduler = None
self._batch_norm_moments = None

@property
def swa_start(self) -> int:
Expand Down Expand Up @@ -171,12 +171,9 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
self._model_contains_batch_norm = self.pl_module_contains_batch_norm(pl_module)

self._max_epochs = trainer.max_epochs
if self._model_contains_batch_norm:
# virtually increase max_epochs to perform batch norm update on latest epoch.
trainer.fit_loop.max_epochs += 1

def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
resuming_after_start = trainer.current_epoch > self.swa_start and not self._initialized
resuming_after_start = (not self._initialized) and (self.swa_start < trainer.current_epoch <= self.swa_end)
if trainer.current_epoch == self.swa_start or resuming_after_start:
self._initialized = True

Expand Down Expand Up @@ -223,75 +220,73 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
if self.swa_start <= trainer.current_epoch <= self.swa_end:
self.update_parameters(self._average_model, pl_module, self.n_averaged, self.avg_fn)

# Note: No > here in case the callback is saved with the model and training continues
if trainer.current_epoch == self.swa_end + 1:

# Transfer weights from average model to pl_module
self.transfer_weights(self._average_model, pl_module)

# Reset BatchNorm for update
self.reset_batch_norm_and_save_state(pl_module)

# There is no need to perform either backward or optimizer.step as we are
# performing only one pass over the train data-loader to compute activation statistics
# Therefore, we will virtually increase `num_training_batches` by 1 and skip backward.
trainer.num_training_batches += 1
trainer.fit_loop._skip_backward = True
self._accumulate_grad_batches = trainer.accumulate_grad_batches

trainer.accumulate_grad_batches = trainer.num_training_batches

def on_train_epoch_end(self, trainer: "pl.Trainer", *args):
trainer.fit_loop._skip_backward = False

def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule"):
if self._model_contains_batch_norm and trainer.current_epoch == self.swa_end + 1:
# BatchNorm epoch update. Reset state
trainer.accumulate_grad_batches = self._accumulate_grad_batches
trainer.num_training_batches -= 1
trainer.fit_loop.max_epochs -= 1
self.reset_momenta()
elif trainer.current_epoch == self.swa_end:
if trainer.current_epoch == self.swa_end:
# Last SWA epoch. Transfer weights from average model to pl_module
self.transfer_weights(self._average_model, pl_module)
if self._model_contains_batch_norm:
self._update_batch_norm_moments(trainer, pl_module, store_moments=False)

def on_validation_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self._swa_validation and (self.swa_start <= trainer.current_epoch <= self.swa_end):
# Take a temporary copy of the model parameters
self.transfer_weights(pl_module, self._temp_model)
# Update the model with the averaged parameters
self.transfer_weights(self._average_model, pl_module)
if self._model_contains_batch_norm:
self._update_batch_norm_moments(trainer, pl_module)

def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
if self._swa_validation and (self.swa_start <= trainer.current_epoch <= self.swa_end):
# Copy original model parameters back
self.transfer_weights(self._temp_model, pl_module)
if self._model_contains_batch_norm:
self._restore_batch_norm_moments()

@staticmethod
def transfer_weights(src_pl_module: "pl.LightningModule", dst_pl_module: "pl.LightningModule"):
for src_param, dst_param in zip(src_pl_module.parameters(), dst_pl_module.parameters()):
dst_param.detach().copy_(src_param.to(dst_param.device))

def reset_batch_norm_and_save_state(self, pl_module: "pl.LightningModule"):
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L154."""
self.momenta = {}
def _update_batch_norm_moments(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", store_moments: bool = True
):
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L140-L166."""
prev_momenta = {}
self._batch_norm_moments = {}

was_training = pl_module.training
pl_module.train()

for module in pl_module.modules():
if not isinstance(module, nn.modules.batchnorm._BatchNorm):
continue
prev_momenta[module] = module.momentum
if store_moments:
self._batch_norm_moments[module] = (module.running_mean, module.running_var)
module.running_mean = torch.zeros_like(
module.running_mean, device=pl_module.device, dtype=module.running_mean.dtype
)
module.running_var = torch.ones_like(
module.running_var, device=pl_module.device, dtype=module.running_var.dtype
)
self.momenta[module] = module.momentum
module.momentum = None
module.num_batches_tracked *= 0

def reset_momenta(self):
"""Adapted from https://github.com/pytorch/pytorch/blob/v1.7.1/torch/optim/swa_utils.py#L164-L165."""
for bn_module in self.momenta:
bn_module.momentum = self.momenta[bn_module]
# Recompute mean and variance for all batch norm layers by doing a full pass over the training data
for batch, _ in trainer.data_connector.train_data_fetcher:
batch = batch.to(pl_module.device)
pl_module(batch)

# Reset model state
for bn_module, momenta in prev_momenta.items():
bn_module.momentum = momenta
pl_module.train(was_training)

def _restore_batch_norm_moments(self):
for bn_module, (mean, variance) in self._batch_norm_moments.items():
bn_module.running_mean = mean
bn_module.running_var = variance

@staticmethod
def update_parameters(
Expand All @@ -317,7 +312,6 @@ def on_save_checkpoint(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
) -> dict:
checkpoint_data = {
"momenta": self.momenta,
"n_averaged": self.n_averaged,
"swa_lrs": self._swa_lrs,
"annealing_epochs": self._annealing_epochs,
Expand All @@ -330,7 +324,6 @@ def on_load_checkpoint(
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", callback_state: Dict[str, Any]
) -> None:
if callback_state:
self.momenta = callback_state["momenta"]
self.n_averaged = callback_state["n_averaged"]
self._swa_lrs = callback_state["swa_lrs"]
self._annealing_strategy = callback_state["annealing_strategy"]
Expand Down
43 changes: 31 additions & 12 deletions tests/callbacks/test_stochastic_weight_avg.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,21 @@ def training_step(self, batch, batch_idx):
loss = self.loss(batch, output)
return {"loss": loss}

def validation_step(self, batch, batch_idx):
output = self.forward(batch)
loss = self.loss(batch, output)
return {"x": loss}

def train_dataloader(self):

dset_cls = RandomIterableDataset if self.iterable_dataset else RandomDataset
dset = dset_cls(32, 64)

return DataLoader(dset, batch_size=2)

def val_dataloader(self):
return self.train_dataloader()

def configure_optimizers(self):
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
return {
Expand Down Expand Up @@ -86,20 +94,24 @@ def __init__(self, *args, **kwargs):
self.resuming_from_epoch = 0
super().__init__(*args, **kwargs)

validation_calls: int = 0
update_parameters_calls: int = 0
transfer_weights_calls: int = 0

def update_parameters(self, *args, **kwargs):
self.update_parameters_calls += 1
return StochasticWeightAveraging.update_parameters(*args, **kwargs)

def on_validation_start(self, *args, **kwargs):
self.validation_calls += 1
return super().on_validation_start(*args, **kwargs)

def transfer_weights(self, *args, **kwargs):
self.transfer_weights_calls += 1
return StochasticWeightAveraging.transfer_weights(*args, **kwargs)

def on_train_epoch_start(self, trainer, *args):
super().on_train_epoch_start(trainer, *args)
assert trainer.fit_loop._skip_backward == (trainer.current_epoch > self.swa_end)
if self.swa_start <= trainer.current_epoch:
assert isinstance(trainer.lr_schedulers[0]["scheduler"], SWALR)
assert trainer.lr_schedulers[0]["interval"] == "epoch"
Expand All @@ -116,11 +128,6 @@ def on_train_epoch_end(self, trainer, *args):
def on_train_end(self, trainer, pl_module):
super().on_train_end(trainer, pl_module)

# make sure these are correctly set again
assert not trainer.fit_loop._skip_backward
assert trainer.accumulate_grad_batches == 2
assert trainer.num_training_batches == 5

if not isinstance(trainer.training_type_plugin, DDPSpawnPlugin):
# check backward call count. the batchnorm update epoch should not backward
assert trainer.accelerator.backward.call_count == (
Expand All @@ -133,16 +140,27 @@ def on_train_end(self, trainer, pl_module):
else:
expected_update_calls = trainer.max_epochs - (self._swa_epoch_start - 1)
assert self.update_parameters_calls == expected_update_calls
assert self.transfer_weights_calls == 1
if self._swa_validation:
# 3 weight transfers are needed per SWA validation step
assert self.transfer_weights_calls == (self.validation_calls - self._swa_epoch_start) * 3 + 1
else:
assert self.transfer_weights_calls == 1


def train_with_swa(
tmpdir, batchnorm=True, strategy=None, gpus=None, num_processes=1, interval="epoch", iterable_dataset=False
tmpdir,
batchnorm=True,
strategy=None,
gpus=None,
num_processes=1,
interval="epoch",
iterable_dataset=False,
validation=False,
):
model = SwaTestModel(batchnorm=batchnorm, interval=interval, iterable_dataset=iterable_dataset)
swa_start = 2
max_epochs = 5
swa_callback = SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1)
swa_callback = SwaTestCallback(swa_epoch_start=swa_start, swa_lrs=0.1, swa_validation=validation)
assert swa_callback.update_parameters_calls == 0
assert swa_callback.transfer_weights_calls == 0

Expand All @@ -151,7 +169,7 @@ def train_with_swa(
enable_progress_bar=False,
max_epochs=max_epochs,
limit_train_batches=5,
limit_val_batches=0,
limit_val_batches=1.0 if validation else 0.0,
callbacks=[swa_callback],
accumulate_grad_batches=2,
strategy=strategy,
Expand Down Expand Up @@ -188,8 +206,9 @@ def test_swa_callback_1_gpu(tmpdir):

@pytest.mark.parametrize("batchnorm", (True, False))
@pytest.mark.parametrize("iterable_dataset", (True, False))
def test_swa_callback(tmpdir, batchnorm: bool, iterable_dataset: bool):
train_with_swa(tmpdir, batchnorm=batchnorm, iterable_dataset=iterable_dataset)
@pytest.mark.parametrize("validation", (True, False))
def test_swa_callback(tmpdir, batchnorm: bool, iterable_dataset: bool, validation: bool):
train_with_swa(tmpdir, batchnorm=batchnorm, iterable_dataset=iterable_dataset, validation=validation)


@pytest.mark.parametrize("interval", ("epoch", "step"))
Expand Down

0 comments on commit c8db9d8

Please sign in to comment.