Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add callback to manage fault-tolerance checkpoints #11862

Merged
merged 7 commits into from Feb 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -36,6 +36,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a function to validate if fault tolerant training is supported. ([#10465](https://github.com/PyTorchLightning/pytorch-lightning/pull/10465))


- Added a private callback to manage the creation and deletion of fault-tolerance checkpoints ([#11862](https://github.com/PyTorchLightning/pytorch-lightning/pull/11862))


- Show a better error message when a custom `DataLoader` implementation is not well implemented and we need to reconstruct it ([#10719](https://github.com/PyTorchLightning/pytorch-lightning/pull/10719))


Expand Down
50 changes: 50 additions & 0 deletions pytorch_lightning/callbacks/fault_tolerance.py
@@ -0,0 +1,50 @@
# Copyright The PyTorch Lightning team.
carmocca marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r"""
Fault-Tolerance
^^^^^^^^^^^^^^^

Contains callbacks for fault-tolerance support. These are not meant to be used publicly.
"""
import os
from typing import Any

import pytorch_lightning as pl
from pytorch_lightning import Callback
from pytorch_lightning.utilities.types import _PATH


class _FaultToleranceCheckpoint(Callback):
"""Used to save a fault-tolerance checkpoint on exception."""

FILE_EXTENSION = ".ckpt"

def __init__(self, dirpath: _PATH, filename: str = ".pl_auto_save") -> None:
super().__init__()
# not optional because an exception could occur at any moment, so we cannot wait until the `setup` hook
self.dirpath = dirpath
if not filename:
raise ValueError("The filename cannot be empty")
self.filename = filename

@property
def ckpt_path(self) -> str:
return os.path.join(self.dirpath, self.filename + self.FILE_EXTENSION)

def on_exception(self, trainer: "pl.Trainer", *_: Any, **__: Any) -> None:
# overwrite if necessary
trainer.save_checkpoint(self.ckpt_path)
carmocca marked this conversation as resolved.
Show resolved Hide resolved

def teardown(self, trainer: "pl.Trainer", *_: Any, **__: Any) -> None:
trainer.strategy.remove_checkpoint(self.ckpt_path)
carmocca marked this conversation as resolved.
Show resolved Hide resolved
30 changes: 19 additions & 11 deletions pytorch_lightning/trainer/connectors/callback_connector.py
Expand Up @@ -26,7 +26,7 @@
)
from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary
from pytorch_lightning.callbacks.timer import Timer
from pytorch_lightning.utilities import ModelSummaryMode
from pytorch_lightning.utilities.enums import ModelSummaryMode
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info

Expand Down Expand Up @@ -102,7 +102,10 @@ def on_trainer_init(
# accumulated grads
self._configure_accumulated_gradients(accumulate_grad_batches)

# push all checkpoint callbacks to the end
if self.trainer.state._fault_tolerant_mode.is_enabled:
self._configure_fault_tolerance_callbacks()

# push all model checkpoint callbacks to the end
# it is important that these are the last callbacks to run
self.trainer.callbacks = self._reorder_callbacks(self.trainer.callbacks)

Expand Down Expand Up @@ -146,13 +149,13 @@ def _configure_checkpoint_callbacks(self, checkpoint_callback: Optional[bool], e
# if both are set then checkpoint only if both are True
enable_checkpointing = checkpoint_callback and enable_checkpointing

if self._trainer_has_checkpoint_callbacks() and enable_checkpointing is False:
raise MisconfigurationException(
"Trainer was configured with `enable_checkpointing=False`"
" but found `ModelCheckpoint` in callbacks list."
)

if not self._trainer_has_checkpoint_callbacks() and enable_checkpointing is True:
if self.trainer.checkpoint_callbacks:
if not enable_checkpointing:
raise MisconfigurationException(
"Trainer was configured with `enable_checkpointing=False`"
" but found `ModelCheckpoint` in callbacks list."
)
elif enable_checkpointing:
self.trainer.callbacks.append(ModelCheckpoint())

def _configure_model_summary_callback(
Expand Down Expand Up @@ -252,8 +255,13 @@ def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, Dic
timer = Timer(duration=max_time, interval="step")
self.trainer.callbacks.append(timer)

def _trainer_has_checkpoint_callbacks(self):
return len(self.trainer.checkpoint_callbacks) > 0
def _configure_fault_tolerance_callbacks(self):
from pytorch_lightning.callbacks.fault_tolerance import _FaultToleranceCheckpoint

if any(isinstance(cb, _FaultToleranceCheckpoint) for cb in self.trainer.callbacks):
raise RuntimeError("There should be only one fault-tolerance checkpoint callback.")
# don't use `log_dir` to minimize the chances of failure
self.trainer.callbacks.append(_FaultToleranceCheckpoint(dirpath=self.trainer.default_root_dir))

def _attach_model_logging_functions(self):
lightning_module = self.trainer.lightning_module
Expand Down
Expand Up @@ -64,12 +64,6 @@ def _hpc_resume_path(self) -> Optional[str]:
if max_version is not None:
return os.path.join(dir_path_hpc, f"hpc_ckpt_{max_version}.ckpt")

@property
def _fault_tolerant_auto_resume_path(self) -> Optional[str]:
auto_saved_path = os.path.join(str(self.trainer.weights_save_path), ".pl_auto_save.ckpt")
fs = get_filesystem(auto_saved_path)
return auto_saved_path if fs.exists(auto_saved_path) else None

def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None:
"""Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority:

Expand All @@ -78,7 +72,7 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None:
3. from `checkpoint_path` file if provided
4. don't restore
"""
self.resume_checkpoint_path = self._hpc_resume_path or self._fault_tolerant_auto_resume_path or checkpoint_path
self.resume_checkpoint_path = self._hpc_resume_path or checkpoint_path
checkpoint_path = self.resume_checkpoint_path
if not checkpoint_path:
log.detail("`checkpoint_path` not specified. Skipping checkpoint loading.")
Expand Down
20 changes: 11 additions & 9 deletions pytorch_lightning/trainer/trainer.py
Expand Up @@ -687,7 +687,6 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs:
if distributed_available() and self.world_size > 1:
# try syncing remaining processes, kill otherwise
self.strategy.reconciliate_processes(traceback.format_exc())
self._on_exception()
self._call_callback_hooks("on_exception", exception)
self._teardown()
# teardown might access the stage so we reset it after
Expand Down Expand Up @@ -760,7 +759,7 @@ def _fit_impl(
# TODO: ckpt_path only in v2.0
ckpt_path = ckpt_path or self.resume_from_checkpoint
self._ckpt_path = self.__set_ckpt_path(
ckpt_path, model_provided=model, model_connected=self.lightning_module is not None
carmocca marked this conversation as resolved.
Show resolved Hide resolved
ckpt_path, model_provided=True, model_connected=self.lightning_module is not None
)
results = self._run(model, ckpt_path=self.ckpt_path)

Expand Down Expand Up @@ -1377,6 +1376,16 @@ def _run_sanity_check(self) -> None:
self.state.stage = stage

def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_connected: bool) -> Optional[str]:
# fault-tolerance takes precedence
from pytorch_lightning.callbacks.fault_tolerance import _FaultToleranceCheckpoint

ft_checkpoints = [cb for cb in self.callbacks if isinstance(cb, _FaultToleranceCheckpoint)]
if ft_checkpoints:
ft_ckpt_path = ft_checkpoints[0].ckpt_path
fs = get_filesystem(ft_ckpt_path)
if fs.exists(ft_ckpt_path):
return ft_ckpt_path

if model_provided and ckpt_path is None:
# use passed model to function without loading weights
return
Expand Down Expand Up @@ -1750,13 +1759,6 @@ def _log_device_info(self) -> None:
" `Trainer(ipus=8)` or script `--ipus=8`."
)

def _on_exception(self) -> None:
if not _fault_tolerant_training():
return
# save a checkpoint for fault tolerant training. we don't use `log_dir` to minimize the chances of failure.
file_path = os.path.join(self.default_root_dir, ".pl_auto_save.ckpt")
self.save_checkpoint(file_path)

"""
Data loading methods
"""
Expand Down
51 changes: 26 additions & 25 deletions tests/trainer/test_trainer.py
Expand Up @@ -1943,40 +1943,41 @@ def on_predict_start(self) -> None:
raise Exception("Error during predict")


@pytest.mark.parametrize(
"strategy,num_processes",
[
(None, 1),
pytest.param("ddp_spawn", 1, marks=RunIf(skip_windows=True)),
],
)
def test_error_handling_all_stages(tmpdir, strategy, num_processes):
class ExceptionCounter(Callback):
exceptions = 0

def on_exception(self, *_):
self.exceptions += 1


@pytest.mark.parametrize("strategy", [None, pytest.param("ddp_spawn", marks=RunIf(skip_windows=True))])
def test_error_handling_all_stages(tmpdir, strategy):
model = TrainerStagesErrorsModel()
trainer = Trainer(default_root_dir=tmpdir, strategy=strategy, num_processes=num_processes, fast_dev_run=True)
counter = ExceptionCounter()

trainer = Trainer(
default_root_dir=tmpdir,
strategy=strategy,
devices=1,
callbacks=counter,
fast_dev_run=True,
)

with pytest.raises(Exception, match=r"Error during train"), patch(
"pytorch_lightning.Trainer._on_exception"
) as exception_hook:
with pytest.raises(Exception, match=r"Error during train"):
trainer.fit(model)
exception_hook.assert_called()
assert counter.exceptions == 1

with pytest.raises(Exception, match=r"Error during validation"), patch(
"pytorch_lightning.Trainer._on_exception"
) as exception_hook:
with pytest.raises(Exception, match=r"Error during validation"):
trainer.validate(model)
exception_hook.assert_called()
assert counter.exceptions == 2

with pytest.raises(Exception, match=r"Error during test"), patch(
"pytorch_lightning.Trainer._on_exception"
) as exception_hook:
with pytest.raises(Exception, match=r"Error during test"):
trainer.test(model)
exception_hook.assert_called()
assert counter.exceptions == 3

with pytest.raises(Exception, match=r"Error during predict"), patch(
"pytorch_lightning.Trainer._on_exception"
) as exception_hook:
with pytest.raises(Exception, match=r"Error during predict"):
trainer.predict(model, model.val_dataloader(), return_predictions=False)
exception_hook.assert_called()
assert counter.exceptions == 4


def test_trainer_metrics_reset_before_each_task(tmpdir):
Expand Down
9 changes: 6 additions & 3 deletions tests/utilities/test_auto_restart.py
Expand Up @@ -1150,8 +1150,12 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on
model_signaled = _fit_model(
tmpdir, True, val_check_interval, failure_on_step, failure_on_training, on_last_batch, status=status
)
checkpoint_path = str(tmpdir / ".pl_auto_save.ckpt")
assert os.path.exists(checkpoint_path)
# we saved a ft-checkpoint
signaled_ckpt_path = str(tmpdir / ".pl_auto_save.ckpt")
assert os.path.exists(signaled_ckpt_path)
# load for later as the next fit call will delete it
checkpoint = torch.load(signaled_ckpt_path)["loops"]["fit_loop"]

model_restarted = _fit_model(tmpdir, False, val_check_interval, failure_on_step, failure_on_training, on_last_batch)

# check the batches
Expand All @@ -1164,7 +1168,6 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on
assert not torch.equal(model_total.layer.weight, model_signaled.layer.weight)
assert torch.equal(model_restarted.layer.weight, model_total.layer.weight)

checkpoint = torch.load(checkpoint_path)["loops"]["fit_loop"]
p = checkpoint["epoch_loop.batch_progress"]
if p["is_last_batch"] and p["current"]["completed"] == 4:
assert "dataloader_state_dict" not in checkpoint["epoch_loop.state_dict"]
Expand Down