Skip to content

Commit

Permalink
Added changeable extension variable for model checkpoints (#4977)
Browse files Browse the repository at this point in the history
* Added changeable extension variable for model checkpoints

* Removed whitespace

* Removed the last bit of whitespace

* Wrote tests for FILE_EXTENSION

* Fixed formatting issues

* More formatting issues

* Simplify test by just using defaults

* Formatting to PEP8

* Added dummy class that inherits ModelCheckpoint; run only one batch instead of epoch for integration test

* Fixed too much whitespace formatting

* some changes

Co-authored-by: rohitgr7 <rohitgr1998@gmail.com>
  • Loading branch information
janhenriklambrechts and rohitgr7 committed Dec 6, 2020
1 parent 2e838e6 commit b00991e
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 3 deletions.
7 changes: 4 additions & 3 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class ModelCheckpoint(Callback):
Example::
# custom path
# saves a file like: my/path/epoch=0.ckpt
# saves a file like: my/path/epoch=0-step=10.ckpt
>>> checkpoint_callback = ModelCheckpoint(dirpath='my/path/')
By default, dirpath is ``None`` and will be set at runtime to the location
Expand Down Expand Up @@ -140,6 +140,7 @@ class ModelCheckpoint(Callback):

CHECKPOINT_JOIN_CHAR = "-"
CHECKPOINT_NAME_LAST = "last"
FILE_EXTENSION = ".ckpt"

def __init__(
self,
Expand Down Expand Up @@ -442,7 +443,7 @@ def format_checkpoint_name(
)
if ver is not None:
filename = self.CHECKPOINT_JOIN_CHAR.join((filename, f"v{ver}"))
ckpt_name = f"{filename}.ckpt"
ckpt_name = f"{filename}{self.FILE_EXTENSION}"
return os.path.join(self.dirpath, ckpt_name) if self.dirpath else ckpt_name

def __resolve_ckpt_dir(self, trainer, pl_module):
Expand Down Expand Up @@ -545,7 +546,7 @@ def _save_last_checkpoint(self, trainer, pl_module, ckpt_name_metrics, filepath)
ckpt_name_metrics,
prefix=self.prefix
)
last_filepath = os.path.join(self.dirpath, f"{last_filepath}.ckpt")
last_filepath = os.path.join(self.dirpath, f"{last_filepath}{self.FILE_EXTENSION}")

self._save_model(last_filepath, trainer, pl_module)
if (
Expand Down
23 changes: 23 additions & 0 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,29 @@ def test_model_checkpoint_format_checkpoint_name(tmpdir):
assert ckpt_name == filepath / 'test-epoch=3-step=2.ckpt'


class ModelCheckpointExtensionTest(ModelCheckpoint):
FILE_EXTENSION = '.tpkc'


def test_model_checkpoint_file_extension(tmpdir):
"""
Test ModelCheckpoint with different file extension.
"""

model = LogInTwoMethods()
model_checkpoint = ModelCheckpointExtensionTest(monitor='early_stop_on', dirpath=tmpdir, save_top_k=1, save_last=True)
trainer = Trainer(
default_root_dir=tmpdir,
callbacks=[model_checkpoint],
max_steps=1,
logger=False,
)
trainer.fit(model)

expected = ['epoch=0-step=0.tpkc', 'last.tpkc']
assert set(expected) == set(os.listdir(tmpdir))


def test_model_checkpoint_save_last(tmpdir):
"""Tests that save_last produces only one last checkpoint."""
seed_everything()
Expand Down

0 comments on commit b00991e

Please sign in to comment.