Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fault Tolerant Manual: Add loading to reload the states #10699

Merged
merged 53 commits into from Nov 23, 2021
Merged
Show file tree
Hide file tree
Changes from 50 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
ca8dbbb
update
tchaton Nov 19, 2021
3f0d28a
update
tchaton Nov 19, 2021
0c47670
update
tchaton Nov 19, 2021
3e5b52e
update
tchaton Nov 19, 2021
24c8245
update
tchaton Nov 19, 2021
bcd5569
update
tchaton Nov 19, 2021
8d3844d
update
tchaton Nov 19, 2021
1829b46
update
tchaton Nov 19, 2021
a1a364a
typo
tchaton Nov 19, 2021
de41675
update on comments
tchaton Nov 22, 2021
8178a32
Update pytorch_lightning/utilities/auto_restart.py
kaushikb11 Nov 22, 2021
00b9355
update
tchaton Nov 22, 2021
96f0517
update
tchaton Nov 22, 2021
297fd67
Merge branch 'fault_tolerant_enum' of https://github.com/PyTorchLight…
tchaton Nov 22, 2021
9800cba
update
tchaton Nov 22, 2021
427ed03
docstring improvement
tchaton Nov 22, 2021
ae712b0
update
tchaton Nov 22, 2021
9a5166d
Rename and simplify
carmocca Nov 22, 2021
b5fa819
Add comment
carmocca Nov 22, 2021
c82b2f2
update
tchaton Nov 22, 2021
2ede205
update
tchaton Nov 22, 2021
b16c4c0
update
tchaton Nov 22, 2021
ce9c23c
update
tchaton Nov 22, 2021
2baddb9
update
tchaton Nov 22, 2021
97548bb
update
tchaton Nov 22, 2021
d953ae9
update
tchaton Nov 22, 2021
41ffbab
use_teardown
tchaton Nov 22, 2021
d04596d
Use `Protocol`
carmocca Nov 22, 2021
ff7b836
Simplify test
carmocca Nov 22, 2021
a5698e6
Update CHANGELOG.md
carmocca Nov 22, 2021
79fdacc
update
tchaton Nov 22, 2021
916b520
update
tchaton Nov 22, 2021
4b67fbf
update
tchaton Nov 22, 2021
c9481e2
update
tchaton Nov 22, 2021
ef29342
update
tchaton Nov 22, 2021
4a1fff7
update
tchaton Nov 22, 2021
cb27e30
update
tchaton Nov 22, 2021
7903d24
resolve tests
tchaton Nov 22, 2021
20d19a1
update
tchaton Nov 22, 2021
1104cbc
update
tchaton Nov 23, 2021
f071f9a
change to 0
tchaton Nov 23, 2021
b777dc3
update
tchaton Nov 23, 2021
2da1674
update
tchaton Nov 23, 2021
647bebd
merge with master
tchaton Nov 23, 2021
dbcfa65
update changelog
tchaton Nov 23, 2021
ae18166
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 23, 2021
a527929
update
tchaton Nov 23, 2021
5827e8b
Merge branch 'add_reloading' of https://github.com/PyTorchLightning/p…
tchaton Nov 23, 2021
51cf75b
update
tchaton Nov 23, 2021
35644b8
update on comments
tchaton Nov 23, 2021
2aeda3b
update on comments
tchaton Nov 23, 2021
fe3afd8
update
tchaton Nov 23, 2021
421d869
update
tchaton Nov 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Add `_FaultTolerantMode` enum used to track different supported fault tolerant modes ([#10645](https://github.com/PyTorchLightning/pytorch-lightning/issues/10645))
* Add a `_rotate_worker_indices` utility to reload the state according the latest worker ([#10647](https://github.com/PyTorchLightning/pytorch-lightning/issues/10647))
* Add stateful workers ([#10674](https://github.com/PyTorchLightning/pytorch-lightning/issues/10674))
* Add logic to reload the states across data loading components ([#10699](https://github.com/PyTorchLightning/pytorch-lightning/issues/10699))


-
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Expand Up @@ -22,7 +22,7 @@
from pytorch_lightning.loops.base import Loop
from pytorch_lightning.loops.utilities import _update_dataloader_iter
from pytorch_lightning.trainer.progress import BatchProgress
from pytorch_lightning.utilities.auto_restart import MergedIteratorState, reload_dataloader_state_dict
from pytorch_lightning.utilities.auto_restart import _reload_dataloader_state_dict, MergedIteratorState
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataFetcher
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
Expand Down Expand Up @@ -182,7 +182,7 @@ def on_load_checkpoint(self, state_dict: Dict) -> None:

def _reload_dataloader_state_dict(self, data_fetcher: AbstractDataFetcher):
if not self.trainer.sanity_checking and self._dataloader_state_dict:
reload_dataloader_state_dict(data_fetcher.dataloader, self._dataloader_state_dict)
_reload_dataloader_state_dict(data_fetcher.dataloader, self._dataloader_state_dict)
self._dataloader_state_dict = None

def _num_completed_batches_reached(self) -> bool:
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/supporters.py
Expand Up @@ -24,9 +24,9 @@

from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections
from pytorch_lightning.utilities.auto_restart import (
_reload_dataloader_state_dict,
MergedIteratorState,
patch_dataloader_iterator,
reload_dataloader_state_dict,
)
from pytorch_lightning.utilities.data import get_len
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -403,7 +403,7 @@ def create_loader_iters(dataloader: DataLoader, state_dict: Dict) -> Iterator:
if isinstance(dataloader, CycleIterator):
dataloader = dataloader_to_iter_on.loader

reload_dataloader_state_dict(dataloader, state_dict)
_reload_dataloader_state_dict(dataloader, state_dict)

# We finally spawned the workers if any.
it = iter(dataloader_to_iter_on)
Expand Down
74 changes: 53 additions & 21 deletions pytorch_lightning/utilities/auto_restart.py
Expand Up @@ -33,7 +33,6 @@
import pytorch_lightning as pl
from pytorch_lightning.utilities.enums import _FaultTolerantMode, AutoRestartBatchKeys
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training


class FastForwardSampler(Sampler):
Expand Down Expand Up @@ -564,35 +563,69 @@ def _add_capture_metadata_collate(dataloader: DataLoader) -> None:
)


def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None:
def _reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"""Utility to reload state_dict within dataloader for fault tolerance."""

if not _fault_tolerant_training():
fault_tolerant_mode = _FaultTolerantMode.detect_current_mode()

if not fault_tolerant_mode.is_enabled:
return

dataset = dataloader.dataset

if isinstance(dataset, CaptureMapDataset):
iterator_state = state_dict["state"][0]
if fault_tolerant_mode.is_automatic:
if isinstance(dataset, CaptureMapDataset):
iterator_state = state_dict["state"][0]

if not isinstance(iterator_state, IteratorState):
iterator_state = IteratorState.from_state_dict(iterator_state)
if not isinstance(iterator_state, IteratorState):
iterator_state = IteratorState.from_state_dict(iterator_state)

# reload sampler state
ff_sampler = _find_fast_forward_samplers(dataloader)
ff_sampler.load_state_dict(iterator_state.sampler_state)
# reload sampler state
ff_sampler = _find_fast_forward_samplers(dataloader)
ff_sampler.load_state_dict(iterator_state.sampler_state)

# reload dataset state
dataset.load_state_dict(
iterator_state.dataset_state,
latest_worker_id=state_dict["latest_worker_id"],
num_workers=iterator_state.num_workers,
)
# reload dataset state
dataset.load_state_dict(
iterator_state.dataset_state,
latest_worker_id=state_dict["latest_worker_id"],
num_workers=iterator_state.num_workers,
)

elif isinstance(dataset, CaptureIterableDataset):
dataset.load_state_dict(
{sampler_name: state[0]["sampler_state"] for sampler_name, state in state_dict["state"].items()}
)
elif isinstance(dataset, CaptureIterableDataset):
dataset.load_state_dict(
{sampler_name: state[0]["sampler_state"] for sampler_name, state in state_dict["state"].items()}
)

else:
raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.")

elif fault_tolerant_mode.is_manual:

# In manual mode, we don't wrap the user objects with `CaptureMapDataset` or `CaptureIterableDataset`
# therefore, we need to reload the states manually.

latest_worker_id = state_dict["latest_worker_id"]
num_workers = state_dict["state"][latest_worker_id]["num_workers"]
sampler_state = state_dict["state"][latest_worker_id]["sampler_state"]
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if sampler_state:
for k in sampler_state:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
obj = getattr(dataloader, k)
if not isinstance(obj, _SupportsStateDict):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
raise MisconfigurationException(
f"The DataLoader attribute {k}:{obj} should have a `load_state_dict` method."
)

obj.load_state_dict(sampler_state[k])

if not isinstance(dataset, _SupportsStateDict):
return

dataset_state = {
worker_id: state_dict["state"][worker_id]["dataset_state"][worker_id]
for worker_id in state_dict["state"].keys()
}

dataset.load_state_dict(_rotate_worker_indices(dataset_state, latest_worker_id, num_workers))
tchaton marked this conversation as resolved.
Show resolved Hide resolved

else:
raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.")
Expand Down Expand Up @@ -638,7 +671,6 @@ def _store_sampler_state(self) -> None:
for k, v in self._loader.__dict__.items()
if isinstance(v, _SupportsStateDict) and k != "dataset"
}

self.__accumulate_state(sampler_state)

def _next_index(self) -> Any:
Expand Down
10 changes: 8 additions & 2 deletions tests/utilities/test_auto_restart.py
Expand Up @@ -19,6 +19,7 @@
from collections.abc import Iterable
from contextlib import suppress
from copy import deepcopy
from dataclasses import asdict
from typing import List, Optional
from unittest import mock
from unittest.mock import ANY
Expand All @@ -42,6 +43,7 @@
_dataloader_to_state_dict,
_MultiProcessingDataLoaderIterStateful,
_patch_dataloader_get_iterators,
_reload_dataloader_state_dict,
_rotate_worker_indices,
_SingleProcessDataLoaderIterStateful,
_SupportsStateDict,
Expand Down Expand Up @@ -1289,7 +1291,7 @@ def state_dict(self):
return {"counter": self.counter}

def load_state_dict(self, state_dict):
self.counter = state_dict["counter"]
self.counter = state_dict[0]["counter"]


@pytest.mark.parametrize("num_workers", [0])
Expand Down Expand Up @@ -1319,7 +1321,9 @@ def test_stateful_workers(num_workers):
assert isinstance(dataloader_iter, worker_type)

next(data_fetcher_iter)
state = data_fetcher.dataloader_iter.state.state

reloaded_state = deepcopy(data_fetcher.dataloader_iter.state)
state = reloaded_state.state
assert state[0].dataset_state == {0: {"counter": 1}}
assert state[0].sampler_state["sampler"] == {"counter": 1}

Expand Down Expand Up @@ -1350,4 +1354,6 @@ def test_stateful_workers(num_workers):
assert not hasattr(DataLoader, "_ori_get_iterator")
assert DataLoader._get_iterator == _get_iterator_fn

_reload_dataloader_state_dict(dataloader, asdict(reloaded_state))
assert dataloader.sampler.counter == dataloader.dataset.counter == 1
data_fetcher.teardown()