From e94aff1c5bed2c616143dc694343c71a2be0b1bd Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Fri, 26 Nov 2021 19:33:47 +0000 Subject: [PATCH] Fault Tolerant: Add support for fault tolerant dataloader validator (#10465) --- CHANGELOG.md | 3 + pytorch_lightning/trainer/data_loading.py | 3 +- pytorch_lightning/utilities/auto_restart.py | 96 ++++++++++++++++- tests/utilities/test_auto_restart.py | 112 +++++++++++++++++++- 4 files changed, 208 insertions(+), 6 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f52369b443164..264e66e278b6e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added support for re-instantiation of custom (subclasses of) `DataLoaders` returned in the `*_dataloader()` methods, i.e., automatic replacement of samplers now works with custom types of `DataLoader` ([#10680](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639)) +- Added a function to validate if fault tolerant training is supported. ([#10465](https://github.com/PyTorchLightning/pytorch-lightning/issues/10465)) + + - Show a better error message when a custom `DataLoader` implementation is not well implemented and we need to reconstruct it ([#10719](https://github.com/PyTorchLightning/pytorch-lightning/issues/10719)) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index bfba0229660a6..455c2719b124a 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -28,7 +28,7 @@ from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection -from pytorch_lightning.utilities.auto_restart import _add_capture_metadata_collate +from pytorch_lightning.utilities.auto_restart import _add_capture_metadata_collate, _validate_fault_tolerant_automatic from pytorch_lightning.utilities.data import ( _auto_add_worker_init_fn, _replace_dataloader_init_method, @@ -441,6 +441,7 @@ def request_dataloader( if isinstance(dataloader, tuple): dataloader = list(dataloader) self.training_type_plugin.barrier("get_dataloaders") + _validate_fault_tolerant_automatic(dataloader, stage) return dataloader @staticmethod diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 84f0c9decefea..9d26f4a6e0736 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -16,11 +16,19 @@ from functools import partial, wraps from random import getstate as python_get_rng_state from random import setstate as python_set_rng_state -from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, Union import numpy as np import torch -from torch.utils.data import Dataset, get_worker_info, Sampler +from torch.utils.data import ( + BatchSampler, + Dataset, + DistributedSampler, + get_worker_info, + RandomSampler, + Sampler, + SequentialSampler, +) from torch.utils.data.dataloader import ( _BaseDataLoaderIter, _MultiProcessingDataLoaderIter, @@ -370,7 +378,7 @@ def _cycle_to_next_worker_and_reset(dataloader: DataLoader, state_dict: Dict[str # get current num workers num_workers = getattr(iter_dataloader, "_num_workers", 0) # as `state_dict` are workers dependent, Lightning doesn't support changing - # the `num_workers` for fault tolerant training + # the `num_workers` for Fault-tolerance if state_dict["num_workers"] != num_workers: raise MisconfigurationException( f"The provided `num_workers` {num_workers} doesn't match the one used " @@ -734,13 +742,93 @@ def _patch_dataloader_get_iterators() -> None: def _teardown_dataloader_get_iterators() -> None: """This function is used to restore the DataLoader `get_iterator` with its original one.""" - # cleanup the get_iterator replacement in case of Fault Tolerant Training. + # cleanup the get_iterator replacement in case of Fault-tolerance. get_iterator = getattr(DataLoader, "_ori_get_iterator", None) if get_iterator: DataLoader._get_iterator = get_iterator del DataLoader._ori_get_iterator +def _validate_iterable_dataset(dataloader: DataLoader) -> None: + SUPPORTED_SAMPLERS = (RandomSampler, SequentialSampler, DistributedSampler) + + dataset = dataloader.dataset + + if getattr(dataset, "__next__", None) is None: + raise AttributeError( + "Fault-tolerance doesn't support an `IterableDataset` without `__next__` " + "method implemented. Hint: We recommend you to move your logic from `__iter__`" + " inside and rely on a sampler to perform the sample sampling." + ) + + samplers = {k: v for k, v in dataset.__dict__.items() if isinstance(v, Sampler)} + + if not samplers: + raise TypeError("Fault-tolerance doesn't support an IterableDataset without a sampler as attribute.") + + sampler = [v for v in samplers.values() if type(v) in SUPPORTED_SAMPLERS] + + if not sampler: + raise TypeError(f"Fault-tolerance supports only {SUPPORTED_SAMPLERS}.") + + if len(sampler) > 1: + raise ValueError(f"A single sampler is supported within an Iterable Dataset. Found {sampler}.") + + if type(sampler[0]) is DistributedSampler and sampler.shuffle: + raise TypeError("A `DistributedSampler` sampler shuffle attribute is set to True.") + elif type(sampler[0]) is not SequentialSampler: + raise TypeError("Only `SequentialSampler` is supported.") + + +def _validate_map_dataset(dataloader: DataLoader) -> None: + SUPPORTED_SAMPLERS = (RandomSampler, SequentialSampler, DistributedSampler) + + sampler = getattr(dataloader, "sampler", None) + if sampler is not None and type(sampler) not in SUPPORTED_SAMPLERS: + raise TypeError(f"Fault-tolerance supports only {SUPPORTED_SAMPLERS}.") + + batch_sampler = getattr(dataloader, "batch_sampler", None) + if batch_sampler is not None and type(batch_sampler) is not BatchSampler: + raise TypeError("Fault-tolerance supports only a `BatchSampler`.") + + if type(sampler) is DistributedSampler and sampler.shuffle: + raise TypeError("A `DistributedSampler` sampler shuffle attribute is set to True.") + elif type(sampler) is RandomSampler: + raise TypeError("Only `SequentialSampler` is supported.") + + +def _validate_fault_tolerant_automatic(dataloader: Iterable, stage: "pl.trainer.states.RunningStage") -> None: + """This function is used to validate that Fault-tolerance is possible with the user data.""" + if not _FaultTolerantMode.detect_current_mode().is_automatic: + return + + from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator + + if isinstance(dataloader, CombinedLoader): + dataloaders = dataloader.loaders + else: + dataloaders = dataloader + + dl_loaders = [] + + def flatten_dataloader(dataloader: Union[DataLoader, CycleIterator, Iterable]) -> None: + nonlocal dl_loaders + if isinstance(dataloader, CycleIterator): + dataloader = dataloader.loader + dl_loaders.append(dataloader) + + apply_to_collection(dataloaders, (DataLoader, CycleIterator), flatten_dataloader) + + if len(dl_loaders) > 1 and stage == pl.trainer.states.RunningStage.TRAINING: + raise ValueError("Fault-tolerance supports only a single dataloader.") + + for dataloader in dl_loaders: + validator_fn = ( + _validate_iterable_dataset if isinstance(dataloader.dataset, IterableDataset) else _validate_map_dataset + ) + validator_fn(dataloader) + + def _collect_states_on_rank_zero_over_collection(state_dict: Any, key: str = "state") -> Any: """This utility collects the state across processes for a collection of state.""" diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 1a479af05aa3f..4c2c440797dd2 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -34,10 +34,12 @@ from torch.utils.data._utils.worker import get_worker_info from torch.utils.data.dataloader import DataLoader, default_collate from torch.utils.data.dataset import Dataset, IterableDataset +from torch.utils.data.sampler import Sampler import tests.helpers.utils as tutils from pytorch_lightning import Callback, LightningModule, seed_everything, Trainer -from pytorch_lightning.trainer.states import TrainerState +from pytorch_lightning.trainer.states import RunningStage, TrainerState +from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities.auto_restart import ( _add_capture_metadata_collate, _collect_states_on_rank_zero_over_collection, @@ -48,6 +50,7 @@ _SingleProcessDataLoaderIterStateful, _SupportsStateDict, _teardown_dataloader_get_iterators, + _validate_fault_tolerant_automatic, CaptureIterableDataset, CaptureMapDataset, FastForwardSampler, @@ -665,6 +668,7 @@ def create_iterable_dataset(batch_size, num_workers, attr_name="iter_sampler", w return dataset +@mock.patch("pytorch_lightning.trainer.data_loading._validate_fault_tolerant_automatic", lambda x, y: None) @pytest.mark.parametrize("use_fault_tolerant", ["0", "1"]) def test_data_loading_wraps_dataset_and_samplers(use_fault_tolerant, tmpdir): """This test ensures the dataset and sampler are properly wrapped when fault tolerant is enabled.""" @@ -893,6 +897,10 @@ def _run_training(trainer_kwargs, dataset_classes, fail_on_step: int = -1, ckpt_ return model.seen_batches, model.parameters() +# this test will fail `fault_tolerant` don't support multiple datasets. +# this tests works as the dataset is fully deterministic and therefore +# there is not overall between the seeds. +@mock.patch("pytorch_lightning.trainer.data_loading._validate_fault_tolerant_automatic", lambda x, y: None) @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) @pytest.mark.parametrize( "dataset_classes", @@ -1180,6 +1188,108 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on assert "dataloader_state_dict" in state_dict +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +def test_validate_fault_tolerant(tmpdir): + def data(): + return range(10) + + def dataloader(): + return DataLoader(data()) + + _validate_fault_tolerant_automatic(dataloader(), RunningStage.TRAINING) + + dataloaders = CombinedLoader([dataloader(), dataloader()]) + with pytest.raises(ValueError, match="Fault-tolerance supports only a single dataloader."): + _validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING) + + dataloaders = CombinedLoader([dataloader(), dataloader()], mode="max_size_cycle") + with pytest.raises(ValueError, match="Fault-tolerance supports only a single dataloader."): + _validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING) + + dataloaders = [dataloader(), dataloader()] + with pytest.raises(ValueError, match="Fault-tolerance supports only a single dataloader."): + _validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING) + + _validate_fault_tolerant_automatic(dataloaders, RunningStage.VALIDATING) + + dataloaders = [DataLoader(data(), sampler=DistributedSampler(data(), num_replicas=2, rank=0, shuffle=True))] + with pytest.raises(TypeError, match="A `DistributedSampler` sampler shuffle attribute is set to True."): + _validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING) + + dataloaders = [DataLoader(data(), sampler=DistributedSampler(data(), num_replicas=2, rank=0, shuffle=False))] + _validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING) + + dataset = SequentialGetItemDataset(2) + dataloaders = [ + DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=False)), + DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=False)), + ] + with pytest.raises(ValueError, match="Fault-tolerance supports only a single dataloader."): + _validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING) + + dataloaders = [ + DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=True)), + DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=False)), + ] + with pytest.raises(ValueError, match="Fault-tolerance supports only a single."): + _validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING) + + dataloaders = [ + DataLoader(dataset, sampler=RandomSampler(dataset)), + DataLoader(dataset, sampler=SequentialSampler(dataset)), + ] + + with pytest.raises(TypeError, match="Only `SequentialSampler` is supported."): + _validate_fault_tolerant_automatic(dataloaders, RunningStage.VALIDATING) + + class CustomRandomSampler(RandomSampler): + pass + + dl = DataLoader(data(), sampler=CustomRandomSampler(data())) + with pytest.raises(TypeError, match="RandomSampler"): + _validate_fault_tolerant_automatic(dl, RunningStage.TRAINING) + + class CustomBatchSampler(BatchSampler): + pass + + sampler = Sampler(data()) + batch_sampler = CustomBatchSampler(sampler, 2, False) + dl = DataLoader(data(), batch_sampler=batch_sampler) + with pytest.raises(TypeError, match="BatchSampler"): + _validate_fault_tolerant_automatic(dl, RunningStage.TRAINING) + + class CustomIterable(IterableDataset): + pass + + iterable_dataloader = DataLoader(CustomIterable()) + with pytest.raises(AttributeError, match="without `__next__` method"): + _validate_fault_tolerant_automatic(iterable_dataloader, RunningStage.TRAINING) + + class CustomIterable(IterableDataset): + def __next__(self): + return torch.tensor(0) + + iterable_dataloader = DataLoader(CustomIterable()) + with pytest.raises(TypeError, match="IterableDataset without a sampler as attribute"): + _validate_fault_tolerant_automatic(iterable_dataloader, RunningStage.TRAINING) + + class CustomIterable(IterableDataset): + def __init__(self): + super().__init__() + self.sampler = CustomRandomSampler(data()) + + def __next__(self): + return torch.tensor(0) + + iterable_dataloader = DataLoader(CustomIterable()) + with pytest.raises(TypeError, match="RandomSampler"): + _validate_fault_tolerant_automatic(iterable_dataloader, RunningStage.TRAINING) + + dataloaders = [iterable_dataloader, DataLoader(CustomIterable())] + with pytest.raises(TypeError, match="RandomSampler"): + _validate_fault_tolerant_automatic(dataloaders, RunningStage.VALIDATING) + + def test_rotate_worker_indices(): """This test ensures `worker_id` are rotated properly depending on which one was the latest.""" state_dict = {0: 0, 1: 1}