diff --git a/CHANGELOG.md b/CHANGELOG.md index adb1b070dc386..fd9cf54a9730e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Add `_SupportsStateDict` protocol to detect if classes are stateful ([#10646](https://github.com/PyTorchLightning/pytorch-lightning/issues/10646)) * Add `_FaultTolerantMode` enum used to track different supported fault tolerant modes ([#10645](https://github.com/PyTorchLightning/pytorch-lightning/issues/10645)) * Add a `_rotate_worker_indices` utility to reload the state according the latest worker ([#10647](https://github.com/PyTorchLightning/pytorch-lightning/issues/10647)) + * Add stateful workers ([#10674](https://github.com/PyTorchLightning/pytorch-lightning/issues/10674)) - diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index d7354a8294b37..6044f1320286c 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -15,7 +15,6 @@ import os from abc import ABC from copy import deepcopy -from functools import partial from typing import Any, Callable, Collection, List, Optional, Tuple, Union from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler @@ -29,7 +28,7 @@ from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection -from pytorch_lightning.utilities.auto_restart import _capture_metadata_collate +from pytorch_lightning.utilities.auto_restart import _add_capture_metadata_collate from pytorch_lightning.utilities.data import ( _auto_add_worker_init_fn, _update_dataloader, @@ -215,7 +214,7 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) - # add collate_fn to collect metadata for fault tolerant training if _fault_tolerant_training(): - apply_to_collection(self.train_dataloader, DataLoader, self._add_sampler_metadata_collate) + apply_to_collection(self.train_dataloader, DataLoader, _add_capture_metadata_collate) # wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches self.train_dataloader = CombinedLoader(self.train_dataloader, self._data_connector.multiple_trainloader_mode) @@ -437,14 +436,6 @@ def request_dataloader( self.training_type_plugin.barrier("get_dataloaders") return dataloader - @staticmethod - def _add_sampler_metadata_collate(dataloader: DataLoader) -> None: - """Wrap default collate function to retrive ``FastForwardSampler`` state dict when fault tolerant is - enabled.""" - dataloader.collate_fn = partial( - _capture_metadata_collate, dataset=dataloader.dataset, default_collate=dataloader.collate_fn - ) - @staticmethod def _resolve_overfit_batches(dataloader: Collection[DataLoader]) -> Collection[DataLoader]: all_have_sequential_sampler = True diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 4cb1793643c1d..4e984f7ecb2aa 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -21,11 +21,17 @@ import numpy as np import torch from torch.utils.data import Dataset, get_worker_info, Sampler -from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader, IterableDataset +from torch.utils.data.dataloader import ( + _BaseDataLoaderIter, + _MultiProcessingDataLoaderIter, + _SingleProcessDataLoaderIter, + DataLoader, + IterableDataset, +) from typing_extensions import Protocol, runtime_checkable import pytorch_lightning as pl -from pytorch_lightning.utilities.enums import AutoRestartBatchKeys +from pytorch_lightning.utilities.enums import _FaultTolerantMode, AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training @@ -435,8 +441,10 @@ def _find_current_worker(iterator: Iterator) -> Dict[str, Optional[int]]: return {"num_workers": num_workers, "previous_worker": previous_worker} -def _capture_metadata_collate(samples: List, dataset: Dataset, default_collate: Callable) -> Dict: - """A collate function that adds the state dict of a :class:`CaptureIterableDataset` or +def _capture_metadata_collate( + samples: List, dataset: Dataset, collate_fn: Callable, fault_tolerant_mode: _FaultTolerantMode +) -> Any: + """A collate_fn function that adds the state dict of a :class:`CaptureIterableDataset` or :class:`CaptureMapDataset` used in the worker processes. This function gets executed within the worker processes. The structure will be: @@ -447,10 +455,25 @@ def _capture_metadata_collate(samples: List, dataset: Dataset, default_collate: "__pl_restart_meta": {"sampler_name0": state_dict0, "sampler_name1": state_dict1}, } """ - data = default_collate(samples) - if not isinstance(dataset, (CaptureIterableDataset, CaptureMapDataset)): - return data - metadata = dataset.state_dict() + data = collate_fn(samples) + metadata = None + if fault_tolerant_mode.is_automatic: + metadata = dataset.state_dict() + else: + state_dict_fn = getattr(dataset, "state_dict", None) + info = get_worker_info() + worker_id = info.id if info else 0 + if state_dict_fn is not None: + metadata = state_dict_fn() + if worker_id not in metadata: + if info and info.num_workers > 1: + raise MisconfigurationException( + f"The state_dict returned by {dataset} needs to be indexed by `worker_id` integer keys." + ) + metadata = {0: metadata} + if metadata is None: + metadata = {worker_id: {}} + return {"data": data, AutoRestartBatchKeys.PL_RESTART_META: metadata} @@ -480,6 +503,9 @@ def patch_dataloader_iterator( will extract the current iteration as part of the metadata returned by a custom batch. """ + if not _FaultTolerantMode.detect_current_mode().is_automatic: + return + assert isinstance(dataloader.dataset, (CaptureMapDataset, CaptureIterableDataset)) def _next_data_wrapper(fn, it, dl, num_batches_fetched) -> Callable: @@ -527,8 +553,14 @@ def wrapper(): def _add_capture_metadata_collate(dataloader: DataLoader) -> None: """Wrap default collate function to retrive captured dataset state dict when fault tolerant is enabled.""" + faut_tolerant_mode = _FaultTolerantMode.detect_current_mode() + if not faut_tolerant_mode.is_enabled: + return dataloader.collate_fn = partial( - _capture_metadata_collate, dataset=dataloader.dataset, default_collate=dataloader.collate_fn + _capture_metadata_collate, + dataset=dataloader.dataset, + collate_fn=dataloader.collate_fn, + fault_tolerant_mode=faut_tolerant_mode, ) @@ -589,3 +621,106 @@ def state_dict(self) -> Dict[str, Any]: def load_state_dict(self, state_dict: Dict[str, Any]) -> None: ... + + +class _StatefulDataLoaderIter: + """This mixin is used to make PyTorch DataLoaderIter stateful.""" + + def __accumulate_state(self, sampler_state: Dict[str, Any]) -> None: + # store sampler state within a queue alongside its idx. + self._sampler_state_idx = getattr(self, "_sampler_state_idx", 0) + 1 + self._sampler_state.append((sampler_state, self._sampler_state_idx)) + + def _store_sampler_state(self) -> None: + """This function is used to extract the sampler states if any.""" + sampler_state = { + k: v.state_dict() + for k, v in self._loader.__dict__.items() + if isinstance(v, _SupportsStateDict) and k != "dataset" + } + + self.__accumulate_state(sampler_state) + + def _next_index(self) -> Any: + indexes = super()._next_index() + self._store_sampler_state() + return indexes + + def _prepare_loader(self, loader): + if not isinstance(loader.collate_fn, partial): + loader.collate_fn = partial(_capture_metadata_collate, dataset=loader.dataset, collate_fn=loader.collate_fn) + self._loader = loader + self._data_fetcher: "pl.utilities.fetching.AbstractDataFetcher" = loader._lightning_fetcher + self.num_batches_fetched = 0 + self._sampler_state = [] + self._sampler_state_idx = 0 + + def __del__(self) -> None: + if isinstance(self._loader.collate_fn, partial): + self._loader.collate_fn = self._loader.collate_fn.keywords["collate_fn"] + + def _next_data(self) -> Any: + combined_batch = super()._next_data() + + batch, state = combined_batch["data"], combined_batch[AutoRestartBatchKeys.PL_RESTART_META] + + self.num_batches_fetched += 1 + + sampler_state, sampler_state_idx = self._sampler_state.pop(0) + # there is no workers within the samplers + worker_id = list(state.keys())[0] + + state = [ + IteratorState( + num_workers=self._loader.num_workers, + sampler_state=sampler_state, + dataset_state=state, + worker_id=worker_id, + num_batches_fetched=self.num_batches_fetched, + ) + ] + # ensures there is an alignement between the sampler state and currently fetched batch + assert sampler_state_idx == self.num_batches_fetched + self._data_fetcher._store_dataloader_iter_state(self, state) + return batch + + +class _SingleProcessDataLoaderIterStateful(_StatefulDataLoaderIter, _SingleProcessDataLoaderIter): + def __init__(self, loader: DataLoader): + self._prepare_loader(loader) + super().__init__(loader) + + +class _MultiProcessingDataLoaderIterStateful(_StatefulDataLoaderIter, _MultiProcessingDataLoaderIter): + def __init__(self, loader: DataLoader): + self._prepare_loader(loader) + super().__init__(loader) + + +def _get_iterator(self) -> "_BaseDataLoaderIter": + if not hasattr(self, "_lightning_fetcher"): + raise MisconfigurationException( + "A stateful iterator should be used only when a DataFetcher has been attached to the DataLoader." + ) + if self.num_workers == 0: + return _SingleProcessDataLoaderIterStateful(self) + else: + if hasattr(self, "check_worker_number_rationality"): + self.check_worker_number_rationality() + return _MultiProcessingDataLoaderIterStateful(self) + + +def _patch_dataloader_get_iterators() -> None: + """This function is used to replace the DataLoader iterator by their stateful version.""" + if not hasattr(DataLoader, "_ori_get_iterator"): + DataLoader._ori_get_iterator = DataLoader._get_iterator + DataLoader._get_iterator = _get_iterator + + +def _teardown_dataloader_get_iterators() -> None: + """This function is used to restore the DataLoader `get_iterator` with its original one.""" + # cleanup the get_iterator replacement in case of Fault Tolerant Training. + get_iterator = getattr(DataLoader, "_ori_get_iterator", None) + if get_iterator: + DataLoader._get_iterator = get_iterator + del DataLoader._ori_get_iterator diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index 9b80d2f9874c7..f5bb4be032d10 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -99,6 +99,8 @@ def setup( if self.profiler is not None and stage is None: raise MisconfigurationException("When providing a profiler, the stage should be provided too.") + self._attach_data_fetcher() + @staticmethod def _add_capture_metadata_collate(dataloader: Iterable) -> None: if not isinstance(dataloader, (DataLoader, CombinedLoader)): @@ -190,6 +192,16 @@ def collect_state(iterator: Iterator): return apply_to_collection(self.loader_iters, Iterator, collect_state) + def _attach_data_fetcher(self): + def _attach_data_fetcher_fn(loader: DataLoader): + if isinstance(loader, CycleIterator): + loader = loader.loader + + if isinstance(loader, DataLoader) and _fault_tolerant_training(): + loader._lightning_fetcher = self + + apply_to_collection(self.loaders, (DataLoader, CycleIterator), _attach_data_fetcher_fn) + def __iter__(self) -> Generator[Tuple[Any, bool], None, None]: if self.dataloader is None: raise MisconfigurationException("The iterate hasn't been provided. HINT: Did you call setup function ?.") diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index d9063f90db377..25f80ec6817a5 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -40,8 +40,12 @@ _add_capture_metadata_collate, _dataloader_load_state_dict, _dataloader_to_state_dict, + _MultiProcessingDataLoaderIterStateful, + _patch_dataloader_get_iterators, _rotate_worker_indices, + _SingleProcessDataLoaderIterStateful, _SupportsStateDict, + _teardown_dataloader_get_iterators, CaptureIterableDataset, CaptureMapDataset, FastForwardSampler, @@ -245,8 +249,10 @@ def __next__(self): return self.data[next(iter_sampler)] +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) @pytest.mark.skipif(torch.cuda.is_available(), reason="This test takes around 30 sec and should be skipped in Azure CI") @pytest.mark.parametrize("num_workers", [0, 1, 2]) +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) def test_fast_forward_sampler_over_iterable_dataset(num_workers): """This test ensures ``FastForwardSampler`` and ``CaptureIterableDataset`` are properly being used to capture workers states.""" @@ -626,11 +632,13 @@ def all_gather(tensor, world_size): assert torch.equal(t, tr) +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) @pytest.mark.skipif(torch.cuda.is_available(), reason="This test takes around 45 sec and should be skipped in Azure CI") def test_fast_forward_sampler_iterative_dataset(): _test_fast_forward_sampler_with_distributed_sampler_and_iterative_dataset(0, 1) +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) @pytest.mark.skipif(torch.cuda.is_available(), reason="This test takes around 55 sec and should be skipped in Azure CI") @RunIf(skip_windows=True) def test_fast_forward_sampler_with_distributed_sampler_and_iterative_dataset(): @@ -1251,3 +1259,95 @@ def test_fault_tolerant_mode_enum(): ): with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "3"}): _FaultTolerantMode.detect_current_mode() + + +class StatefulRandomSampler(RandomSampler): + + counter = 0 + + def state_dict(self): + self.counter += 1 + return {"counter": self.counter} + + def load_state_dict(self, state_dict): + self.counter = state_dict["counter"] + + +class StatefulRandomDataset(RandomDataset): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.counter = 0 + + def __getitem__(self, index): + self.counter += 1 + return super().__getitem__(index) + + def state_dict(self): + info = get_worker_info() + if info: + return {info.id: {"counter": self.counter}} + return {"counter": self.counter} + + def load_state_dict(self, state_dict): + self.counter = state_dict["counter"] + + +@pytest.mark.parametrize("num_workers", [0]) +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "2"}) +def test_stateful_workers(num_workers): + + seed_everything(42) + + _get_iterator_fn = DataLoader._get_iterator + _patch_dataloader_get_iterators() + assert DataLoader._ori_get_iterator is not None + + data_fetcher = DataFetcher() + dataset = StatefulRandomDataset(1, 64) + dataloader = DataLoader(dataset, sampler=StatefulRandomSampler(dataset), num_workers=num_workers) + + with pytest.raises(MisconfigurationException, match="A stateful iterator should be used"): + iter(dataloader) + + # This would attach the `data_fetcher` to the DataLoader. + data_fetcher.setup(dataloader) + + data_fetcher_iter = iter(data_fetcher) + + dataloader_iter = data_fetcher.dataloader_iter + worker_type = _SingleProcessDataLoaderIterStateful if num_workers == 0 else _MultiProcessingDataLoaderIterStateful + assert isinstance(dataloader_iter, worker_type) + + next(data_fetcher_iter) + state = data_fetcher.dataloader_iter.state.state + assert state[0].dataset_state == {0: {"counter": 1}} + assert state[0].sampler_state["sampler"] == {"counter": 1} + + next(data_fetcher_iter) + previous_state = data_fetcher.dataloader_iter.previous_state.state + state = data_fetcher.dataloader_iter.state.state + assert previous_state[0].dataset_state == {0: {"counter": 1}} + assert previous_state[0].sampler_state["sampler"] == {"counter": 1} + # TODO: Resolve the previous `sampler_state` associated to `worker_id: 0`. + worker_id = 1 if num_workers else 0 + assert state[worker_id].sampler_state["sampler"] == {"counter": 2} + + # each worker has its own copy of the dataset + assert state[0].dataset_state == ({0: {"counter": 2}} if num_workers == 0 else {0: {"counter": 1}}) + target_previous_state = deepcopy(state) + + next(data_fetcher_iter) + latest_worker_id = data_fetcher.dataloader_iter.state.latest_worker_id + assert latest_worker_id == 0 + previous_state = data_fetcher.dataloader_iter.previous_state.state + state = data_fetcher.dataloader_iter.state.state + + assert target_previous_state == previous_state + assert state[0].sampler_state["sampler"] == {"counter": 3} + assert state[0].dataset_state == ({0: {"counter": 3}} if num_workers == 0 else {0: {"counter": 2}}) + + _teardown_dataloader_get_iterators() + assert not hasattr(DataLoader, "_ori_get_iterator") + assert DataLoader._get_iterator == _get_iterator_fn + + data_fetcher.teardown()