Skip to content

Commit

Permalink
Call on_load_checkpoint before loading state_dict (#4057)
Browse files Browse the repository at this point in the history
  • Loading branch information
rohitgr7 committed Oct 14, 2020
1 parent f967fbb commit dec31b3
Showing 1 changed file with 4 additions and 3 deletions.
7 changes: 4 additions & 3 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,16 @@ def restore(self, checkpoint_path: str, on_gpu: bool):
# load model state
model = self.trainer.get_model()

# load the state_dict on the model automatically
model.load_state_dict(checkpoint['state_dict'])

# give the datamodule a chance to load something
if self.trainer.datamodule is not None:
self.trainer.datamodule.on_load_checkpoint(checkpoint)

# give model a chance to load something
model.on_load_checkpoint(checkpoint)

# load the state_dict on the model automatically
model.load_state_dict(checkpoint['state_dict'])

if on_gpu:
model.cuda(self.trainer.root_gpu)

Expand Down

0 comments on commit dec31b3

Please sign in to comment.