Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fault Tolerant Manual: Add _rotate_worker_indices utility #10647

Merged
merged 9 commits into from Nov 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fault Tolerant Manual
* Add `_SupportsStateDict` protocol to detect if classes are stateful ([#10646](https://github.com/PyTorchLightning/pytorch-lightning/issues/10646))
* Add `_FaultTolerantMode` enum used to track different supported fault tolerant modes ([#10645](https://github.com/PyTorchLightning/pytorch-lightning/issues/10645))
* Add a `_rotate_worker_indices` utility to reload the state according the latest worker ([#10647](https://github.com/PyTorchLightning/pytorch-lightning/issues/10647))


-
Expand Down
23 changes: 15 additions & 8 deletions pytorch_lightning/utilities/auto_restart.py
Expand Up @@ -247,14 +247,7 @@ def __len__(self) -> int:
def load_state_dict(self, state_dict: Dict[int, Any], latest_worker_id: int, num_workers: int) -> None:
# as workers aren't available, the ``state_dict``` is cached until workers are made available.
state_dict = deepcopy(state_dict)

if num_workers > 0:
# remap states to worker ids starting at 0
next_worker_id = latest_worker_id + 1
old_to_new_worker_id_map = [((next_worker_id + i) % num_workers, i) for i in range(num_workers)]
state_dict = {
new_id: state_dict[old_id] for old_id, new_id in old_to_new_worker_id_map if old_id in state_dict
}
state_dict = _rotate_worker_indices(state_dict, latest_worker_id, num_workers)
self._cached_state_dict = state_dict

def state_dict(self) -> Dict[int, Dict[str, Any]]:
Expand Down Expand Up @@ -573,6 +566,20 @@ def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, A
raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.")


def _rotate_worker_indices(state: Dict[int, Any], latest_worker_id: int, num_workers: int) -> Dict[int, Any]:
"""This function is used to rotate the worker indices based on the `latest_worker_id` the training failed
on."""
if num_workers == 0:
return state
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if latest_worker_id > num_workers - 1:
raise MisconfigurationException("The `latest_worker_id` should be within [0, num_workers - 1].")
if len(state) != num_workers:
raise MisconfigurationException("The `state` should contain `num_workers - 1` values.")
next_worker_id = latest_worker_id + 1
old_to_new_worker_id_map = [((next_worker_id + i) % num_workers, i) for i in range(num_workers)]
return {new_id: state[old_id] for old_id, new_id in old_to_new_worker_id_map if old_id in state}
tchaton marked this conversation as resolved.
Show resolved Hide resolved


@runtime_checkable
class _SupportsStateDict(Protocol):
"""This class is used to detect if an object is stateful using `isinstance(obj, _SupportsStateDict)`."""
Expand Down
14 changes: 14 additions & 0 deletions tests/utilities/test_auto_restart.py
Expand Up @@ -40,6 +40,7 @@
_add_capture_metadata_collate,
_dataloader_load_state_dict,
_dataloader_to_state_dict,
_rotate_worker_indices,
_SupportsStateDict,
CaptureIterableDataset,
CaptureMapDataset,
Expand Down Expand Up @@ -1196,6 +1197,19 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on
assert "dataloader_state_dict" in state_dict


def test_rotate_worker_indices():
"""This test ensures `worker_id` are rotated properly depending on which one was the latest."""
state_dict = {0: 0, 1: 1}
assert _rotate_worker_indices(state_dict, 0, 2) == {0: 1, 1: 0}
assert _rotate_worker_indices(state_dict, 1, 2) == {0: 0, 1: 1}

with pytest.raises(MisconfigurationException, match="The `latest_worker_id` should be within"):
_rotate_worker_indices(state_dict, 2, 2)

with pytest.raises(MisconfigurationException, match="The `state` should contain"):
_rotate_worker_indices(state_dict, 2, 3)


def test_supports_state_dict_protocol():
class StatefulClass:
def state_dict(self):
Expand Down