Skip to content

Commit

Permalink
Fault Tolerant Manual: Add loading to reload the states (#10699)
Browse files Browse the repository at this point in the history
Co-authored-by: Kaushik B <45285388+kaushikb11@users.noreply.github.com>
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
4 people committed Nov 23, 2021
1 parent dca1776 commit b28ab34
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 31 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Expand Up @@ -18,7 +18,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Add a `_rotate_worker_indices` utility to reload the state according the latest worker ([#10647](https://github.com/PyTorchLightning/pytorch-lightning/issues/10647))
* Add stateful workers ([#10674](https://github.com/PyTorchLightning/pytorch-lightning/issues/10674))
* Add an utility to collect the states across processes ([#10639](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639))

* Add logic to reload the states across data loading components ([#10699](https://github.com/PyTorchLightning/pytorch-lightning/issues/10699))

-

Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Expand Up @@ -22,7 +22,7 @@
from pytorch_lightning.loops.base import Loop
from pytorch_lightning.loops.utilities import _update_dataloader_iter
from pytorch_lightning.trainer.progress import BatchProgress
from pytorch_lightning.utilities.auto_restart import MergedIteratorState, reload_dataloader_state_dict
from pytorch_lightning.utilities.auto_restart import _reload_dataloader_state_dict, MergedIteratorState
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataFetcher
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
Expand Down Expand Up @@ -182,7 +182,7 @@ def on_load_checkpoint(self, state_dict: Dict) -> None:

def _reload_dataloader_state_dict(self, data_fetcher: AbstractDataFetcher):
if not self.trainer.sanity_checking and self._dataloader_state_dict:
reload_dataloader_state_dict(data_fetcher.dataloader, self._dataloader_state_dict)
_reload_dataloader_state_dict(data_fetcher.dataloader, self._dataloader_state_dict)
self._dataloader_state_dict = None

def _num_completed_batches_reached(self) -> bool:
Expand Down
4 changes: 2 additions & 2 deletions pytorch_lightning/trainer/supporters.py
Expand Up @@ -24,9 +24,9 @@

from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections
from pytorch_lightning.utilities.auto_restart import (
_reload_dataloader_state_dict,
MergedIteratorState,
patch_dataloader_iterator,
reload_dataloader_state_dict,
)
from pytorch_lightning.utilities.data import get_len
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -403,7 +403,7 @@ def create_loader_iters(dataloader: DataLoader, state_dict: Dict) -> Iterator:
if isinstance(dataloader, CycleIterator):
dataloader = dataloader_to_iter_on.loader

reload_dataloader_state_dict(dataloader, state_dict)
_reload_dataloader_state_dict(dataloader, state_dict)

# We finally spawned the workers if any.
it = iter(dataloader_to_iter_on)
Expand Down
98 changes: 74 additions & 24 deletions pytorch_lightning/utilities/auto_restart.py
Expand Up @@ -33,7 +33,6 @@
import pytorch_lightning as pl
from pytorch_lightning.utilities.enums import _FaultTolerantMode, AutoRestartBatchKeys
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training


class FastForwardSampler(Sampler):
Expand Down Expand Up @@ -564,38 +563,90 @@ def _add_capture_metadata_collate(dataloader: DataLoader) -> None:
)


def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None:
"""Utility to reload state_dict within dataloader for fault tolerance."""
def _reload_dataloader_state_dict_automatic_map_dataset(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None:
iterator_state = state_dict["state"][0]

if not _fault_tolerant_training():
return
if not isinstance(iterator_state, IteratorState):
iterator_state = IteratorState.from_state_dict(iterator_state)

dataset = dataloader.dataset
# reload sampler state
ff_sampler = _find_fast_forward_samplers(dataloader)
ff_sampler.load_state_dict(iterator_state.sampler_state)

if isinstance(dataset, CaptureMapDataset):
iterator_state = state_dict["state"][0]
# reload dataset state
dataloader.dataset.load_state_dict(
iterator_state.dataset_state,
latest_worker_id=state_dict["latest_worker_id"],
num_workers=iterator_state.num_workers,
)

if not isinstance(iterator_state, IteratorState):
iterator_state = IteratorState.from_state_dict(iterator_state)

# reload sampler state
ff_sampler = _find_fast_forward_samplers(dataloader)
ff_sampler.load_state_dict(iterator_state.sampler_state)
def _reload_dataloader_state_dict_automatic_iterable_dataset(
dataset: CaptureIterableDataset, state_dict: Dict[str, Any]
) -> None:
dataset.load_state_dict(
{sampler_name: state[0]["sampler_state"] for sampler_name, state in state_dict["state"].items()}
)

# reload dataset state
dataset.load_state_dict(
iterator_state.dataset_state,
latest_worker_id=state_dict["latest_worker_id"],
num_workers=iterator_state.num_workers,
)

def _reload_dataloader_state_dict_automatic(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None:
dataset = dataloader.dataset
if isinstance(dataset, CaptureMapDataset):
_reload_dataloader_state_dict_automatic_map_dataset(dataloader, state_dict)

elif isinstance(dataset, CaptureIterableDataset):
dataset.load_state_dict(
{sampler_name: state[0]["sampler_state"] for sampler_name, state in state_dict["state"].items()}
)
_reload_dataloader_state_dict_automatic_iterable_dataset(dataset, state_dict)

else:
raise MisconfigurationException("This shouldn't be happening. Please, open an issue.")


def _reload_dataloader_state_dict_manual(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None:
# In manual mode, we don't wrap the user objects with `CaptureMapDataset` or `CaptureIterableDataset`
# therefore, we need to reload the states manually.

latest_worker_id = state_dict["latest_worker_id"]
num_workers = state_dict["state"][latest_worker_id]["num_workers"]
sampler_state = state_dict["state"][latest_worker_id].get("sampler_state", None)
if sampler_state:
# `sampler_state` keys contain all the DataLoader attribute names
# which matched `_SupportsStateDict` API interface while collecting the `state_dict`.
for dataloader_attr_name in sampler_state:
obj = getattr(dataloader, dataloader_attr_name)
if not isinstance(obj, _SupportsStateDict):
raise MisconfigurationException(
f"The DataLoader attribute {dataloader_attr_name}:{obj} should have a `load_state_dict` method."
)

obj.load_state_dict(sampler_state[dataloader_attr_name])

if not isinstance(dataloader.dataset, _SupportsStateDict):
return

dataset_state = {
worker_id: state_dict["state"][worker_id]["dataset_state"][worker_id]
for worker_id in state_dict["state"].keys()
}

dataloader.dataset.load_state_dict(_rotate_worker_indices(dataset_state, latest_worker_id, num_workers))


def _reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None:
"""Utility to reload state_dict within dataloader for fault tolerance."""

fault_tolerant_mode = _FaultTolerantMode.detect_current_mode()

if not fault_tolerant_mode.is_enabled:
return

if fault_tolerant_mode.is_automatic:
_reload_dataloader_state_dict_automatic(dataloader, state_dict)

elif fault_tolerant_mode.is_manual:
_reload_dataloader_state_dict_manual(dataloader, state_dict)

else:
raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.")
raise MisconfigurationException("This shouldn't be happening. Please, open an issue.")


def _rotate_worker_indices(state: Dict[int, Any], latest_worker_id: int, num_workers: int) -> Dict[int, Any]:
Expand Down Expand Up @@ -638,7 +689,6 @@ def _store_sampler_state(self) -> None:
for k, v in self._loader.__dict__.items()
if isinstance(v, _SupportsStateDict) and k != "dataset"
}

self.__accumulate_state(sampler_state)

def _next_index(self) -> Any:
Expand Down
10 changes: 8 additions & 2 deletions tests/utilities/test_auto_restart.py
Expand Up @@ -19,6 +19,7 @@
from collections.abc import Iterable
from contextlib import suppress
from copy import deepcopy
from dataclasses import asdict
from typing import List, Optional
from unittest import mock
from unittest.mock import ANY
Expand All @@ -42,6 +43,7 @@
_dataloader_to_state_dict,
_MultiProcessingDataLoaderIterStateful,
_patch_dataloader_get_iterators,
_reload_dataloader_state_dict,
_rotate_worker_indices,
_SingleProcessDataLoaderIterStateful,
_SupportsStateDict,
Expand Down Expand Up @@ -1289,7 +1291,7 @@ def state_dict(self):
return {"counter": self.counter}

def load_state_dict(self, state_dict):
self.counter = state_dict["counter"]
self.counter = state_dict[0]["counter"]


@pytest.mark.parametrize("num_workers", [0])
Expand Down Expand Up @@ -1319,7 +1321,9 @@ def test_stateful_workers(num_workers):
assert isinstance(dataloader_iter, worker_type)

next(data_fetcher_iter)
state = data_fetcher.dataloader_iter.state.state

reloaded_state = deepcopy(data_fetcher.dataloader_iter.state)
state = reloaded_state.state
assert state[0].dataset_state == {0: {"counter": 1}}
assert state[0].sampler_state["sampler"] == {"counter": 1}

Expand Down Expand Up @@ -1350,4 +1354,6 @@ def test_stateful_workers(num_workers):
assert not hasattr(DataLoader, "_ori_get_iterator")
assert DataLoader._get_iterator == _get_iterator_fn

_reload_dataloader_state_dict(dataloader, asdict(reloaded_state))
assert dataloader.sampler.counter == dataloader.dataset.counter == 1
data_fetcher.teardown()

0 comments on commit b28ab34

Please sign in to comment.