-
Notifications
You must be signed in to change notification settings - Fork 3.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support training resume and saving best model for SWA #6074
Comments
Hi, I'm also interested in this feature and have had a go at implementing it. I've opened a draft PR at #9938, and would appreciate some feedback from the PyTorch Lightning devs to make sure you're happy with this approach before I tidy it up and add tests and documentation. I've made using SWA weights for validation optional and defaulted it to false for backwards compatibility. Loading the best model from a checkpoint seems a bit awkward though. With my changes you can use the checkpoint_callback = ModelCheckpoint(monitor="val_loss", save_top_k=3, mode='min')
swa_callback = StochasticWeightAveraging(swa_epoch_start=0.8, swa_validation=True)
trainer = Trainer(..., callbacks=[swa_callback, checkpoint_callback])
trainer.fit(model, data_model)
checkpoint_path = checkpoint_callback.best_model_path
new_model = MyModel.load_from_checkpoint(checkpoint_path=checkpoint_path)
parameters_loaded: bool = StochasticWeightAveraging.restore_average_parameters_from_checkpoint(new_model, checkpoint_path) One remaining issue is that when resuming from training, the SWALR scheduler is recreated with I also need to work out how to handle models with batch normalization. |
I've been looking into how batch normalization currently works when using SWA to better understand how this should be handled during validation. The mean and variance of inputs to the BatchNorm layers depend on the weights of lower layers so the running estimates computed by the underlying model won't be accurate when the weights are replaced with average weights. Therefore the mean and variance are computed with a full pass over the training set at the end of training, after replacing the weights with the averaged weights. This is done by increasing the number of epochs by one but setting The validation passes should probably also update the batch norm parameters in order to accurately represent the model performance. Using the same approach of adding extra non-optimization training epochs before validation seems like it could get quite awkward though. I'm wondering whether it would be simpler to do this without modifying the training loop parameters by just using the training data fetcher directly. I've updated my draft PR to use this approach but I'm not that familiar with the PyTorch Lightning code base so am not sure what I might be missing by using the data fetcher directly rather than letting the training loop handle this. The |
Hi @tchaton, it looks like you implemented SWA in PyTorch Lightning originally, do you have any thoughts on this? |
Implemented in #9938 |
馃殌 Feature
To support stop-resume and saving the model based on the best validation performance when using Stochastic weight averaging.
Motivation
If one does not reach the end of the training, the SWA model would not be saved to the checkpoint.
For example:
I observe that the saved checkpoint is still the original model even if I use an SWA callback.
Pitch
checkpotint['state_dict']
and let the original model weight saved as a callback state.Alternatives
Additional context
@Borda @MilesCranmer, I do not have much experience on coding neither am I familiar with PyTorch Lightning. would you help? Thx a lot. 馃槈
The text was updated successfully, but these errors were encountered: