From fb45a7e228dee97a29b221b53dbaeafc6b51d3f5 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 19 Nov 2021 19:04:26 +0000 Subject: [PATCH 1/6] update --- pytorch_lightning/utilities/auto_restart.py | 19 +++++++++++-------- tests/utilities/test_auto_restart.py | 7 +++++++ 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index ef52717636d90..e986a18f37657 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -247,14 +247,7 @@ def __len__(self) -> int: def load_state_dict(self, state_dict: Dict[int, Any], latest_worker_id: int, num_workers: int) -> None: # as workers aren't available, the ``state_dict``` is cached until workers are made available. state_dict = deepcopy(state_dict) - - if num_workers > 0: - # remap states to worker ids starting at 0 - next_worker_id = latest_worker_id + 1 - old_to_new_worker_id_map = [((next_worker_id + i) % num_workers, i) for i in range(num_workers)] - state_dict = { - new_id: state_dict[old_id] for old_id, new_id in old_to_new_worker_id_map if old_id in state_dict - } + state_dict = _rotate_worker_indices(state_dict, latest_worker_id, num_workers) self._cached_state_dict = state_dict def state_dict(self) -> Dict[int, Dict[str, Any]]: @@ -571,3 +564,13 @@ def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, A else: raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.") + + +def _rotate_worker_indices(state, latest_worker_id: int, num_workers: int): + """This function is used to rotate the worker indices based on the `latest_worker_id` the training failed + on.""" + if num_workers == 0: + return state + next_worker_id = latest_worker_id + 1 + old_to_new_worker_id_map = [((next_worker_id + i) % num_workers, i) for i in range(num_workers)] + return {new_id: state[old_id] for old_id, new_id in old_to_new_worker_id_map if old_id in state} diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index b36a9d1d76941..ff99fdbc09f59 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -39,6 +39,7 @@ _add_capture_metadata_collate, _dataloader_load_state_dict, _dataloader_to_state_dict, + _rotate_worker_indices, CaptureIterableDataset, CaptureMapDataset, FastForwardSampler, @@ -1192,3 +1193,9 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on assert "dataloader_state_dict" not in state_dict else: assert "dataloader_state_dict" in state_dict + + +def test_rotate_worker_indices(): + state_dict = {0: 0, 1: 1} + assert _rotate_worker_indices(state_dict, 0, 2) == {0: 1, 1: 0} + assert _rotate_worker_indices(state_dict, 1, 2) == {0: 0, 1: 1} From 8b0ab8d5e7808e2ecd1d3034df997deb71fab4db Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 19 Nov 2021 19:05:14 +0000 Subject: [PATCH 2/6] update --- tests/utilities/test_auto_restart.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index ff99fdbc09f59..5e29acfcc72c2 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -1196,6 +1196,7 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on def test_rotate_worker_indices(): + """This test ensures `worker_id` are rotated properly depending on which one was the latest.""" state_dict = {0: 0, 1: 1} assert _rotate_worker_indices(state_dict, 0, 2) == {0: 1, 1: 0} assert _rotate_worker_indices(state_dict, 1, 2) == {0: 0, 1: 1} From 8abcae2a6af2a19abd98c17c263faaeff014cdc3 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 19 Nov 2021 19:08:33 +0000 Subject: [PATCH 3/6] update changelog --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 96830878d84ae..d011b92d32bc7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a flag `SLURMEnvironment(auto_requeue=True|False)` to control whether Lightning handles the requeuing ([#10601](https://github.com/PyTorchLightning/pytorch-lightning/issues/10601)) -- +- Fault Tolerant Manual + * Add a `_rotate_worker_indices` utility to reload the state according the latest worker ([#10647](https://github.com/PyTorchLightning/pytorch-lightning/issues/10647)) - From 1bb3b8aaac6b144b3af0d5c37dc5b8607d45147a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 22 Nov 2021 15:23:51 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/utilities/test_auto_restart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 5613c31292088..e90c3c3afafcb 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -1202,7 +1202,7 @@ def test_rotate_worker_indices(): assert _rotate_worker_indices(state_dict, 0, 2) == {0: 1, 1: 0} assert _rotate_worker_indices(state_dict, 1, 2) == {0: 0, 1: 1} - + def test_fault_tolerant_mode_enum(): with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "0"}): assert _FaultTolerantMode.DISABLED == _FaultTolerantMode.detect_current_mode() From 289764ad37aba0c5afcb7a288fe9751c5506b54b Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 22 Nov 2021 18:02:03 +0000 Subject: [PATCH 5/6] update --- pytorch_lightning/utilities/auto_restart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index e986a18f37657..9e64a933c4a95 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -566,7 +566,7 @@ def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, A raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.") -def _rotate_worker_indices(state, latest_worker_id: int, num_workers: int): +def _rotate_worker_indices(state: Dict[int, Any], latest_worker_id: int, num_workers: int) -> Dict[int, Any]: """This function is used to rotate the worker indices based on the `latest_worker_id` the training failed on.""" if num_workers == 0: From 99d25961762489d4fc8444a4c76cb45654c240f2 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 22 Nov 2021 18:28:48 +0000 Subject: [PATCH 6/6] update --- pytorch_lightning/utilities/auto_restart.py | 6 +++++- tests/utilities/test_auto_restart.py | 6 ++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 67352d9a10d91..2d04805e9dd35 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -565,11 +565,15 @@ def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, A raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.") -def _rotate_worker_indices(state, latest_worker_id: int, num_workers: int): +def _rotate_worker_indices(state: Dict[int, Any], latest_worker_id: int, num_workers: int) -> Dict[int, Any]: """This function is used to rotate the worker indices based on the `latest_worker_id` the training failed on.""" if num_workers == 0: return state + if latest_worker_id > num_workers - 1: + raise MisconfigurationException("The `latest_worker_id` should be within [0, num_workers - 1].") + if len(state) != num_workers: + raise MisconfigurationException("The `state` should contain `num_workers - 1` values.") next_worker_id = latest_worker_id + 1 old_to_new_worker_id_map = [((next_worker_id + i) % num_workers, i) for i in range(num_workers)] return {new_id: state[old_id] for old_id, new_id in old_to_new_worker_id_map if old_id in state} diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index e90c3c3afafcb..5852c9fbdad45 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -1202,6 +1202,12 @@ def test_rotate_worker_indices(): assert _rotate_worker_indices(state_dict, 0, 2) == {0: 1, 1: 0} assert _rotate_worker_indices(state_dict, 1, 2) == {0: 0, 1: 1} + with pytest.raises(MisconfigurationException, match="The `latest_worker_id` should be within"): + _rotate_worker_indices(state_dict, 2, 2) + + with pytest.raises(MisconfigurationException, match="The `state` should contain"): + _rotate_worker_indices(state_dict, 2, 3) + def test_fault_tolerant_mode_enum(): with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "0"}):