diff --git a/CHANGELOG.md b/CHANGELOG.md index da15315e6a7f2..1f7014a71d9a3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,7 +18,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Add a `_rotate_worker_indices` utility to reload the state according the latest worker ([#10647](https://github.com/PyTorchLightning/pytorch-lightning/issues/10647)) * Add stateful workers ([#10674](https://github.com/PyTorchLightning/pytorch-lightning/issues/10674)) * Add an utility to collect the states across processes ([#10639](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639)) - + * Add logic to reload the states across data loading components ([#10699](https://github.com/PyTorchLightning/pytorch-lightning/issues/10699)) - diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 102603f20302b..2fc572ea252e6 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -22,7 +22,7 @@ from pytorch_lightning.loops.base import Loop from pytorch_lightning.loops.utilities import _update_dataloader_iter from pytorch_lightning.trainer.progress import BatchProgress -from pytorch_lightning.utilities.auto_restart import MergedIteratorState, reload_dataloader_state_dict +from pytorch_lightning.utilities.auto_restart import _reload_dataloader_state_dict, MergedIteratorState from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataFetcher from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT @@ -182,7 +182,7 @@ def on_load_checkpoint(self, state_dict: Dict) -> None: def _reload_dataloader_state_dict(self, data_fetcher: AbstractDataFetcher): if not self.trainer.sanity_checking and self._dataloader_state_dict: - reload_dataloader_state_dict(data_fetcher.dataloader, self._dataloader_state_dict) + _reload_dataloader_state_dict(data_fetcher.dataloader, self._dataloader_state_dict) self._dataloader_state_dict = None def _num_completed_batches_reached(self) -> bool: diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 6e2e51e82bbf1..d65bc08e6689e 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -24,9 +24,9 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections from pytorch_lightning.utilities.auto_restart import ( + _reload_dataloader_state_dict, MergedIteratorState, patch_dataloader_iterator, - reload_dataloader_state_dict, ) from pytorch_lightning.utilities.data import get_len from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -403,7 +403,7 @@ def create_loader_iters(dataloader: DataLoader, state_dict: Dict) -> Iterator: if isinstance(dataloader, CycleIterator): dataloader = dataloader_to_iter_on.loader - reload_dataloader_state_dict(dataloader, state_dict) + _reload_dataloader_state_dict(dataloader, state_dict) # We finally spawned the workers if any. it = iter(dataloader_to_iter_on) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 4e984f7ecb2aa..3fa32bc72da5e 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -33,7 +33,6 @@ import pytorch_lightning as pl from pytorch_lightning.utilities.enums import _FaultTolerantMode, AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _fault_tolerant_training class FastForwardSampler(Sampler): @@ -564,38 +563,90 @@ def _add_capture_metadata_collate(dataloader: DataLoader) -> None: ) -def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None: - """Utility to reload state_dict within dataloader for fault tolerance.""" +def _reload_dataloader_state_dict_automatic_map_dataset(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None: + iterator_state = state_dict["state"][0] - if not _fault_tolerant_training(): - return + if not isinstance(iterator_state, IteratorState): + iterator_state = IteratorState.from_state_dict(iterator_state) - dataset = dataloader.dataset + # reload sampler state + ff_sampler = _find_fast_forward_samplers(dataloader) + ff_sampler.load_state_dict(iterator_state.sampler_state) - if isinstance(dataset, CaptureMapDataset): - iterator_state = state_dict["state"][0] + # reload dataset state + dataloader.dataset.load_state_dict( + iterator_state.dataset_state, + latest_worker_id=state_dict["latest_worker_id"], + num_workers=iterator_state.num_workers, + ) - if not isinstance(iterator_state, IteratorState): - iterator_state = IteratorState.from_state_dict(iterator_state) - # reload sampler state - ff_sampler = _find_fast_forward_samplers(dataloader) - ff_sampler.load_state_dict(iterator_state.sampler_state) +def _reload_dataloader_state_dict_automatic_iterable_dataset( + dataset: CaptureIterableDataset, state_dict: Dict[str, Any] +) -> None: + dataset.load_state_dict( + {sampler_name: state[0]["sampler_state"] for sampler_name, state in state_dict["state"].items()} + ) - # reload dataset state - dataset.load_state_dict( - iterator_state.dataset_state, - latest_worker_id=state_dict["latest_worker_id"], - num_workers=iterator_state.num_workers, - ) + +def _reload_dataloader_state_dict_automatic(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None: + dataset = dataloader.dataset + if isinstance(dataset, CaptureMapDataset): + _reload_dataloader_state_dict_automatic_map_dataset(dataloader, state_dict) elif isinstance(dataset, CaptureIterableDataset): - dataset.load_state_dict( - {sampler_name: state[0]["sampler_state"] for sampler_name, state in state_dict["state"].items()} - ) + _reload_dataloader_state_dict_automatic_iterable_dataset(dataset, state_dict) + + else: + raise MisconfigurationException("This shouldn't be happening. Please, open an issue.") + + +def _reload_dataloader_state_dict_manual(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None: + # In manual mode, we don't wrap the user objects with `CaptureMapDataset` or `CaptureIterableDataset` + # therefore, we need to reload the states manually. + + latest_worker_id = state_dict["latest_worker_id"] + num_workers = state_dict["state"][latest_worker_id]["num_workers"] + sampler_state = state_dict["state"][latest_worker_id].get("sampler_state", None) + if sampler_state: + # `sampler_state` keys contain all the DataLoader attribute names + # which matched `_SupportsStateDict` API interface while collecting the `state_dict`. + for dataloader_attr_name in sampler_state: + obj = getattr(dataloader, dataloader_attr_name) + if not isinstance(obj, _SupportsStateDict): + raise MisconfigurationException( + f"The DataLoader attribute {dataloader_attr_name}:{obj} should have a `load_state_dict` method." + ) + + obj.load_state_dict(sampler_state[dataloader_attr_name]) + + if not isinstance(dataloader.dataset, _SupportsStateDict): + return + + dataset_state = { + worker_id: state_dict["state"][worker_id]["dataset_state"][worker_id] + for worker_id in state_dict["state"].keys() + } + + dataloader.dataset.load_state_dict(_rotate_worker_indices(dataset_state, latest_worker_id, num_workers)) + + +def _reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None: + """Utility to reload state_dict within dataloader for fault tolerance.""" + + fault_tolerant_mode = _FaultTolerantMode.detect_current_mode() + + if not fault_tolerant_mode.is_enabled: + return + + if fault_tolerant_mode.is_automatic: + _reload_dataloader_state_dict_automatic(dataloader, state_dict) + + elif fault_tolerant_mode.is_manual: + _reload_dataloader_state_dict_manual(dataloader, state_dict) else: - raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.") + raise MisconfigurationException("This shouldn't be happening. Please, open an issue.") def _rotate_worker_indices(state: Dict[int, Any], latest_worker_id: int, num_workers: int) -> Dict[int, Any]: @@ -638,7 +689,6 @@ def _store_sampler_state(self) -> None: for k, v in self._loader.__dict__.items() if isinstance(v, _SupportsStateDict) and k != "dataset" } - self.__accumulate_state(sampler_state) def _next_index(self) -> Any: diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 25f80ec6817a5..1c27d582cc6a5 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -19,6 +19,7 @@ from collections.abc import Iterable from contextlib import suppress from copy import deepcopy +from dataclasses import asdict from typing import List, Optional from unittest import mock from unittest.mock import ANY @@ -42,6 +43,7 @@ _dataloader_to_state_dict, _MultiProcessingDataLoaderIterStateful, _patch_dataloader_get_iterators, + _reload_dataloader_state_dict, _rotate_worker_indices, _SingleProcessDataLoaderIterStateful, _SupportsStateDict, @@ -1289,7 +1291,7 @@ def state_dict(self): return {"counter": self.counter} def load_state_dict(self, state_dict): - self.counter = state_dict["counter"] + self.counter = state_dict[0]["counter"] @pytest.mark.parametrize("num_workers", [0]) @@ -1319,7 +1321,9 @@ def test_stateful_workers(num_workers): assert isinstance(dataloader_iter, worker_type) next(data_fetcher_iter) - state = data_fetcher.dataloader_iter.state.state + + reloaded_state = deepcopy(data_fetcher.dataloader_iter.state) + state = reloaded_state.state assert state[0].dataset_state == {0: {"counter": 1}} assert state[0].sampler_state["sampler"] == {"counter": 1} @@ -1350,4 +1354,6 @@ def test_stateful_workers(num_workers): assert not hasattr(DataLoader, "_ori_get_iterator") assert DataLoader._get_iterator == _get_iterator_fn + _reload_dataloader_state_dict(dataloader, asdict(reloaded_state)) + assert dataloader.sampler.counter == dataloader.dataset.counter == 1 data_fetcher.teardown()