From cfafb56c8894f26f82f1f1549aa10f6d8f7fd9c5 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 10 Feb 2022 23:21:38 +0100 Subject: [PATCH 1/6] Add callback to save fault-tolerance checkpoints --- .../callbacks/fault_tolerance.py | 49 +++++++++++++++++++ .../trainer/connectors/callback_connector.py | 29 ++++++----- .../connectors/checkpoint_connector.py | 8 +-- pytorch_lightning/trainer/trainer.py | 20 ++++---- tests/utilities/test_auto_restart.py | 9 ++-- 5 files changed, 85 insertions(+), 30 deletions(-) create mode 100644 pytorch_lightning/callbacks/fault_tolerance.py diff --git a/pytorch_lightning/callbacks/fault_tolerance.py b/pytorch_lightning/callbacks/fault_tolerance.py new file mode 100644 index 0000000000000..ded1272a7e7db --- /dev/null +++ b/pytorch_lightning/callbacks/fault_tolerance.py @@ -0,0 +1,49 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +r""" +Fault-Tolerance +^^^^^^^^^^^^^^^ + +Contains callbacks for fault-tolerance support. These are not meant to be used publicly. +""" +import os +from typing import Any + +import pytorch_lightning as pl +from pytorch_lightning.utilities.types import _PATH + + +class _FaultToleranceCheckpoint(pl.Callback): + """Used to save a fault-tolerance checkpoint on exception.""" + + FILE_EXTENSION = ".ckpt" + + def __init__(self, dirpath: _PATH, filename: str = ".pl_auto_save") -> None: + super().__init__() + # not optional because an exception could occur at any moment, so we cannot wait until the `setup` hook + self.dirpath = dirpath + if not filename: + raise ValueError("The filename cannot be empty") + self.filename = filename + + @property + def ckpt_path(self) -> str: + return os.path.join(self.dirpath, self.filename + self.FILE_EXTENSION) + + def on_exception(self, trainer: "pl.Trainer", *_: Any, **__: Any) -> None: + # overwrite if necessary + trainer.save_checkpoint(self.ckpt_path) + + def teardown(self, trainer: "pl.Trainer", *_: Any, **__: Any) -> None: + trainer.strategy.remove_checkpoint(self.ckpt_path) diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 45a05a446ba77..53fb237fdd7e4 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -26,7 +26,7 @@ ) from pytorch_lightning.callbacks.rich_model_summary import RichModelSummary from pytorch_lightning.callbacks.timer import Timer -from pytorch_lightning.utilities import ModelSummaryMode +from pytorch_lightning.utilities.enums import ModelSummaryMode from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info @@ -102,7 +102,10 @@ def on_trainer_init( # accumulated grads self._configure_accumulated_gradients(accumulate_grad_batches) - # push all checkpoint callbacks to the end + if self.trainer.state._fault_tolerant_mode.is_enabled: + self._configure_fault_tolerance_callbacks() + + # push all model checkpoint callbacks to the end # it is important that these are the last callbacks to run self.trainer.callbacks = self._reorder_callbacks(self.trainer.callbacks) @@ -146,13 +149,12 @@ def _configure_checkpoint_callbacks(self, checkpoint_callback: Optional[bool], e # if both are set then checkpoint only if both are True enable_checkpointing = checkpoint_callback and enable_checkpointing - if self._trainer_has_checkpoint_callbacks() and enable_checkpointing is False: - raise MisconfigurationException( - "Trainer was configured with `enable_checkpointing=False`" - " but found `ModelCheckpoint` in callbacks list." - ) - - if not self._trainer_has_checkpoint_callbacks() and enable_checkpointing is True: + if self.trainer.checkpoint_callbacks: + if not enable_checkpointing: + raise MisconfigurationException( + "Trainer was configured with `enable_checkpointing=False`" + " but found `ModelCheckpoint` in callbacks list." + ) self.trainer.callbacks.append(ModelCheckpoint()) def _configure_model_summary_callback( @@ -252,8 +254,13 @@ def _configure_timer_callback(self, max_time: Optional[Union[str, timedelta, Dic timer = Timer(duration=max_time, interval="step") self.trainer.callbacks.append(timer) - def _trainer_has_checkpoint_callbacks(self): - return len(self.trainer.checkpoint_callbacks) > 0 + def _configure_fault_tolerance_callbacks(self): + from pytorch_lightning.callbacks.fault_tolerance import _FaultToleranceCheckpoint + + if any(isinstance(cb, _FaultToleranceCheckpoint) for cb in self.trainer.callbacks): + raise RuntimeError("There should be only 1 fault-tolerance checkpoint callback.") + # don't use `log_dir` to minimize the chances of failure + self.trainer.callbacks.append(_FaultToleranceCheckpoint(dirpath=self.trainer.default_root_dir)) def _attach_model_logging_functions(self): lightning_module = self.trainer.lightning_module diff --git a/pytorch_lightning/trainer/connectors/checkpoint_connector.py b/pytorch_lightning/trainer/connectors/checkpoint_connector.py index bb6889d797f5c..d5a059b720ee5 100644 --- a/pytorch_lightning/trainer/connectors/checkpoint_connector.py +++ b/pytorch_lightning/trainer/connectors/checkpoint_connector.py @@ -63,12 +63,6 @@ def _hpc_resume_path(self) -> Optional[str]: if max_version is not None: return os.path.join(dir_path_hpc, f"hpc_ckpt_{max_version}.ckpt") - @property - def _fault_tolerant_auto_resume_path(self) -> Optional[str]: - auto_saved_path = os.path.join(str(self.trainer.weights_save_path), ".pl_auto_save.ckpt") - fs = get_filesystem(auto_saved_path) - return auto_saved_path if fs.exists(auto_saved_path) else None - def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None: """Attempts to pre-load the checkpoint file to memory, with the source path determined in this priority: @@ -77,7 +71,7 @@ def resume_start(self, checkpoint_path: Optional[_PATH] = None) -> None: 3. from `checkpoint_path` file if provided 4. don't restore """ - self.resume_checkpoint_path = self._hpc_resume_path or self._fault_tolerant_auto_resume_path or checkpoint_path + self.resume_checkpoint_path = self._hpc_resume_path or checkpoint_path checkpoint_path = self.resume_checkpoint_path if not checkpoint_path: log.detail("`checkpoint_path` not specified. Skipping checkpoint loading.") diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 3663c211e4445..94f166c9f22fb 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -688,7 +688,6 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: if distributed_available() and self.world_size > 1: # try syncing remaining processes, kill otherwise self.strategy.reconciliate_processes(traceback.format_exc()) - self._on_exception() self._call_callback_hooks("on_exception", exception) self._teardown() # teardown might access the stage so we reset it after @@ -761,7 +760,7 @@ def _fit_impl( # TODO: ckpt_path only in v2.0 ckpt_path = ckpt_path or self.resume_from_checkpoint self._ckpt_path = self.__set_ckpt_path( - ckpt_path, model_provided=model, model_connected=self.lightning_module is not None + ckpt_path, model_provided=True, model_connected=self.lightning_module is not None ) results = self._run(model, ckpt_path=self.ckpt_path) @@ -1380,6 +1379,16 @@ def _run_sanity_check(self) -> None: self.state.stage = stage def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_connected: bool) -> Optional[str]: + # fault-tolerance takes precedence + from pytorch_lightning.callbacks.fault_tolerance import _FaultToleranceCheckpoint + + ft_checkpoints = [cb for cb in self.callbacks if isinstance(cb, _FaultToleranceCheckpoint)] + if ft_checkpoints: + ft_ckpt_path = ft_checkpoints[0].ckpt_path + fs = get_filesystem(ft_ckpt_path) + if fs.exists(ft_ckpt_path): + return ft_ckpt_path + if model_provided and ckpt_path is None: # use passed model to function without loading weights return @@ -1753,13 +1762,6 @@ def _log_device_info(self) -> None: " `Trainer(ipus=8)` or script `--ipus=8`." ) - def _on_exception(self) -> None: - if not _fault_tolerant_training(): - return - # save a checkpoint for fault tolerant training. we don't use `log_dir` to minimize the chances of failure. - file_path = os.path.join(self.default_root_dir, ".pl_auto_save.ckpt") - self.save_checkpoint(file_path) - """ Data loading methods """ diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 5faefbc9f71b6..a40ea568f9256 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -1150,8 +1150,12 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on model_signaled = _fit_model( tmpdir, True, val_check_interval, failure_on_step, failure_on_training, on_last_batch, status=status ) - checkpoint_path = str(tmpdir / ".pl_auto_save.ckpt") - assert os.path.exists(checkpoint_path) + # we saved a ft-checkpoint + signaled_ckpt_path = str(tmpdir / ".pl_auto_save.ckpt") + assert os.path.exists(signaled_ckpt_path) + # load for later as the next fit call will delete it + checkpoint = torch.load(signaled_ckpt_path)["loops"]["fit_loop"] + model_restarted = _fit_model(tmpdir, False, val_check_interval, failure_on_step, failure_on_training, on_last_batch) # check the batches @@ -1164,7 +1168,6 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on assert not torch.equal(model_total.layer.weight, model_signaled.layer.weight) assert torch.equal(model_restarted.layer.weight, model_total.layer.weight) - checkpoint = torch.load(checkpoint_path)["loops"]["fit_loop"] p = checkpoint["epoch_loop.batch_progress"] if p["is_last_batch"] and p["current"]["completed"] == 4: assert "dataloader_state_dict" not in checkpoint["epoch_loop.state_dict"] From fecee3a0a06aec861575eb5b55bbe1155b7b146f Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Thu, 10 Feb 2022 23:30:26 +0100 Subject: [PATCH 2/6] Update CHANGELOG --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2782a4cb1d9f1..a3c84a373ab33 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a function to validate if fault tolerant training is supported. ([#10465](https://github.com/PyTorchLightning/pytorch-lightning/pull/10465)) +- Added a private callback to manage the creation and deletion of fault-toelrance checkpoints ([#11862](https://github.com/PyTorchLightning/pytorch-lightning/pull/11862)) + + - Show a better error message when a custom `DataLoader` implementation is not well implemented and we need to reconstruct it ([#10719](https://github.com/PyTorchLightning/pytorch-lightning/pull/10719)) From c28834ba0e237e7a61445b2025cfb81f8e932651 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 11 Feb 2022 00:06:21 +0100 Subject: [PATCH 3/6] Fix --- pytorch_lightning/trainer/connectors/callback_connector.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index 53fb237fdd7e4..c87b0efd42605 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -155,6 +155,7 @@ def _configure_checkpoint_callbacks(self, checkpoint_callback: Optional[bool], e "Trainer was configured with `enable_checkpointing=False`" " but found `ModelCheckpoint` in callbacks list." ) + elif enable_checkpointing: self.trainer.callbacks.append(ModelCheckpoint()) def _configure_model_summary_callback( From 56871ea98d615f173df74849444f1867cb9540fa Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 11 Feb 2022 01:59:17 +0100 Subject: [PATCH 4/6] Fix test --- tests/trainer/test_trainer.py | 39 ++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 587ff0b7b9f72..36ce57a423199 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -2009,31 +2009,36 @@ def on_predict_start(self) -> None: ) def test_error_handling_all_stages(tmpdir, strategy, num_processes): model = TrainerStagesErrorsModel() - trainer = Trainer(default_root_dir=tmpdir, strategy=strategy, num_processes=num_processes, fast_dev_run=True) + exceptions = 0 + + class TestException(Callback): + def on_exception(self, *_): + nonlocal exceptions + exceptions += 1 + + trainer = Trainer( + default_root_dir=tmpdir, + strategy=strategy, + num_processes=num_processes, + callbacks=TestException(), + fast_dev_run=True, + ) - with pytest.raises(Exception, match=r"Error during train"), patch( - "pytorch_lightning.Trainer._on_exception" - ) as exception_hook: + with pytest.raises(Exception, match=r"Error during train"): trainer.fit(model) - exception_hook.assert_called() + assert exceptions == 1 - with pytest.raises(Exception, match=r"Error during validation"), patch( - "pytorch_lightning.Trainer._on_exception" - ) as exception_hook: + with pytest.raises(Exception, match=r"Error during validation"): trainer.validate(model) - exception_hook.assert_called() + assert exceptions == 2 - with pytest.raises(Exception, match=r"Error during test"), patch( - "pytorch_lightning.Trainer._on_exception" - ) as exception_hook: + with pytest.raises(Exception, match=r"Error during test"): trainer.test(model) - exception_hook.assert_called() + assert exceptions == 3 - with pytest.raises(Exception, match=r"Error during predict"), patch( - "pytorch_lightning.Trainer._on_exception" - ) as exception_hook: + with pytest.raises(Exception, match=r"Error during predict"): trainer.predict(model, model.val_dataloader(), return_predictions=False) - exception_hook.assert_called() + assert exceptions == 4 def test_trainer_metrics_reset_before_each_task(tmpdir): From 2ef811b830845038dddbc37b5d4b4b9b40bacefe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Sat, 12 Feb 2022 13:52:46 +0100 Subject: [PATCH 5/6] Update CHANGELOG.md Co-authored-by: Danielle Pintz <38207072+daniellepintz@users.noreply.github.com> --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index a3c84a373ab33..ff74820dbee3b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,7 +33,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a function to validate if fault tolerant training is supported. ([#10465](https://github.com/PyTorchLightning/pytorch-lightning/pull/10465)) -- Added a private callback to manage the creation and deletion of fault-toelrance checkpoints ([#11862](https://github.com/PyTorchLightning/pytorch-lightning/pull/11862)) +- Added a private callback to manage the creation and deletion of fault-tolerance checkpoints ([#11862](https://github.com/PyTorchLightning/pytorch-lightning/pull/11862)) - Show a better error message when a custom `DataLoader` implementation is not well implemented and we need to reconstruct it ([#10719](https://github.com/PyTorchLightning/pytorch-lightning/pull/10719)) From 1750674cce910dd66acb829f6dbf1cf94440ffb9 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Tue, 22 Feb 2022 14:16:14 +0100 Subject: [PATCH 6/6] Fix test --- .../callbacks/fault_tolerance.py | 3 +- .../trainer/connectors/callback_connector.py | 2 +- tests/trainer/test_trainer.py | 34 ++++++++----------- 3 files changed, 18 insertions(+), 21 deletions(-) diff --git a/pytorch_lightning/callbacks/fault_tolerance.py b/pytorch_lightning/callbacks/fault_tolerance.py index ded1272a7e7db..59b8d31f46506 100644 --- a/pytorch_lightning/callbacks/fault_tolerance.py +++ b/pytorch_lightning/callbacks/fault_tolerance.py @@ -21,10 +21,11 @@ from typing import Any import pytorch_lightning as pl +from pytorch_lightning import Callback from pytorch_lightning.utilities.types import _PATH -class _FaultToleranceCheckpoint(pl.Callback): +class _FaultToleranceCheckpoint(Callback): """Used to save a fault-tolerance checkpoint on exception.""" FILE_EXTENSION = ".ckpt" diff --git a/pytorch_lightning/trainer/connectors/callback_connector.py b/pytorch_lightning/trainer/connectors/callback_connector.py index c87b0efd42605..062d1237a0b84 100644 --- a/pytorch_lightning/trainer/connectors/callback_connector.py +++ b/pytorch_lightning/trainer/connectors/callback_connector.py @@ -259,7 +259,7 @@ def _configure_fault_tolerance_callbacks(self): from pytorch_lightning.callbacks.fault_tolerance import _FaultToleranceCheckpoint if any(isinstance(cb, _FaultToleranceCheckpoint) for cb in self.trainer.callbacks): - raise RuntimeError("There should be only 1 fault-tolerance checkpoint callback.") + raise RuntimeError("There should be only one fault-tolerance checkpoint callback.") # don't use `log_dir` to minimize the chances of failure self.trainer.callbacks.append(_FaultToleranceCheckpoint(dirpath=self.trainer.default_root_dir)) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 7d960a39a44c6..4b50230d2de48 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1943,45 +1943,41 @@ def on_predict_start(self) -> None: raise Exception("Error during predict") -@pytest.mark.parametrize( - "strategy,num_processes", - [ - (None, 1), - pytest.param("ddp_spawn", 1, marks=RunIf(skip_windows=True)), - ], -) -def test_error_handling_all_stages(tmpdir, strategy, num_processes): - model = TrainerStagesErrorsModel() +class ExceptionCounter(Callback): exceptions = 0 - class TestException(Callback): - def on_exception(self, *_): - nonlocal exceptions - exceptions += 1 + def on_exception(self, *_): + self.exceptions += 1 + + +@pytest.mark.parametrize("strategy", [None, pytest.param("ddp_spawn", marks=RunIf(skip_windows=True))]) +def test_error_handling_all_stages(tmpdir, strategy): + model = TrainerStagesErrorsModel() + counter = ExceptionCounter() trainer = Trainer( default_root_dir=tmpdir, strategy=strategy, - num_processes=num_processes, - callbacks=TestException(), + devices=1, + callbacks=counter, fast_dev_run=True, ) with pytest.raises(Exception, match=r"Error during train"): trainer.fit(model) - assert exceptions == 1 + assert counter.exceptions == 1 with pytest.raises(Exception, match=r"Error during validation"): trainer.validate(model) - assert exceptions == 2 + assert counter.exceptions == 2 with pytest.raises(Exception, match=r"Error during test"): trainer.test(model) - assert exceptions == 3 + assert counter.exceptions == 3 with pytest.raises(Exception, match=r"Error during predict"): trainer.predict(model, model.val_dataloader(), return_predictions=False) - assert exceptions == 4 + assert counter.exceptions == 4 def test_trainer_metrics_reset_before_each_task(tmpdir):