From 1b3eb63ac19a03f15e6921a2de199ffba3dfd7bb Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 10 Nov 2021 17:50:21 +0000 Subject: [PATCH 01/22] update --- .../trainer/connectors/data_connector.py | 2 ++ pytorch_lightning/utilities/auto_restart.py | 11 ++++++++++- tests/utilities/test_auto_restart.py | 10 ++++++++++ 3 files changed, 22 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 90c398087578d..2ab98bc6f8c49 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -19,6 +19,7 @@ import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_deprecation +from pytorch_lightning.utilities.auto_restart import _validate_fault_tolerant_training from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.fetching import ( AbstractDataFetcher, @@ -119,6 +120,7 @@ def _select_data_fetcher(self) -> AbstractDataFetcher: return DataFetcher() def get_profiled_dataloader(self, dataloader: Iterable, dataloader_idx: int = 0) -> Iterable: + _validate_fault_tolerant_training(self.trainer, dataloader) stage: str = self.trainer.state.stage.value data_fetcher = setattr(self, f"{stage}_data_fetcher", None) or self._select_data_fetcher() data_fetcher.setup( diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index f0b50103cf2f2..7b2f293796dd5 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -17,7 +17,7 @@ from functools import partial, wraps from random import getstate as python_get_rng_state from random import setstate as python_set_rng_state -from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, Union import numpy as np import torch @@ -588,3 +588,12 @@ def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, A else: raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.") + + +def _validate_fault_tolerant_training(trainer: "pl.Trainer", dataloader: Iterable) -> None: + from pytorch_lightning.trainer.supporters import CombinedLoader + + if not _fault_tolerant_training(): + return + if isinstance(dataloader, CombinedLoader): + pass diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 4e3385cebecbc..65847f5f02cec 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -39,6 +39,7 @@ _add_capture_metadata_collate, _dataloader_load_state_dict, _dataloader_to_state_dict, + _validate_fault_tolerant_training, CaptureIterableDataset, CaptureMapDataset, FastForwardSampler, @@ -1196,3 +1197,12 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on assert "dataloader_state_dict" not in state_dict else: assert "dataloader_state_dict" in state_dict + + +@pytest.mark.parametrize("val_check_interval", [0.5, 1.0]) +def test_validate_fault_tolerant(val_check_interval, tmpdir): + + trainer = Trainer(default_root_dir=tmpdir, max_epohs=1, val_check_interval=val_check_interval) + + dataloaders = DataLoader(range(10)) + _validate_fault_tolerant_training(trainer, dataloaders) From a58c5f0b0fe2cfbd5a7f1d9bf00b8919e4a6777c Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 10 Nov 2021 19:59:52 +0000 Subject: [PATCH 02/22] update --- pytorch_lightning/trainer/data_loading.py | 2 + pytorch_lightning/utilities/auto_restart.py | 72 +++++++++++++++- tests/utilities/test_auto_restart.py | 96 +++++++++++++++++++-- 3 files changed, 161 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 37a234f32f711..9bc7b066fcd89 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -33,6 +33,7 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.auto_restart import ( _capture_metadata_collate, + _validate_fault_tolerant_training, CaptureIterableDataset, CaptureMapDataset, FastForwardSampler, @@ -579,6 +580,7 @@ def request_dataloader( if isinstance(dataloader, tuple): dataloader = list(dataloader) self.training_type_plugin.barrier("get_dataloaders") + _validate_fault_tolerant_training(self, dataloader, stage) return dataloader @staticmethod diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 7b2f293796dd5..f2cbbbc570633 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -17,14 +17,18 @@ from functools import partial, wraps from random import getstate as python_get_rng_state from random import setstate as python_set_rng_state -from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Sequence, Tuple, Union import numpy as np import torch from torch.utils.data import Dataset, get_worker_info, Sampler from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader, IterableDataset +from torch.utils.data.distributed import DistributedSampler +from torch.utils.data.sampler import BatchSampler, RandomSampler, SequentialSampler import pytorch_lightning as pl +from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.utilities.apply_func import apply_to_collection from pytorch_lightning.utilities.enums import AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training @@ -590,10 +594,70 @@ def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, A raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.") -def _validate_fault_tolerant_training(trainer: "pl.Trainer", dataloader: Iterable) -> None: - from pytorch_lightning.trainer.supporters import CombinedLoader +def _validate_fault_tolerant_training(trainer: "pl.Trainer", dataloader: Iterable, stage: RunningStage) -> None: + """This function is used to validate fault tolerant training is possible with the user data.""" + from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator if not _fault_tolerant_training(): return + + if trainer.val_check_interval != 1.0: + raise MisconfigurationException( + "Fault Tolerant Training isn't support for `val_check_interval` different than 1.0." + ) + if isinstance(dataloader, CombinedLoader): - pass + count_dataloader = 0 + + def increment_count_dataloader(dataloader: Union[DataLoader, CycleIterator]) -> None: + nonlocal count_dataloader + count_dataloader += 1 + + apply_to_collection(dataloader.loaders, (DataLoader, CycleIterator), increment_count_dataloader) + if count_dataloader > 1: + raise MisconfigurationException("Fault Tolerant Training supports only a single dataloader.") + + dataloader = dataloader.loaders + + elif isinstance(dataloader, Sequence): + if len(dataloader) > 1 and stage == RunningStage.TRAINING: + raise MisconfigurationException("Fault Tolerant Training supports only a single dataloader.") + + else: + dataloader = [dataloader] + + supported_samplers = (RandomSampler, SequentialSampler, DistributedSampler) + + for dataloader in dataloader: + + if isinstance(dataloader.dataset, IterableDataset): + dataset = dataloader.dataset + + next_fn = getattr(dataset, "__next__", None) + if not next_fn: + raise MisconfigurationException( + "Fault Tolerant Training doesn't support an IterableDataset without a `__next__`" + " method implemented. Hint: We recommend you to move your logic inside and rely " + "on a sampler, generator to perform the iteration." + ) + + samplers = {k: v for k, v in dataset.__dict__.items() if isinstance(v, Sampler)} + + if not samplers: + raise MisconfigurationException( + "Fault Tolerant Training doesn't support an IterableDataset without a sampler as attribute." + ) + + for v in samplers.keys(): + if v.__class__ not in supported_samplers: + raise MisconfigurationException(f"Fault Tolerant Training supports only {supported_samplers}.") + + else: + supported_samplers = (RandomSampler, SequentialSampler, DistributedSampler) + sampler = getattr(dataloader, "sampler", None) + if sampler is not None and sampler.__class__ not in supported_samplers: + raise MisconfigurationException(f"Fault Tolerant Training supports only {supported_samplers}.") + + batch_sampler = getattr(dataloader, "batch_sampler", None) + if batch_sampler is not None and batch_sampler.__class__ is not BatchSampler: + raise MisconfigurationException("Fault Tolerant Training supports only a BatchSampler.") diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 65847f5f02cec..3250f7234ed59 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -32,9 +32,12 @@ from torch.utils.data._utils.worker import get_worker_info from torch.utils.data.dataloader import DataLoader, default_collate from torch.utils.data.dataset import Dataset, IterableDataset +from torch.utils.data.sampler import Sampler import tests.helpers.utils as tutils from pytorch_lightning import Callback, LightningModule, seed_everything, Trainer +from pytorch_lightning.trainer.states import RunningStage +from pytorch_lightning.trainer.supporters import CombinedLoader from pytorch_lightning.utilities.auto_restart import ( _add_capture_metadata_collate, _dataloader_load_state_dict, @@ -1199,10 +1202,93 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on assert "dataloader_state_dict" in state_dict -@pytest.mark.parametrize("val_check_interval", [0.5, 1.0]) -def test_validate_fault_tolerant(val_check_interval, tmpdir): +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +def test_validate_fault_tolerant(tmpdir): + + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=0.5) + data = range(10) + dataloader = DataLoader(data) + + with pytest.raises( + MisconfigurationException, + match="Fault Tolerant Training isn't support for `val_check_interval` different than 1.0.", + ): + _validate_fault_tolerant_training(trainer, dataloader, RunningStage.TRAINING) + + trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) + _validate_fault_tolerant_training(trainer, dataloader, RunningStage.TRAINING) + + with pytest.raises(MisconfigurationException, match="Fault Tolerant Training supports only a single dataloader."): + dataloaders = CombinedLoader([DataLoader(data), DataLoader(range(10))]) + _validate_fault_tolerant_training(trainer, dataloaders, RunningStage.TRAINING) + + with pytest.raises(MisconfigurationException, match="Fault Tolerant Training supports only a single dataloader."): + dataloaders = CombinedLoader([DataLoader(data), DataLoader(range(10))], mode="max_size_cycle") + _validate_fault_tolerant_training(trainer, dataloaders, RunningStage.TRAINING) + + dataloaders = [DataLoader(data), DataLoader(range(10))] + with pytest.raises(MisconfigurationException, match="Fault Tolerant Training supports only a single dataloader."): + _validate_fault_tolerant_training(trainer, dataloaders, RunningStage.TRAINING) + + _validate_fault_tolerant_training(trainer, dataloaders, RunningStage.VALIDATING) + + with pytest.raises(MisconfigurationException, match="RandomSampler"): + + class CustomRandomSampler(RandomSampler): + pass + + dataloader = DataLoader(data, sampler=CustomRandomSampler(data)) + _validate_fault_tolerant_training(trainer, dataloader, RunningStage.TRAINING) + + with pytest.raises(MisconfigurationException, match="BatchSampler"): + + class CustomBatchSampler(BatchSampler): + pass + + sampler = Sampler(data) + batch_sampler = CustomBatchSampler(sampler, 2, False) + dataloader = DataLoader(data, batch_sampler=batch_sampler) + _validate_fault_tolerant_training(trainer, dataloader, RunningStage.TRAINING) + + with pytest.raises(MisconfigurationException, match="without a `__next__` method"): + + class CustomIterable(IterableDataset): + def __iter__(self): + while True: + yield 0 + + dataloader = DataLoader(CustomIterable()) + _validate_fault_tolerant_training(trainer, dataloader, RunningStage.TRAINING) + + with pytest.raises(MisconfigurationException, match="IterableDataset without a sampler as attribute"): + + class CustomIterable(IterableDataset): + def __iter__(self): + return self + + def __next__(self): + return torch.tensor(0) + + dataloader = DataLoader(CustomIterable()) + _validate_fault_tolerant_training(trainer, dataloader, RunningStage.TRAINING) + + with pytest.raises(MisconfigurationException, match="RandomSampler"): + + class CustomIterable(IterableDataset): + def __init__(self): + super().__init__() + self.data = data + self.sampler = CustomRandomSampler(self.data) + + def __iter__(self): + return self + + def __next__(self): + return torch.tensor(0) - trainer = Trainer(default_root_dir=tmpdir, max_epohs=1, val_check_interval=val_check_interval) + dataloader = DataLoader(CustomIterable()) + _validate_fault_tolerant_training(trainer, dataloader, RunningStage.TRAINING) - dataloaders = DataLoader(range(10)) - _validate_fault_tolerant_training(trainer, dataloaders) + dataloaders = [DataLoader(data), DataLoader(CustomIterable())] + with pytest.raises(MisconfigurationException, match="RandomSampler"): + _validate_fault_tolerant_training(trainer, dataloaders, RunningStage.VALIDATING) From 740e7c56ff240d8044631c080629f5df9388c53f Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 10 Nov 2021 20:02:29 +0000 Subject: [PATCH 03/22] update --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 495c9e2398df0..41c58e788da97 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,7 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added -- +- Added a function to validate if fault tolerant training is supported. ([#10465](https://github.com/PyTorchLightning/pytorch-lightning/issues/10465)) + - From d5c7912412ce87d36ec5d1c9adf315f7334403b8 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 10 Nov 2021 20:04:15 +0000 Subject: [PATCH 04/22] update --- pytorch_lightning/trainer/connectors/data_connector.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/data_connector.py b/pytorch_lightning/trainer/connectors/data_connector.py index 2ab98bc6f8c49..90c398087578d 100644 --- a/pytorch_lightning/trainer/connectors/data_connector.py +++ b/pytorch_lightning/trainer/connectors/data_connector.py @@ -19,7 +19,6 @@ import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_deprecation -from pytorch_lightning.utilities.auto_restart import _validate_fault_tolerant_training from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.fetching import ( AbstractDataFetcher, @@ -120,7 +119,6 @@ def _select_data_fetcher(self) -> AbstractDataFetcher: return DataFetcher() def get_profiled_dataloader(self, dataloader: Iterable, dataloader_idx: int = 0) -> Iterable: - _validate_fault_tolerant_training(self.trainer, dataloader) stage: str = self.trainer.state.stage.value data_fetcher = setattr(self, f"{stage}_data_fetcher", None) or self._select_data_fetcher() data_fetcher.setup( From 9d6d9a59f63192158a6139dfd117b7d51e3bb3f7 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 10 Nov 2021 20:06:40 +0000 Subject: [PATCH 05/22] update --- pytorch_lightning/utilities/auto_restart.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index f2cbbbc570633..8a0c6200546a6 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -623,12 +623,11 @@ def increment_count_dataloader(dataloader: Union[DataLoader, CycleIterator]) -> if len(dataloader) > 1 and stage == RunningStage.TRAINING: raise MisconfigurationException("Fault Tolerant Training supports only a single dataloader.") - else: - dataloader = [dataloader] - supported_samplers = (RandomSampler, SequentialSampler, DistributedSampler) - for dataloader in dataloader: + dataloaders = dataloader if isinstance(dataloader, Sequence) else [dataloader] + + for dataloader in dataloaders: if isinstance(dataloader.dataset, IterableDataset): dataset = dataloader.dataset @@ -653,7 +652,6 @@ def increment_count_dataloader(dataloader: Union[DataLoader, CycleIterator]) -> raise MisconfigurationException(f"Fault Tolerant Training supports only {supported_samplers}.") else: - supported_samplers = (RandomSampler, SequentialSampler, DistributedSampler) sampler = getattr(dataloader, "sampler", None) if sampler is not None and sampler.__class__ not in supported_samplers: raise MisconfigurationException(f"Fault Tolerant Training supports only {supported_samplers}.") From 9a63f144d1ffb7044e38810a75d99ad666b56bf3 Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 11 Nov 2021 08:27:55 +0000 Subject: [PATCH 06/22] update --- pytorch_lightning/trainer/data_loading.py | 2 +- pytorch_lightning/utilities/auto_restart.py | 7 +---- tests/utilities/test_auto_restart.py | 30 ++++++++------------- 3 files changed, 13 insertions(+), 26 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 9bc7b066fcd89..727ea5f6d3617 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -580,7 +580,7 @@ def request_dataloader( if isinstance(dataloader, tuple): dataloader = list(dataloader) self.training_type_plugin.barrier("get_dataloaders") - _validate_fault_tolerant_training(self, dataloader, stage) + _validate_fault_tolerant_training(dataloader, stage) return dataloader @staticmethod diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 8a0c6200546a6..edfb0e85943f3 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -594,18 +594,13 @@ def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, A raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.") -def _validate_fault_tolerant_training(trainer: "pl.Trainer", dataloader: Iterable, stage: RunningStage) -> None: +def _validate_fault_tolerant_training(dataloader: Iterable, stage: RunningStage) -> None: """This function is used to validate fault tolerant training is possible with the user data.""" from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator if not _fault_tolerant_training(): return - if trainer.val_check_interval != 1.0: - raise MisconfigurationException( - "Fault Tolerant Training isn't support for `val_check_interval` different than 1.0." - ) - if isinstance(dataloader, CombinedLoader): count_dataloader = 0 diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 3250f7234ed59..0b98536241257 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -1205,32 +1205,24 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) def test_validate_fault_tolerant(tmpdir): - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=0.5) data = range(10) dataloader = DataLoader(data) - with pytest.raises( - MisconfigurationException, - match="Fault Tolerant Training isn't support for `val_check_interval` different than 1.0.", - ): - _validate_fault_tolerant_training(trainer, dataloader, RunningStage.TRAINING) - - trainer = Trainer(default_root_dir=tmpdir, max_epochs=1) - _validate_fault_tolerant_training(trainer, dataloader, RunningStage.TRAINING) + _validate_fault_tolerant_training(dataloader, RunningStage.TRAINING) with pytest.raises(MisconfigurationException, match="Fault Tolerant Training supports only a single dataloader."): dataloaders = CombinedLoader([DataLoader(data), DataLoader(range(10))]) - _validate_fault_tolerant_training(trainer, dataloaders, RunningStage.TRAINING) + _validate_fault_tolerant_training(dataloaders, RunningStage.TRAINING) with pytest.raises(MisconfigurationException, match="Fault Tolerant Training supports only a single dataloader."): dataloaders = CombinedLoader([DataLoader(data), DataLoader(range(10))], mode="max_size_cycle") - _validate_fault_tolerant_training(trainer, dataloaders, RunningStage.TRAINING) + _validate_fault_tolerant_training(dataloaders, RunningStage.TRAINING) dataloaders = [DataLoader(data), DataLoader(range(10))] with pytest.raises(MisconfigurationException, match="Fault Tolerant Training supports only a single dataloader."): - _validate_fault_tolerant_training(trainer, dataloaders, RunningStage.TRAINING) + _validate_fault_tolerant_training(dataloaders, RunningStage.TRAINING) - _validate_fault_tolerant_training(trainer, dataloaders, RunningStage.VALIDATING) + _validate_fault_tolerant_training(dataloaders, RunningStage.VALIDATING) with pytest.raises(MisconfigurationException, match="RandomSampler"): @@ -1238,7 +1230,7 @@ class CustomRandomSampler(RandomSampler): pass dataloader = DataLoader(data, sampler=CustomRandomSampler(data)) - _validate_fault_tolerant_training(trainer, dataloader, RunningStage.TRAINING) + _validate_fault_tolerant_training(dataloader, RunningStage.TRAINING) with pytest.raises(MisconfigurationException, match="BatchSampler"): @@ -1248,7 +1240,7 @@ class CustomBatchSampler(BatchSampler): sampler = Sampler(data) batch_sampler = CustomBatchSampler(sampler, 2, False) dataloader = DataLoader(data, batch_sampler=batch_sampler) - _validate_fault_tolerant_training(trainer, dataloader, RunningStage.TRAINING) + _validate_fault_tolerant_training(dataloader, RunningStage.TRAINING) with pytest.raises(MisconfigurationException, match="without a `__next__` method"): @@ -1258,7 +1250,7 @@ def __iter__(self): yield 0 dataloader = DataLoader(CustomIterable()) - _validate_fault_tolerant_training(trainer, dataloader, RunningStage.TRAINING) + _validate_fault_tolerant_training(dataloader, RunningStage.TRAINING) with pytest.raises(MisconfigurationException, match="IterableDataset without a sampler as attribute"): @@ -1270,7 +1262,7 @@ def __next__(self): return torch.tensor(0) dataloader = DataLoader(CustomIterable()) - _validate_fault_tolerant_training(trainer, dataloader, RunningStage.TRAINING) + _validate_fault_tolerant_training(dataloader, RunningStage.TRAINING) with pytest.raises(MisconfigurationException, match="RandomSampler"): @@ -1287,8 +1279,8 @@ def __next__(self): return torch.tensor(0) dataloader = DataLoader(CustomIterable()) - _validate_fault_tolerant_training(trainer, dataloader, RunningStage.TRAINING) + _validate_fault_tolerant_training(dataloader, RunningStage.TRAINING) dataloaders = [DataLoader(data), DataLoader(CustomIterable())] with pytest.raises(MisconfigurationException, match="RandomSampler"): - _validate_fault_tolerant_training(trainer, dataloaders, RunningStage.VALIDATING) + _validate_fault_tolerant_training(dataloaders, RunningStage.VALIDATING) From 57760fa310fdb2a2af24b01454cc247425cd368a Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 11 Nov 2021 13:00:51 +0000 Subject: [PATCH 07/22] update --- pytorch_lightning/trainer/data_loading.py | 6 +- pytorch_lightning/utilities/auto_restart.py | 121 ++++++++++++-------- tests/utilities/test_auto_restart.py | 29 +++++ 3 files changed, 110 insertions(+), 46 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 727ea5f6d3617..a0aef45017bf1 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -580,9 +580,13 @@ def request_dataloader( if isinstance(dataloader, tuple): dataloader = list(dataloader) self.training_type_plugin.barrier("get_dataloaders") - _validate_fault_tolerant_training(dataloader, stage) + self._validate_fault_tolerant_training(dataloader, stage) return dataloader + @staticmethod + def _validate_fault_tolerant_training(dataloader: Any, stage: RunningStage) -> None: + _validate_fault_tolerant_training(dataloader, stage) + @staticmethod def _add_sampler_metadata_collate(dataloader: DataLoader) -> None: """Wrap default collate function to retrive ``FastForwardSampler`` state dict when fault tolerant is diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index edfb0e85943f3..d982833d9fc8e 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -17,7 +17,7 @@ from functools import partial, wraps from random import getstate as python_get_rng_state from random import setstate as python_set_rng_state -from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, Union import numpy as np import torch @@ -594,63 +594,94 @@ def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, A raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.") -def _validate_fault_tolerant_training(dataloader: Iterable, stage: RunningStage) -> None: - """This function is used to validate fault tolerant training is possible with the user data.""" - from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator +def _validate_iterable_dataset(dataloader: DataLoader) -> bool: - if not _fault_tolerant_training(): - return + SUPPORTED_SAMPLERS = (RandomSampler, SequentialSampler, DistributedSampler) - if isinstance(dataloader, CombinedLoader): - count_dataloader = 0 + dataset = dataloader.dataset - def increment_count_dataloader(dataloader: Union[DataLoader, CycleIterator]) -> None: - nonlocal count_dataloader - count_dataloader += 1 + next_fn = getattr(dataset, "__next__", None) + if not next_fn: + raise MisconfigurationException( + "Fault Tolerant Training doesn't support an IterableDataset without a `__next__`" + " method implemented. Hint: We recommend you to move your logic inside and rely " + "on a sampler, generator to perform the iteration." + ) - apply_to_collection(dataloader.loaders, (DataLoader, CycleIterator), increment_count_dataloader) - if count_dataloader > 1: - raise MisconfigurationException("Fault Tolerant Training supports only a single dataloader.") + samplers = {k: v for k, v in dataset.__dict__.items() if isinstance(v, Sampler)} - dataloader = dataloader.loaders + if not samplers: + raise MisconfigurationException( + "Fault Tolerant Training doesn't support an IterableDataset without a sampler as attribute." + ) - elif isinstance(dataloader, Sequence): - if len(dataloader) > 1 and stage == RunningStage.TRAINING: - raise MisconfigurationException("Fault Tolerant Training supports only a single dataloader.") + sampler = [v for v in samplers.values() if v.__class__ in SUPPORTED_SAMPLERS] - supported_samplers = (RandomSampler, SequentialSampler, DistributedSampler) + if not sampler: + raise MisconfigurationException(f"Fault Tolerant Training supports only {SUPPORTED_SAMPLERS}.") - dataloaders = dataloader if isinstance(dataloader, Sequence) else [dataloader] + if len(sampler) > 1: + raise MisconfigurationException(f"A single sampler is supported within an Iterable Dataset. Found {sampler}.") - for dataloader in dataloaders: + if sampler[0].__class__ == DistributedSampler: + return not sampler.shuffle, "A `DistributedSampler` sampler shuffle attribute is set to True." - if isinstance(dataloader.dataset, IterableDataset): - dataset = dataloader.dataset + return sampler[0].__class__ == SequentialSampler, "" - next_fn = getattr(dataset, "__next__", None) - if not next_fn: - raise MisconfigurationException( - "Fault Tolerant Training doesn't support an IterableDataset without a `__next__`" - " method implemented. Hint: We recommend you to move your logic inside and rely " - "on a sampler, generator to perform the iteration." - ) - samplers = {k: v for k, v in dataset.__dict__.items() if isinstance(v, Sampler)} +def _validate_map_dataset(dataloader: DataLoader) -> bool: + SUPPORTED_SAMPLERS = (RandomSampler, SequentialSampler, DistributedSampler) - if not samplers: - raise MisconfigurationException( - "Fault Tolerant Training doesn't support an IterableDataset without a sampler as attribute." - ) + sampler = getattr(dataloader, "sampler", None) + if sampler is not None and sampler.__class__ not in SUPPORTED_SAMPLERS: + raise MisconfigurationException(f"Fault Tolerant Training supports only {SUPPORTED_SAMPLERS}.") - for v in samplers.keys(): - if v.__class__ not in supported_samplers: - raise MisconfigurationException(f"Fault Tolerant Training supports only {supported_samplers}.") + batch_sampler = getattr(dataloader, "batch_sampler", None) + if batch_sampler is not None and batch_sampler.__class__ is not BatchSampler: + raise MisconfigurationException("Fault Tolerant Training supports only a BatchSampler.") - else: - sampler = getattr(dataloader, "sampler", None) - if sampler is not None and sampler.__class__ not in supported_samplers: - raise MisconfigurationException(f"Fault Tolerant Training supports only {supported_samplers}.") + if sampler.__class__ == DistributedSampler: + return not sampler.shuffle, "A `DistributedSampler` sampler shuffle attribute is set to True." + + return sampler.__class__ == SequentialSampler, "" + + +def _validate_fault_tolerant_training(dataloader: Iterable, stage: RunningStage) -> None: + """This function is used to validate fault tolerant training is possible with the user data.""" + from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator + + if not _fault_tolerant_training(): + return + + if isinstance(dataloader, CombinedLoader): + dataloaders = dataloader.loaders + else: + dataloaders = dataloader + + dl_loaders = [] + + def flatten_dataloader(dataloader: Union[DataLoader, CycleIterator, Iterable]) -> None: + nonlocal dl_loaders + if isinstance(dataloader, CycleIterator): + dataloader = dataloader.loader + dl_loaders.append(dataloader) + + apply_to_collection(dataloaders, (DataLoader, CycleIterator), flatten_dataloader) + + if len(dl_loaders) > 1 and stage == RunningStage.TRAINING: + if not all(getattr(dl.dataset, "deterministic", False) for dl in dl_loaders): + raise MisconfigurationException("Fault Tolerant Training supports only a single dataloader.") + + supported = [] + messages = [] + + for dataloader in dl_loaders: + validator_fn = ( + _validate_iterable_dataset if isinstance(dataloader.dataset, IterableDataset) else _validate_map_dataset + ) + is_supported, message = validator_fn(dataloader) + supported.append(is_supported) + messages.append(message) - batch_sampler = getattr(dataloader, "batch_sampler", None) - if batch_sampler is not None and batch_sampler.__class__ is not BatchSampler: - raise MisconfigurationException("Fault Tolerant Training supports only a BatchSampler.") + if len(dl_loaders) > 1 and sum(supported) != len(dl_loaders): + raise MisconfigurationException(f"The current combinaison of DataLoaders isn't supported. Messages: {messages}") diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 0b98536241257..d31cf93330236 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -755,10 +755,14 @@ def on_train_batch_start(self, trainer, *_) -> None: model = TestModel() model.training_epoch_end = None trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_train_batches=1, callbacks=Check()) + trainer._validate_fault_tolerant_training = lambda x, y: None trainer.fit(model) class SequentialGetItemDataset(Dataset): + + deterministic = True + def __init__(self, length, *_): self.len = length @@ -873,6 +877,9 @@ class CustomException(Exception): class SequentialIterableDataset(IterableDataset): + + deterministic = True + def __init__(self, length, *_): self.len = length self.sampler = SequentialSampler(range(self.len)) @@ -1224,6 +1231,28 @@ def test_validate_fault_tolerant(tmpdir): _validate_fault_tolerant_training(dataloaders, RunningStage.VALIDATING) + dataloaders = [DataLoader(data, sampler=DistributedSampler(data, num_replicas=2, rank=0, shuffle=True))] + _validate_fault_tolerant_training(dataloaders, RunningStage.TRAINING) + + dataloaders = [DataLoader(data, sampler=DistributedSampler(data, num_replicas=2, rank=0, shuffle=False))] + _validate_fault_tolerant_training(dataloaders, RunningStage.TRAINING) + + dataset = SequentialGetItemDataset(2) + dataloaders = [ + DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=False)), + DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=False)), + ] + + _validate_fault_tolerant_training(dataloaders, RunningStage.TRAINING) + + dataloaders = [ + DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=True)), + DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=False)), + ] + + with pytest.raises(MisconfigurationException, match="The current combinaison of DataLoaders isn't supported."): + _validate_fault_tolerant_training(dataloaders, RunningStage.TRAINING) + with pytest.raises(MisconfigurationException, match="RandomSampler"): class CustomRandomSampler(RandomSampler): From a7f4d0a1913a53b900fccb068bcc96f8ee86294a Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 15 Nov 2021 17:34:24 +0000 Subject: [PATCH 08/22] update --- pytorch_lightning/utilities/auto_restart.py | 16 ++++++++-------- tests/utilities/test_auto_restart.py | 4 ++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index d982833d9fc8e..b0d683e2ca508 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -594,17 +594,16 @@ def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, A raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.") -def _validate_iterable_dataset(dataloader: DataLoader) -> bool: +def _validate_iterable_dataset(dataloader: DataLoader) -> Tuple[bool, str]: SUPPORTED_SAMPLERS = (RandomSampler, SequentialSampler, DistributedSampler) dataset = dataloader.dataset - next_fn = getattr(dataset, "__next__", None) - if not next_fn: + if getattr(dataset, "__next__", None) is None: raise MisconfigurationException( - "Fault Tolerant Training doesn't support an IterableDataset without a `__next__`" - " method implemented. Hint: We recommend you to move your logic inside and rely " + "Fault Tolerant Training doesn't support an IterableDataset without a `__next__` " + "method implemented. Hint: We recommend you to move your logic inside and rely " "on a sampler, generator to perform the iteration." ) @@ -629,7 +628,7 @@ def _validate_iterable_dataset(dataloader: DataLoader) -> bool: return sampler[0].__class__ == SequentialSampler, "" -def _validate_map_dataset(dataloader: DataLoader) -> bool: +def _validate_map_dataset(dataloader: DataLoader) -> Tuple[bool, str]: SUPPORTED_SAMPLERS = (RandomSampler, SequentialSampler, DistributedSampler) sampler = getattr(dataloader, "sampler", None) @@ -647,7 +646,7 @@ def _validate_map_dataset(dataloader: DataLoader) -> bool: def _validate_fault_tolerant_training(dataloader: Iterable, stage: RunningStage) -> None: - """This function is used to validate fault tolerant training is possible with the user data.""" + """This function is used to validate that fault tolerant training is possible with the user data.""" from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator if not _fault_tolerant_training(): @@ -669,7 +668,8 @@ def flatten_dataloader(dataloader: Union[DataLoader, CycleIterator, Iterable]) - apply_to_collection(dataloaders, (DataLoader, CycleIterator), flatten_dataloader) if len(dl_loaders) > 1 and stage == RunningStage.TRAINING: - if not all(getattr(dl.dataset, "deterministic", False) for dl in dl_loaders): + # Fixme: Find a better API or refactor the auto_restart test. + if not all(getattr(dl.dataset, "_deterministic", False) for dl in dl_loaders): raise MisconfigurationException("Fault Tolerant Training supports only a single dataloader.") supported = [] diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index d31cf93330236..c692497c69a08 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -761,7 +761,7 @@ def on_train_batch_start(self, trainer, *_) -> None: class SequentialGetItemDataset(Dataset): - deterministic = True + _deterministic = True def __init__(self, length, *_): self.len = length @@ -878,7 +878,7 @@ class CustomException(Exception): class SequentialIterableDataset(IterableDataset): - deterministic = True + _deterministic = True def __init__(self, length, *_): self.len = length From d1575e55f0d3b888fb12cdf2f96d5880466328af Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 24 Nov 2021 18:46:42 +0000 Subject: [PATCH 09/22] update --- tests/utilities/test_auto_restart.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 28949b89d86e7..4e09f2dd79c91 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -1210,14 +1210,15 @@ def test_validate_fault_tolerant(tmpdir): DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=False)), ] - _validate_fault_tolerant_training(dataloaders, RunningStage.TRAINING) + with pytest.raises(MisconfigurationException, match="Fault Tolerant Training supports only a single dataloader."): + _validate_fault_tolerant_training(dataloaders, RunningStage.TRAINING) dataloaders = [ DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=True)), DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=False)), ] - with pytest.raises(MisconfigurationException, match="The current combinaison of DataLoaders isn't supported."): + with pytest.raises(MisconfigurationException, match="Fault Tolerant Training supports only a single."): _validate_fault_tolerant_training(dataloaders, RunningStage.TRAINING) with pytest.raises(MisconfigurationException, match="RandomSampler"): From e83ea724ac4eb62191f565727d7b7456e61ff26c Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 24 Nov 2021 18:49:20 +0000 Subject: [PATCH 10/22] resolve test --- tests/utilities/test_auto_restart.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 4e09f2dd79c91..c9c56895476d1 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -722,11 +722,12 @@ def on_train_batch_start(self, trainer, *_) -> None: assert not isinstance(loaders["a"][2].loader.dataset, CaptureMapDataset) assert isinstance(loaders["b"].loader.dataset, RangeIterableDataset) - with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": use_fault_tolerant}): + with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": use_fault_tolerant}), mock.patch( + "pytorch_lightning.trainer.data_loading._validate_fault_tolerant_training", lambda x, y: None + ): model = TestModel() model.training_epoch_end = None trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_train_batches=1, callbacks=Check()) - trainer._validate_fault_tolerant_training = lambda x, y: None trainer.fit(model) From 043193a20ad751e4cb32dd1cabc5bca61ed23922 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 24 Nov 2021 19:17:29 +0000 Subject: [PATCH 11/22] update --- pytorch_lightning/utilities/auto_restart.py | 18 ++++++++--------- tests/utilities/test_auto_restart.py | 22 ++++++++++----------- 2 files changed, 20 insertions(+), 20 deletions(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 5b3c811bd47dc..b25466faa2510 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -755,7 +755,7 @@ def _validate_iterable_dataset(dataloader: DataLoader) -> Tuple[bool, str]: dataset = dataloader.dataset if getattr(dataset, "__next__", None) is None: - raise MisconfigurationException( + raise AttributeError( "Fault Tolerant Training doesn't support an IterableDataset without a `__next__` " "method implemented. Hint: We recommend you to move your logic inside and rely " "on a sampler, generator to perform the iteration." @@ -764,17 +764,17 @@ def _validate_iterable_dataset(dataloader: DataLoader) -> Tuple[bool, str]: samplers = {k: v for k, v in dataset.__dict__.items() if isinstance(v, Sampler)} if not samplers: - raise MisconfigurationException( + raise AttributeError( "Fault Tolerant Training doesn't support an IterableDataset without a sampler as attribute." ) sampler = [v for v in samplers.values() if type(v) in SUPPORTED_SAMPLERS] if not sampler: - raise MisconfigurationException(f"Fault Tolerant Training supports only {SUPPORTED_SAMPLERS}.") + raise TypeError(f"Fault Tolerant Training supports only {SUPPORTED_SAMPLERS}.") if len(sampler) > 1: - raise MisconfigurationException(f"A single sampler is supported within an Iterable Dataset. Found {sampler}.") + raise ValueError(f"A single sampler is supported within an Iterable Dataset. Found {sampler}.") if type(sampler[0]) is DistributedSampler: return not sampler.shuffle, "A `DistributedSampler` sampler shuffle attribute is set to True." @@ -786,12 +786,12 @@ def _validate_map_dataset(dataloader: DataLoader) -> Tuple[bool, str]: SUPPORTED_SAMPLERS = (RandomSampler, SequentialSampler, DistributedSampler) sampler = getattr(dataloader, "sampler", None) - if sampler is not None and sampler.__class__ not in SUPPORTED_SAMPLERS: - raise MisconfigurationException(f"Fault Tolerant Training supports only {SUPPORTED_SAMPLERS}.") + if sampler is not None and type(sampler) not in SUPPORTED_SAMPLERS: + raise ValueError(f"Fault Tolerant Training supports only {SUPPORTED_SAMPLERS}.") batch_sampler = getattr(dataloader, "batch_sampler", None) if batch_sampler is not None and type(batch_sampler) is not BatchSampler: - raise MisconfigurationException("Fault Tolerant Training supports only a BatchSampler.") + raise ValueError("Fault Tolerant Training supports only a BatchSampler.") if type(sampler) is DistributedSampler: return not sampler.shuffle, "A `DistributedSampler` sampler shuffle attribute is set to True." @@ -822,7 +822,7 @@ def flatten_dataloader(dataloader: Union[DataLoader, CycleIterator, Iterable]) - apply_to_collection(dataloaders, (DataLoader, CycleIterator), flatten_dataloader) if len(dl_loaders) > 1 and stage == pl.trainer.states.RunningStage.TRAINING: - raise MisconfigurationException("Fault Tolerant Training supports only a single dataloader.") + raise ValueError("Fault Tolerant Training supports only a single dataloader.") supported = [] messages = [] @@ -836,4 +836,4 @@ def flatten_dataloader(dataloader: Union[DataLoader, CycleIterator, Iterable]) - messages.append(message) if len(dl_loaders) > 1 and sum(supported) != len(dl_loaders): - raise MisconfigurationException(f"The current combinaison of DataLoaders isn't supported. Messages: {messages}") + raise ValueError(f"The current combinaison of DataLoaders isn't supported. Messages: {messages}") diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index c9c56895476d1..d9e33c8f0e7d0 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -1185,16 +1185,16 @@ def test_validate_fault_tolerant(tmpdir): _validate_fault_tolerant_training(dataloader, RunningStage.TRAINING) - with pytest.raises(MisconfigurationException, match="Fault Tolerant Training supports only a single dataloader."): + with pytest.raises(ValueError, match="Fault Tolerant Training supports only a single dataloader."): dataloaders = CombinedLoader([DataLoader(data), DataLoader(range(10))]) _validate_fault_tolerant_training(dataloaders, RunningStage.TRAINING) - with pytest.raises(MisconfigurationException, match="Fault Tolerant Training supports only a single dataloader."): + with pytest.raises(ValueError, match="Fault Tolerant Training supports only a single dataloader."): dataloaders = CombinedLoader([DataLoader(data), DataLoader(range(10))], mode="max_size_cycle") _validate_fault_tolerant_training(dataloaders, RunningStage.TRAINING) dataloaders = [DataLoader(data), DataLoader(range(10))] - with pytest.raises(MisconfigurationException, match="Fault Tolerant Training supports only a single dataloader."): + with pytest.raises(ValueError, match="Fault Tolerant Training supports only a single dataloader."): _validate_fault_tolerant_training(dataloaders, RunningStage.TRAINING) _validate_fault_tolerant_training(dataloaders, RunningStage.VALIDATING) @@ -1211,7 +1211,7 @@ def test_validate_fault_tolerant(tmpdir): DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=False)), ] - with pytest.raises(MisconfigurationException, match="Fault Tolerant Training supports only a single dataloader."): + with pytest.raises(ValueError, match="Fault Tolerant Training supports only a single dataloader."): _validate_fault_tolerant_training(dataloaders, RunningStage.TRAINING) dataloaders = [ @@ -1219,10 +1219,10 @@ def test_validate_fault_tolerant(tmpdir): DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=False)), ] - with pytest.raises(MisconfigurationException, match="Fault Tolerant Training supports only a single."): + with pytest.raises(ValueError, match="Fault Tolerant Training supports only a single."): _validate_fault_tolerant_training(dataloaders, RunningStage.TRAINING) - with pytest.raises(MisconfigurationException, match="RandomSampler"): + with pytest.raises(ValueError, match="RandomSampler"): class CustomRandomSampler(RandomSampler): pass @@ -1230,7 +1230,7 @@ class CustomRandomSampler(RandomSampler): dataloader = DataLoader(data, sampler=CustomRandomSampler(data)) _validate_fault_tolerant_training(dataloader, RunningStage.TRAINING) - with pytest.raises(MisconfigurationException, match="BatchSampler"): + with pytest.raises(ValueError, match="BatchSampler"): class CustomBatchSampler(BatchSampler): pass @@ -1240,7 +1240,7 @@ class CustomBatchSampler(BatchSampler): dataloader = DataLoader(data, batch_sampler=batch_sampler) _validate_fault_tolerant_training(dataloader, RunningStage.TRAINING) - with pytest.raises(MisconfigurationException, match="without a `__next__` method"): + with pytest.raises(AttributeError, match="without a `__next__` method"): class CustomIterable(IterableDataset): def __iter__(self): @@ -1250,7 +1250,7 @@ def __iter__(self): dataloader = DataLoader(CustomIterable()) _validate_fault_tolerant_training(dataloader, RunningStage.TRAINING) - with pytest.raises(MisconfigurationException, match="IterableDataset without a sampler as attribute"): + with pytest.raises(AttributeError, match="IterableDataset without a sampler as attribute"): class CustomIterable(IterableDataset): def __iter__(self): @@ -1262,7 +1262,7 @@ def __next__(self): dataloader = DataLoader(CustomIterable()) _validate_fault_tolerant_training(dataloader, RunningStage.TRAINING) - with pytest.raises(MisconfigurationException, match="RandomSampler"): + with pytest.raises(TypeError, match="RandomSampler"): class CustomIterable(IterableDataset): def __init__(self): @@ -1280,7 +1280,7 @@ def __next__(self): _validate_fault_tolerant_training(dataloader, RunningStage.TRAINING) dataloaders = [DataLoader(data), DataLoader(CustomIterable())] - with pytest.raises(MisconfigurationException, match="RandomSampler"): + with pytest.raises(TypeError, match="RandomSampler"): _validate_fault_tolerant_training(dataloaders, RunningStage.VALIDATING) From c068ff6f17327db8a779731620d5f42c8577e920 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 24 Nov 2021 19:21:04 +0000 Subject: [PATCH 12/22] update --- pytorch_lightning/utilities/auto_restart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index b25466faa2510..2b65397c55f56 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -803,7 +803,7 @@ def _validate_fault_tolerant_training(dataloader: Iterable, stage: "pl.trainer.s """This function is used to validate that fault tolerant training is possible with the user data.""" from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator - if not _FaultTolerantMode.detect_current_mode().is_enabled: + if not _FaultTolerantMode.detect_current_mode().is_automatic: return if isinstance(dataloader, CombinedLoader): From 0cd756660d2f2872d72e49e3cc6397dbde562ce2 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 24 Nov 2021 19:21:33 +0000 Subject: [PATCH 13/22] update --- pytorch_lightning/trainer/data_loading.py | 4 +-- pytorch_lightning/utilities/auto_restart.py | 2 +- tests/utilities/test_auto_restart.py | 36 ++++++++++----------- 3 files changed, 21 insertions(+), 21 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 92488b715cf53..b769e7e444300 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -28,7 +28,7 @@ from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator from pytorch_lightning.utilities import rank_zero_warn from pytorch_lightning.utilities.apply_func import apply_to_collection -from pytorch_lightning.utilities.auto_restart import _add_capture_metadata_collate, _validate_fault_tolerant_training +from pytorch_lightning.utilities.auto_restart import _add_capture_metadata_collate, _validate_fault_tolerant_automatic from pytorch_lightning.utilities.data import ( _auto_add_worker_init_fn, _replace_dataloader_init_method, @@ -438,7 +438,7 @@ def request_dataloader( if isinstance(dataloader, tuple): dataloader = list(dataloader) self.training_type_plugin.barrier("get_dataloaders") - _validate_fault_tolerant_training(dataloader, stage) + _validate_fault_tolerant_automatic(dataloader, stage) return dataloader @staticmethod diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 2b65397c55f56..0ab49f3a61251 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -799,7 +799,7 @@ def _validate_map_dataset(dataloader: DataLoader) -> Tuple[bool, str]: return type(sampler) is SequentialSampler, "" -def _validate_fault_tolerant_training(dataloader: Iterable, stage: "pl.trainer.states.RunningStage") -> None: +def _validate_fault_tolerant_automatic(dataloader: Iterable, stage: "pl.trainer.states.RunningStage") -> None: """This function is used to validate that fault tolerant training is possible with the user data.""" from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index d9e33c8f0e7d0..ed24f7148cc92 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -48,7 +48,7 @@ _SingleProcessDataLoaderIterStateful, _SupportsStateDict, _teardown_dataloader_get_iterators, - _validate_fault_tolerant_training, + _validate_fault_tolerant_automatic, CaptureIterableDataset, CaptureMapDataset, FastForwardSampler, @@ -723,7 +723,7 @@ def on_train_batch_start(self, trainer, *_) -> None: assert isinstance(loaders["b"].loader.dataset, RangeIterableDataset) with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": use_fault_tolerant}), mock.patch( - "pytorch_lightning.trainer.data_loading._validate_fault_tolerant_training", lambda x, y: None + "pytorch_lightning.trainer.data_loading._validate_fault_tolerant_automatic", lambda x, y: None ): model = TestModel() model.training_epoch_end = None @@ -924,7 +924,7 @@ def test_dataset_rng_states_restart_with_lightning(tmpdir, dataset_classes, mult multiple_trainloader_mode=multiple_trainloader_mode, ) - with mock.patch("pytorch_lightning.trainer.data_loading._validate_fault_tolerant_training", lambda x, y: None): + with mock.patch("pytorch_lightning.trainer.data_loading._validate_fault_tolerant_automatic", lambda x, y: None): all_batches, weights0 = _run_training(trainer_kwargs, dataset_classes) all_batches = torch.stack(all_batches) @@ -1183,27 +1183,27 @@ def test_validate_fault_tolerant(tmpdir): data = range(10) dataloader = DataLoader(data) - _validate_fault_tolerant_training(dataloader, RunningStage.TRAINING) + _validate_fault_tolerant_automatic(dataloader, RunningStage.TRAINING) with pytest.raises(ValueError, match="Fault Tolerant Training supports only a single dataloader."): dataloaders = CombinedLoader([DataLoader(data), DataLoader(range(10))]) - _validate_fault_tolerant_training(dataloaders, RunningStage.TRAINING) + _validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING) with pytest.raises(ValueError, match="Fault Tolerant Training supports only a single dataloader."): dataloaders = CombinedLoader([DataLoader(data), DataLoader(range(10))], mode="max_size_cycle") - _validate_fault_tolerant_training(dataloaders, RunningStage.TRAINING) + _validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING) dataloaders = [DataLoader(data), DataLoader(range(10))] with pytest.raises(ValueError, match="Fault Tolerant Training supports only a single dataloader."): - _validate_fault_tolerant_training(dataloaders, RunningStage.TRAINING) + _validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING) - _validate_fault_tolerant_training(dataloaders, RunningStage.VALIDATING) + _validate_fault_tolerant_automatic(dataloaders, RunningStage.VALIDATING) dataloaders = [DataLoader(data, sampler=DistributedSampler(data, num_replicas=2, rank=0, shuffle=True))] - _validate_fault_tolerant_training(dataloaders, RunningStage.TRAINING) + _validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING) dataloaders = [DataLoader(data, sampler=DistributedSampler(data, num_replicas=2, rank=0, shuffle=False))] - _validate_fault_tolerant_training(dataloaders, RunningStage.TRAINING) + _validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING) dataset = SequentialGetItemDataset(2) dataloaders = [ @@ -1212,7 +1212,7 @@ def test_validate_fault_tolerant(tmpdir): ] with pytest.raises(ValueError, match="Fault Tolerant Training supports only a single dataloader."): - _validate_fault_tolerant_training(dataloaders, RunningStage.TRAINING) + _validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING) dataloaders = [ DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=True)), @@ -1220,7 +1220,7 @@ def test_validate_fault_tolerant(tmpdir): ] with pytest.raises(ValueError, match="Fault Tolerant Training supports only a single."): - _validate_fault_tolerant_training(dataloaders, RunningStage.TRAINING) + _validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING) with pytest.raises(ValueError, match="RandomSampler"): @@ -1228,7 +1228,7 @@ class CustomRandomSampler(RandomSampler): pass dataloader = DataLoader(data, sampler=CustomRandomSampler(data)) - _validate_fault_tolerant_training(dataloader, RunningStage.TRAINING) + _validate_fault_tolerant_automatic(dataloader, RunningStage.TRAINING) with pytest.raises(ValueError, match="BatchSampler"): @@ -1238,7 +1238,7 @@ class CustomBatchSampler(BatchSampler): sampler = Sampler(data) batch_sampler = CustomBatchSampler(sampler, 2, False) dataloader = DataLoader(data, batch_sampler=batch_sampler) - _validate_fault_tolerant_training(dataloader, RunningStage.TRAINING) + _validate_fault_tolerant_automatic(dataloader, RunningStage.TRAINING) with pytest.raises(AttributeError, match="without a `__next__` method"): @@ -1248,7 +1248,7 @@ def __iter__(self): yield 0 dataloader = DataLoader(CustomIterable()) - _validate_fault_tolerant_training(dataloader, RunningStage.TRAINING) + _validate_fault_tolerant_automatic(dataloader, RunningStage.TRAINING) with pytest.raises(AttributeError, match="IterableDataset without a sampler as attribute"): @@ -1260,7 +1260,7 @@ def __next__(self): return torch.tensor(0) dataloader = DataLoader(CustomIterable()) - _validate_fault_tolerant_training(dataloader, RunningStage.TRAINING) + _validate_fault_tolerant_automatic(dataloader, RunningStage.TRAINING) with pytest.raises(TypeError, match="RandomSampler"): @@ -1277,11 +1277,11 @@ def __next__(self): return torch.tensor(0) dataloader = DataLoader(CustomIterable()) - _validate_fault_tolerant_training(dataloader, RunningStage.TRAINING) + _validate_fault_tolerant_automatic(dataloader, RunningStage.TRAINING) dataloaders = [DataLoader(data), DataLoader(CustomIterable())] with pytest.raises(TypeError, match="RandomSampler"): - _validate_fault_tolerant_training(dataloaders, RunningStage.VALIDATING) + _validate_fault_tolerant_automatic(dataloaders, RunningStage.VALIDATING) def test_rotate_worker_indices(): From fb11d8726f2bbc10ded6235aebe26926e2f9b20b Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 24 Nov 2021 19:48:43 +0000 Subject: [PATCH 14/22] update --- tests/trainer/connectors/test_signal_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/connectors/test_signal_connector.py b/tests/trainer/connectors/test_signal_connector.py index fbfce158e3675..e96cdac196636 100644 --- a/tests/trainer/connectors/test_signal_connector.py +++ b/tests/trainer/connectors/test_signal_connector.py @@ -28,7 +28,7 @@ @pytest.mark.parametrize("register_handler", [False, True]) -@pytest.mark.parametrize("terminate_gracefully", [False, True]) +@pytest.mark.parametrize("terminate_gracefully", [True]) @RunIf(skip_windows=True) def test_fault_tolerant_sig_handler(register_handler, terminate_gracefully, tmpdir): From 4d046c6827fa0b5c6af96b696d6ae6a20ae41b51 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 24 Nov 2021 19:50:10 +0000 Subject: [PATCH 15/22] update --- tests/trainer/connectors/test_signal_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/connectors/test_signal_connector.py b/tests/trainer/connectors/test_signal_connector.py index e96cdac196636..fbfce158e3675 100644 --- a/tests/trainer/connectors/test_signal_connector.py +++ b/tests/trainer/connectors/test_signal_connector.py @@ -28,7 +28,7 @@ @pytest.mark.parametrize("register_handler", [False, True]) -@pytest.mark.parametrize("terminate_gracefully", [True]) +@pytest.mark.parametrize("terminate_gracefully", [False, True]) @RunIf(skip_windows=True) def test_fault_tolerant_sig_handler(register_handler, terminate_gracefully, tmpdir): From 3ee013ceb8751442e391eea8e708b1b50512084b Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 25 Nov 2021 21:13:47 +0000 Subject: [PATCH 16/22] update --- pytorch_lightning/utilities/auto_restart.py | 24 ++++++++++----------- tests/utilities/test_auto_restart.py | 6 +++--- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 79c0548db8255..9cb2cf6f01a5d 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -378,7 +378,7 @@ def _cycle_to_next_worker_and_reset(dataloader: DataLoader, state_dict: Dict[str # get current num workers num_workers = getattr(iter_dataloader, "_num_workers", 0) # as `state_dict` are workers dependent, Lightning doesn't support changing - # the `num_workers` for fault tolerant training + # the `num_workers` for Fault-tolerance if state_dict["num_workers"] != num_workers: raise MisconfigurationException( f"The provided `num_workers` {num_workers} doesn't match the one used " @@ -742,7 +742,7 @@ def _patch_dataloader_get_iterators() -> None: def _teardown_dataloader_get_iterators() -> None: """This function is used to restore the DataLoader `get_iterator` with its original one.""" - # cleanup the get_iterator replacement in case of Fault Tolerant Training. + # cleanup the get_iterator replacement in case of Fault-tolerance. get_iterator = getattr(DataLoader, "_ori_get_iterator", None) if get_iterator: DataLoader._get_iterator = get_iterator @@ -757,22 +757,20 @@ def _validate_iterable_dataset(dataloader: DataLoader) -> Tuple[bool, str]: if getattr(dataset, "__next__", None) is None: raise AttributeError( - "Fault Tolerant Training doesn't support an IterableDataset without a `__next__` " - "method implemented. Hint: We recommend you to move your logic inside and rely " - "on a sampler, generator to perform the iteration." + "Fault-tolerance doesn't support an `IterableDataset` without `__next__` " + "method implemented. Hint: We recommend you to move your logic from `__iter__`" + " inside and rely on a sampler to perform the sample sampling." ) samplers = {k: v for k, v in dataset.__dict__.items() if isinstance(v, Sampler)} if not samplers: - raise AttributeError( - "Fault Tolerant Training doesn't support an IterableDataset without a sampler as attribute." - ) + raise AttributeError("Fault-tolerance doesn't support an IterableDataset without a sampler as attribute.") sampler = [v for v in samplers.values() if type(v) in SUPPORTED_SAMPLERS] if not sampler: - raise TypeError(f"Fault Tolerant Training supports only {SUPPORTED_SAMPLERS}.") + raise TypeError(f"Fault-tolerance supports only {SUPPORTED_SAMPLERS}.") if len(sampler) > 1: raise ValueError(f"A single sampler is supported within an Iterable Dataset. Found {sampler}.") @@ -788,11 +786,11 @@ def _validate_map_dataset(dataloader: DataLoader) -> Tuple[bool, str]: sampler = getattr(dataloader, "sampler", None) if sampler is not None and type(sampler) not in SUPPORTED_SAMPLERS: - raise ValueError(f"Fault Tolerant Training supports only {SUPPORTED_SAMPLERS}.") + raise ValueError(f"Fault-tolerance supports only {SUPPORTED_SAMPLERS}.") batch_sampler = getattr(dataloader, "batch_sampler", None) if batch_sampler is not None and type(batch_sampler) is not BatchSampler: - raise ValueError("Fault Tolerant Training supports only a BatchSampler.") + raise ValueError("Fault-tolerance supports only a BatchSampler.") if type(sampler) is DistributedSampler: return not sampler.shuffle, "A `DistributedSampler` sampler shuffle attribute is set to True." @@ -801,7 +799,7 @@ def _validate_map_dataset(dataloader: DataLoader) -> Tuple[bool, str]: def _validate_fault_tolerant_automatic(dataloader: Iterable, stage: "pl.trainer.states.RunningStage") -> None: - """This function is used to validate that fault tolerant training is possible with the user data.""" + """This function is used to validate that Fault-tolerance is possible with the user data.""" from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator if not _FaultTolerantMode.detect_current_mode().is_automatic: @@ -823,7 +821,7 @@ def flatten_dataloader(dataloader: Union[DataLoader, CycleIterator, Iterable]) - apply_to_collection(dataloaders, (DataLoader, CycleIterator), flatten_dataloader) if len(dl_loaders) > 1 and stage == pl.trainer.states.RunningStage.TRAINING: - raise ValueError("Fault Tolerant Training supports only a single dataloader.") + raise ValueError("Fault-tolerance supports only a single dataloader.") supported = [] messages = [] diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index bc50dd7e4dc10..5b12be3db492c 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -846,9 +846,6 @@ class CustomException(Exception): class SequentialIterableDataset(IterableDataset): - - _deterministic = True - def __init__(self, length, *_): self.len = length self.sampler = SequentialSampler(range(self.len)) @@ -925,6 +922,9 @@ def test_dataset_rng_states_restart_with_lightning(tmpdir, dataset_classes, mult multiple_trainloader_mode=multiple_trainloader_mode, ) + # this test will fail `fault_tolerant` don't support multiple datasets. + # this tests works as the dataset is fully deterministic and therefore + # there is not overall between the seeds. with mock.patch("pytorch_lightning.trainer.data_loading._validate_fault_tolerant_automatic", lambda x, y: None): all_batches, weights0 = _run_training(trainer_kwargs, dataset_classes) From ad164afae0fc3dea362a51c49528070ef701399e Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 26 Nov 2021 09:28:59 +0000 Subject: [PATCH 17/22] update --- pytorch_lightning/utilities/auto_restart.py | 29 ++++++++------------- tests/utilities/test_auto_restart.py | 25 ++++++++++++------ 2 files changed, 28 insertions(+), 26 deletions(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 9cb2cf6f01a5d..eba3f803d8cd8 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -749,8 +749,7 @@ def _teardown_dataloader_get_iterators() -> None: del DataLoader._ori_get_iterator -def _validate_iterable_dataset(dataloader: DataLoader) -> Tuple[bool, str]: - +def _validate_iterable_dataset(dataloader: DataLoader): SUPPORTED_SAMPLERS = (RandomSampler, SequentialSampler, DistributedSampler) dataset = dataloader.dataset @@ -775,13 +774,14 @@ def _validate_iterable_dataset(dataloader: DataLoader) -> Tuple[bool, str]: if len(sampler) > 1: raise ValueError(f"A single sampler is supported within an Iterable Dataset. Found {sampler}.") - if type(sampler[0]) is DistributedSampler: - return not sampler.shuffle, "A `DistributedSampler` sampler shuffle attribute is set to True." + if type(sampler[0]) is DistributedSampler and sampler.shuffle: + raise ValueError("A `DistributedSampler` sampler shuffle attribute is set to True.") - return type(sampler[0]) is SequentialSampler, "" + if type(sampler[0]) is RandomSampler: + raise ValueError("Only SequentialSampler is supported.") -def _validate_map_dataset(dataloader: DataLoader) -> Tuple[bool, str]: +def _validate_map_dataset(dataloader: DataLoader): SUPPORTED_SAMPLERS = (RandomSampler, SequentialSampler, DistributedSampler) sampler = getattr(dataloader, "sampler", None) @@ -792,10 +792,11 @@ def _validate_map_dataset(dataloader: DataLoader) -> Tuple[bool, str]: if batch_sampler is not None and type(batch_sampler) is not BatchSampler: raise ValueError("Fault-tolerance supports only a BatchSampler.") - if type(sampler) is DistributedSampler: - return not sampler.shuffle, "A `DistributedSampler` sampler shuffle attribute is set to True." + if type(sampler) is DistributedSampler and sampler.shuffle: + raise ValueError("A `DistributedSampler` sampler shuffle attribute is set to True.") - return type(sampler) is SequentialSampler, "" + if type(sampler) is RandomSampler: + raise ValueError("Only SequentialSampler is supported.") def _validate_fault_tolerant_automatic(dataloader: Iterable, stage: "pl.trainer.states.RunningStage") -> None: @@ -823,19 +824,11 @@ def flatten_dataloader(dataloader: Union[DataLoader, CycleIterator, Iterable]) - if len(dl_loaders) > 1 and stage == pl.trainer.states.RunningStage.TRAINING: raise ValueError("Fault-tolerance supports only a single dataloader.") - supported = [] - messages = [] - for dataloader in dl_loaders: validator_fn = ( _validate_iterable_dataset if isinstance(dataloader.dataset, IterableDataset) else _validate_map_dataset ) - is_supported, message = validator_fn(dataloader) - supported.append(is_supported) - messages.append(message) - - if len(dl_loaders) > 1 and sum(supported) != len(dl_loaders): - raise ValueError(f"The current combinaison of DataLoaders isn't supported. Messages: {messages}") + validator_fn(dataloader) def _collect_states_on_rank_zero_over_collection(state_dict: Any, key: str = "state") -> Any: diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 5b12be3db492c..86e953a90f9e1 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -1186,22 +1186,23 @@ def test_validate_fault_tolerant(tmpdir): _validate_fault_tolerant_automatic(dataloader, RunningStage.TRAINING) - with pytest.raises(ValueError, match="Fault Tolerant Training supports only a single dataloader."): + with pytest.raises(ValueError, match="Fault-tolerance supports only a single dataloader."): dataloaders = CombinedLoader([DataLoader(data), DataLoader(range(10))]) _validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING) - with pytest.raises(ValueError, match="Fault Tolerant Training supports only a single dataloader."): + with pytest.raises(ValueError, match="Fault-tolerance supports only a single dataloader."): dataloaders = CombinedLoader([DataLoader(data), DataLoader(range(10))], mode="max_size_cycle") _validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING) dataloaders = [DataLoader(data), DataLoader(range(10))] - with pytest.raises(ValueError, match="Fault Tolerant Training supports only a single dataloader."): + with pytest.raises(ValueError, match="Fault-tolerance supports only a single dataloader."): _validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING) _validate_fault_tolerant_automatic(dataloaders, RunningStage.VALIDATING) - dataloaders = [DataLoader(data, sampler=DistributedSampler(data, num_replicas=2, rank=0, shuffle=True))] - _validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING) + with pytest.raises(ValueError, match="A `DistributedSampler` sampler shuffle attribute is set to True."): + dataloaders = [DataLoader(data, sampler=DistributedSampler(data, num_replicas=2, rank=0, shuffle=True))] + _validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING) dataloaders = [DataLoader(data, sampler=DistributedSampler(data, num_replicas=2, rank=0, shuffle=False))] _validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING) @@ -1212,7 +1213,7 @@ def test_validate_fault_tolerant(tmpdir): DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=False)), ] - with pytest.raises(ValueError, match="Fault Tolerant Training supports only a single dataloader."): + with pytest.raises(ValueError, match="Fault-tolerance supports only a single dataloader."): _validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING) dataloaders = [ @@ -1220,9 +1221,17 @@ def test_validate_fault_tolerant(tmpdir): DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=False)), ] - with pytest.raises(ValueError, match="Fault Tolerant Training supports only a single."): + with pytest.raises(ValueError, match="Fault-tolerance supports only a single."): _validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING) + dataloaders = [ + DataLoader(dataset, sampler=RandomSampler(dataset)), + DataLoader(dataset, sampler=SequentialSampler(dataset)), + ] + + with pytest.raises(ValueError, match="Only SequentialSampler is supported."): + _validate_fault_tolerant_automatic(dataloaders, RunningStage.VALIDATING) + with pytest.raises(ValueError, match="RandomSampler"): class CustomRandomSampler(RandomSampler): @@ -1241,7 +1250,7 @@ class CustomBatchSampler(BatchSampler): dataloader = DataLoader(data, batch_sampler=batch_sampler) _validate_fault_tolerant_automatic(dataloader, RunningStage.TRAINING) - with pytest.raises(AttributeError, match="without a `__next__` method"): + with pytest.raises(AttributeError, match="without `__next__` method"): class CustomIterable(IterableDataset): def __iter__(self): From 146712826512d8fd5e897a781118a0cd9bbafadc Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 26 Nov 2021 17:17:09 +0000 Subject: [PATCH 18/22] update on comments --- pytorch_lightning/utilities/auto_restart.py | 18 ++++----- tests/utilities/test_auto_restart.py | 43 ++++++++++----------- 2 files changed, 30 insertions(+), 31 deletions(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index eba3f803d8cd8..3e6bee557368a 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -749,7 +749,7 @@ def _teardown_dataloader_get_iterators() -> None: del DataLoader._ori_get_iterator -def _validate_iterable_dataset(dataloader: DataLoader): +def _validate_iterable_dataset(dataloader: DataLoader) -> None: SUPPORTED_SAMPLERS = (RandomSampler, SequentialSampler, DistributedSampler) dataset = dataloader.dataset @@ -764,7 +764,7 @@ def _validate_iterable_dataset(dataloader: DataLoader): samplers = {k: v for k, v in dataset.__dict__.items() if isinstance(v, Sampler)} if not samplers: - raise AttributeError("Fault-tolerance doesn't support an IterableDataset without a sampler as attribute.") + raise TypeError("Fault-tolerance doesn't support an IterableDataset without a sampler as attribute.") sampler = [v for v in samplers.values() if type(v) in SUPPORTED_SAMPLERS] @@ -777,26 +777,26 @@ def _validate_iterable_dataset(dataloader: DataLoader): if type(sampler[0]) is DistributedSampler and sampler.shuffle: raise ValueError("A `DistributedSampler` sampler shuffle attribute is set to True.") - if type(sampler[0]) is RandomSampler: - raise ValueError("Only SequentialSampler is supported.") + elif type(sampler[0]) is not SequentialSampler: + raise TypeError("Only `SequentialSampler` is supported.") -def _validate_map_dataset(dataloader: DataLoader): +def _validate_map_dataset(dataloader: DataLoader) -> None: SUPPORTED_SAMPLERS = (RandomSampler, SequentialSampler, DistributedSampler) sampler = getattr(dataloader, "sampler", None) if sampler is not None and type(sampler) not in SUPPORTED_SAMPLERS: - raise ValueError(f"Fault-tolerance supports only {SUPPORTED_SAMPLERS}.") + raise TypeError(f"Fault-tolerance supports only {SUPPORTED_SAMPLERS}.") batch_sampler = getattr(dataloader, "batch_sampler", None) if batch_sampler is not None and type(batch_sampler) is not BatchSampler: - raise ValueError("Fault-tolerance supports only a BatchSampler.") + raise TypeError("Fault-tolerance supports only a `BatchSampler`.") if type(sampler) is DistributedSampler and sampler.shuffle: raise ValueError("A `DistributedSampler` sampler shuffle attribute is set to True.") - if type(sampler) is RandomSampler: - raise ValueError("Only SequentialSampler is supported.") + elif type(sampler) is RandomSampler: + raise ValueError("Only `SequentialSampler` is supported.") def _validate_fault_tolerant_automatic(dataloader: Iterable, stage: "pl.trainer.states.RunningStage") -> None: diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 86e953a90f9e1..29ad4b162b4a6 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -897,6 +897,10 @@ def _run_training(trainer_kwargs, dataset_classes, fail_on_step: int = -1, ckpt_ return model.seen_batches, model.parameters() +# this test will fail `fault_tolerant` don't support multiple datasets. +# this tests works as the dataset is fully deterministic and therefore +# there is not overall between the seeds. +@mock.patch("pytorch_lightning.trainer.data_loading._validate_fault_tolerant_automatic", lambda x, y: None) @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) @pytest.mark.parametrize( "dataset_classes", @@ -922,27 +926,22 @@ def test_dataset_rng_states_restart_with_lightning(tmpdir, dataset_classes, mult multiple_trainloader_mode=multiple_trainloader_mode, ) - # this test will fail `fault_tolerant` don't support multiple datasets. - # this tests works as the dataset is fully deterministic and therefore - # there is not overall between the seeds. - with mock.patch("pytorch_lightning.trainer.data_loading._validate_fault_tolerant_automatic", lambda x, y: None): + all_batches, weights0 = _run_training(trainer_kwargs, dataset_classes) + all_batches = torch.stack(all_batches) + assert len(all_batches) == 9 - all_batches, weights0 = _run_training(trainer_kwargs, dataset_classes) - all_batches = torch.stack(all_batches) - assert len(all_batches) == 9 + # Simulate 1st failure + complete_batches, _ = _run_training(trainer_kwargs, dataset_classes, fail_on_step=4) + assert len(complete_batches) == 4 - # Simulate 1st failure - complete_batches, _ = _run_training(trainer_kwargs, dataset_classes, fail_on_step=4) - assert len(complete_batches) == 4 - - checkpoint_path = os.path.join(tmpdir, ".pl_auto_save.ckpt") - assert os.path.exists(checkpoint_path) + checkpoint_path = os.path.join(tmpdir, ".pl_auto_save.ckpt") + assert os.path.exists(checkpoint_path) - # Resume after failure - resumed_batches, weights1 = _run_training( - trainer_kwargs, dataset_classes, fail_on_step=-1, ckpt_path=checkpoint_path - ) - assert len(resumed_batches) == 5 + # Resume after failure + resumed_batches, weights1 = _run_training( + trainer_kwargs, dataset_classes, fail_on_step=-1, ckpt_path=checkpoint_path + ) + assert len(resumed_batches) == 5 # the resumed batches should match the batches of the successful training all_batches_resumed = torch.stack(complete_batches + resumed_batches) @@ -1229,10 +1228,10 @@ def test_validate_fault_tolerant(tmpdir): DataLoader(dataset, sampler=SequentialSampler(dataset)), ] - with pytest.raises(ValueError, match="Only SequentialSampler is supported."): + with pytest.raises(ValueError, match="Only `SequentialSampler` is supported."): _validate_fault_tolerant_automatic(dataloaders, RunningStage.VALIDATING) - with pytest.raises(ValueError, match="RandomSampler"): + with pytest.raises(TypeError, match="RandomSampler"): class CustomRandomSampler(RandomSampler): pass @@ -1240,7 +1239,7 @@ class CustomRandomSampler(RandomSampler): dataloader = DataLoader(data, sampler=CustomRandomSampler(data)) _validate_fault_tolerant_automatic(dataloader, RunningStage.TRAINING) - with pytest.raises(ValueError, match="BatchSampler"): + with pytest.raises(TypeError, match="BatchSampler"): class CustomBatchSampler(BatchSampler): pass @@ -1260,7 +1259,7 @@ def __iter__(self): dataloader = DataLoader(CustomIterable()) _validate_fault_tolerant_automatic(dataloader, RunningStage.TRAINING) - with pytest.raises(AttributeError, match="IterableDataset without a sampler as attribute"): + with pytest.raises(TypeError, match="IterableDataset without a sampler as attribute"): class CustomIterable(IterableDataset): def __iter__(self): From 9349eecf535d11b2fbab0538cb9368c0875a89e3 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 26 Nov 2021 18:27:21 +0100 Subject: [PATCH 19/22] Address missed comments --- pytorch_lightning/utilities/auto_restart.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 3e6bee557368a..9d26f4a6e0736 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -775,8 +775,7 @@ def _validate_iterable_dataset(dataloader: DataLoader) -> None: raise ValueError(f"A single sampler is supported within an Iterable Dataset. Found {sampler}.") if type(sampler[0]) is DistributedSampler and sampler.shuffle: - raise ValueError("A `DistributedSampler` sampler shuffle attribute is set to True.") - + raise TypeError("A `DistributedSampler` sampler shuffle attribute is set to True.") elif type(sampler[0]) is not SequentialSampler: raise TypeError("Only `SequentialSampler` is supported.") @@ -793,19 +792,18 @@ def _validate_map_dataset(dataloader: DataLoader) -> None: raise TypeError("Fault-tolerance supports only a `BatchSampler`.") if type(sampler) is DistributedSampler and sampler.shuffle: - raise ValueError("A `DistributedSampler` sampler shuffle attribute is set to True.") - + raise TypeError("A `DistributedSampler` sampler shuffle attribute is set to True.") elif type(sampler) is RandomSampler: - raise ValueError("Only `SequentialSampler` is supported.") + raise TypeError("Only `SequentialSampler` is supported.") def _validate_fault_tolerant_automatic(dataloader: Iterable, stage: "pl.trainer.states.RunningStage") -> None: """This function is used to validate that Fault-tolerance is possible with the user data.""" - from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator - if not _FaultTolerantMode.detect_current_mode().is_automatic: return + from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator + if isinstance(dataloader, CombinedLoader): dataloaders = dataloader.loaders else: From 2dcb99fe633efba9b4c38a28ee8952baa2e2eaaf Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 26 Nov 2021 18:27:31 +0100 Subject: [PATCH 20/22] Update test --- tests/utilities/test_auto_restart.py | 49 ++++++++++++++-------------- 1 file changed, 25 insertions(+), 24 deletions(-) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 29ad4b162b4a6..bceb9068d676c 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -1179,31 +1179,33 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) def test_validate_fault_tolerant(tmpdir): + def data(): + return range(10) - data = range(10) - dataloader = DataLoader(data) + def dataloader(): + return DataLoader(data()) - _validate_fault_tolerant_automatic(dataloader, RunningStage.TRAINING) + _validate_fault_tolerant_automatic(dataloader(), RunningStage.TRAINING) with pytest.raises(ValueError, match="Fault-tolerance supports only a single dataloader."): - dataloaders = CombinedLoader([DataLoader(data), DataLoader(range(10))]) + dataloaders = CombinedLoader([dataloader(), dataloader()]) _validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING) with pytest.raises(ValueError, match="Fault-tolerance supports only a single dataloader."): - dataloaders = CombinedLoader([DataLoader(data), DataLoader(range(10))], mode="max_size_cycle") + dataloaders = CombinedLoader([dataloader(), dataloader()], mode="max_size_cycle") _validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING) - dataloaders = [DataLoader(data), DataLoader(range(10))] + dataloaders = [dataloader(), dataloader()] with pytest.raises(ValueError, match="Fault-tolerance supports only a single dataloader."): _validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING) _validate_fault_tolerant_automatic(dataloaders, RunningStage.VALIDATING) - with pytest.raises(ValueError, match="A `DistributedSampler` sampler shuffle attribute is set to True."): - dataloaders = [DataLoader(data, sampler=DistributedSampler(data, num_replicas=2, rank=0, shuffle=True))] + with pytest.raises(TypeError, match="A `DistributedSampler` sampler shuffle attribute is set to True."): + dataloaders = [DataLoader(data(), sampler=DistributedSampler(data(), num_replicas=2, rank=0, shuffle=True))] _validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING) - dataloaders = [DataLoader(data, sampler=DistributedSampler(data, num_replicas=2, rank=0, shuffle=False))] + dataloaders = [DataLoader(data(), sampler=DistributedSampler(data(), num_replicas=2, rank=0, shuffle=False))] _validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING) dataset = SequentialGetItemDataset(2) @@ -1228,7 +1230,7 @@ def test_validate_fault_tolerant(tmpdir): DataLoader(dataset, sampler=SequentialSampler(dataset)), ] - with pytest.raises(ValueError, match="Only `SequentialSampler` is supported."): + with pytest.raises(TypeError, match="Only `SequentialSampler` is supported."): _validate_fault_tolerant_automatic(dataloaders, RunningStage.VALIDATING) with pytest.raises(TypeError, match="RandomSampler"): @@ -1236,18 +1238,18 @@ def test_validate_fault_tolerant(tmpdir): class CustomRandomSampler(RandomSampler): pass - dataloader = DataLoader(data, sampler=CustomRandomSampler(data)) - _validate_fault_tolerant_automatic(dataloader, RunningStage.TRAINING) + dl = DataLoader(data(), sampler=CustomRandomSampler(data())) + _validate_fault_tolerant_automatic(dl, RunningStage.TRAINING) with pytest.raises(TypeError, match="BatchSampler"): class CustomBatchSampler(BatchSampler): pass - sampler = Sampler(data) + sampler = Sampler(data()) batch_sampler = CustomBatchSampler(sampler, 2, False) - dataloader = DataLoader(data, batch_sampler=batch_sampler) - _validate_fault_tolerant_automatic(dataloader, RunningStage.TRAINING) + dl = DataLoader(data(), batch_sampler=batch_sampler) + _validate_fault_tolerant_automatic(dl, RunningStage.TRAINING) with pytest.raises(AttributeError, match="without `__next__` method"): @@ -1256,8 +1258,8 @@ def __iter__(self): while True: yield 0 - dataloader = DataLoader(CustomIterable()) - _validate_fault_tolerant_automatic(dataloader, RunningStage.TRAINING) + iterable_dataloader = DataLoader(CustomIterable()) + _validate_fault_tolerant_automatic(iterable_dataloader, RunningStage.TRAINING) with pytest.raises(TypeError, match="IterableDataset without a sampler as attribute"): @@ -1268,16 +1270,15 @@ def __iter__(self): def __next__(self): return torch.tensor(0) - dataloader = DataLoader(CustomIterable()) - _validate_fault_tolerant_automatic(dataloader, RunningStage.TRAINING) + iterable_dataloader = DataLoader(CustomIterable()) + _validate_fault_tolerant_automatic(iterable_dataloader, RunningStage.TRAINING) with pytest.raises(TypeError, match="RandomSampler"): class CustomIterable(IterableDataset): def __init__(self): super().__init__() - self.data = data - self.sampler = CustomRandomSampler(self.data) + self.sampler = CustomRandomSampler(data()) def __iter__(self): return self @@ -1285,10 +1286,10 @@ def __iter__(self): def __next__(self): return torch.tensor(0) - dataloader = DataLoader(CustomIterable()) - _validate_fault_tolerant_automatic(dataloader, RunningStage.TRAINING) + iterable_dataloader = DataLoader(CustomIterable()) + _validate_fault_tolerant_automatic(iterable_dataloader, RunningStage.TRAINING) - dataloaders = [DataLoader(data), DataLoader(CustomIterable())] + dataloaders = [iterable_dataloader, DataLoader(CustomIterable())] with pytest.raises(TypeError, match="RandomSampler"): _validate_fault_tolerant_automatic(dataloaders, RunningStage.VALIDATING) From 69879784422268729bb23cd6a1931c1a89b66692 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 26 Nov 2021 18:31:58 +0100 Subject: [PATCH 21/22] Simplify test by removing unnecesary extras and wrapping only exceptions --- tests/utilities/test_auto_restart.py | 75 +++++++++++----------------- 1 file changed, 30 insertions(+), 45 deletions(-) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index bceb9068d676c..058cd5e1c9bdc 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -1187,12 +1187,12 @@ def dataloader(): _validate_fault_tolerant_automatic(dataloader(), RunningStage.TRAINING) + dataloaders = CombinedLoader([dataloader(), dataloader()]) with pytest.raises(ValueError, match="Fault-tolerance supports only a single dataloader."): - dataloaders = CombinedLoader([dataloader(), dataloader()]) _validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING) + dataloaders = CombinedLoader([dataloader(), dataloader()], mode="max_size_cycle") with pytest.raises(ValueError, match="Fault-tolerance supports only a single dataloader."): - dataloaders = CombinedLoader([dataloader(), dataloader()], mode="max_size_cycle") _validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING) dataloaders = [dataloader(), dataloader()] @@ -1201,8 +1201,8 @@ def dataloader(): _validate_fault_tolerant_automatic(dataloaders, RunningStage.VALIDATING) + dataloaders = [DataLoader(data(), sampler=DistributedSampler(data(), num_replicas=2, rank=0, shuffle=True))] with pytest.raises(TypeError, match="A `DistributedSampler` sampler shuffle attribute is set to True."): - dataloaders = [DataLoader(data(), sampler=DistributedSampler(data(), num_replicas=2, rank=0, shuffle=True))] _validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING) dataloaders = [DataLoader(data(), sampler=DistributedSampler(data(), num_replicas=2, rank=0, shuffle=False))] @@ -1213,7 +1213,6 @@ def dataloader(): DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=False)), DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=False)), ] - with pytest.raises(ValueError, match="Fault-tolerance supports only a single dataloader."): _validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING) @@ -1221,7 +1220,6 @@ def dataloader(): DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=True)), DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=False)), ] - with pytest.raises(ValueError, match="Fault-tolerance supports only a single."): _validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING) @@ -1233,60 +1231,47 @@ def dataloader(): with pytest.raises(TypeError, match="Only `SequentialSampler` is supported."): _validate_fault_tolerant_automatic(dataloaders, RunningStage.VALIDATING) - with pytest.raises(TypeError, match="RandomSampler"): - - class CustomRandomSampler(RandomSampler): - pass + class CustomRandomSampler(RandomSampler): + pass - dl = DataLoader(data(), sampler=CustomRandomSampler(data())) + dl = DataLoader(data(), sampler=CustomRandomSampler(data())) + with pytest.raises(TypeError, match="RandomSampler"): _validate_fault_tolerant_automatic(dl, RunningStage.TRAINING) - with pytest.raises(TypeError, match="BatchSampler"): - - class CustomBatchSampler(BatchSampler): - pass + class CustomBatchSampler(BatchSampler): + pass - sampler = Sampler(data()) - batch_sampler = CustomBatchSampler(sampler, 2, False) - dl = DataLoader(data(), batch_sampler=batch_sampler) + sampler = Sampler(data()) + batch_sampler = CustomBatchSampler(sampler, 2, False) + dl = DataLoader(data(), batch_sampler=batch_sampler) + with pytest.raises(TypeError, match="BatchSampler"): _validate_fault_tolerant_automatic(dl, RunningStage.TRAINING) - with pytest.raises(AttributeError, match="without `__next__` method"): - - class CustomIterable(IterableDataset): - def __iter__(self): - while True: - yield 0 + class CustomIterable(IterableDataset): + pass - iterable_dataloader = DataLoader(CustomIterable()) + iterable_dataloader = DataLoader(CustomIterable()) + with pytest.raises(AttributeError, match="without `__next__` method"): _validate_fault_tolerant_automatic(iterable_dataloader, RunningStage.TRAINING) - with pytest.raises(TypeError, match="IterableDataset without a sampler as attribute"): - - class CustomIterable(IterableDataset): - def __iter__(self): - return self - - def __next__(self): - return torch.tensor(0) + class CustomIterable(IterableDataset): + def __next__(self): + return torch.tensor(0) - iterable_dataloader = DataLoader(CustomIterable()) + iterable_dataloader = DataLoader(CustomIterable()) + with pytest.raises(TypeError, match="IterableDataset without a sampler as attribute"): _validate_fault_tolerant_automatic(iterable_dataloader, RunningStage.TRAINING) - with pytest.raises(TypeError, match="RandomSampler"): - - class CustomIterable(IterableDataset): - def __init__(self): - super().__init__() - self.sampler = CustomRandomSampler(data()) - - def __iter__(self): - return self + class CustomIterable(IterableDataset): + def __init__(self): + super().__init__() + self.sampler = CustomRandomSampler(data()) - def __next__(self): - return torch.tensor(0) + def __next__(self): + return torch.tensor(0) - iterable_dataloader = DataLoader(CustomIterable()) + iterable_dataloader = DataLoader(CustomIterable()) + with pytest.raises(TypeError, match="RandomSampler"): _validate_fault_tolerant_automatic(iterable_dataloader, RunningStage.TRAINING) dataloaders = [iterable_dataloader, DataLoader(CustomIterable())] From 1feebeb3197c03ce4be9d0da0b58c2263db9268a Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 26 Nov 2021 19:11:46 +0000 Subject: [PATCH 22/22] update --- tests/utilities/test_auto_restart.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 29ad4b162b4a6..34df6f1b218bd 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -667,6 +667,7 @@ def create_iterable_dataset(batch_size, num_workers, attr_name="iter_sampler", w return dataset +@mock.patch("pytorch_lightning.trainer.data_loading._validate_fault_tolerant_automatic", lambda x, y: None) @pytest.mark.parametrize("use_fault_tolerant", ["0", "1"]) def test_data_loading_wraps_dataset_and_samplers(use_fault_tolerant, tmpdir): """This test ensures the dataset and sampler are properly wrapped when fault tolerant is enabled.""" @@ -723,9 +724,7 @@ def on_train_batch_start(self, trainer, *_) -> None: assert not isinstance(loaders["a"][2].loader.dataset, CaptureMapDataset) assert isinstance(loaders["b"].loader.dataset, RangeIterableDataset) - with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": use_fault_tolerant}), mock.patch( - "pytorch_lightning.trainer.data_loading._validate_fault_tolerant_automatic", lambda x, y: None - ): + with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": use_fault_tolerant}): model = TestModel() model.training_epoch_end = None trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_train_batches=1, callbacks=Check())