Skip to content

Commit

Permalink
Fault Tolerant: Add support for fault tolerant dataloader validator (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton committed Nov 26, 2021
1 parent 8893072 commit e94aff1
Show file tree
Hide file tree
Showing 4 changed files with 208 additions and 6 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -28,6 +28,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for re-instantiation of custom (subclasses of) `DataLoaders` returned in the `*_dataloader()` methods, i.e., automatic replacement of samplers now works with custom types of `DataLoader` ([#10680](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639))


- Added a function to validate if fault tolerant training is supported. ([#10465](https://github.com/PyTorchLightning/pytorch-lightning/issues/10465))


- Show a better error message when a custom `DataLoader` implementation is not well implemented and we need to reconstruct it ([#10719](https://github.com/PyTorchLightning/pytorch-lightning/issues/10719))


Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/trainer/data_loading.py
Expand Up @@ -28,7 +28,7 @@
from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.auto_restart import _add_capture_metadata_collate
from pytorch_lightning.utilities.auto_restart import _add_capture_metadata_collate, _validate_fault_tolerant_automatic
from pytorch_lightning.utilities.data import (
_auto_add_worker_init_fn,
_replace_dataloader_init_method,
Expand Down Expand Up @@ -441,6 +441,7 @@ def request_dataloader(
if isinstance(dataloader, tuple):
dataloader = list(dataloader)
self.training_type_plugin.barrier("get_dataloaders")
_validate_fault_tolerant_automatic(dataloader, stage)
return dataloader

@staticmethod
Expand Down
96 changes: 92 additions & 4 deletions pytorch_lightning/utilities/auto_restart.py
Expand Up @@ -16,11 +16,19 @@
from functools import partial, wraps
from random import getstate as python_get_rng_state
from random import setstate as python_set_rng_state
from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, Union

import numpy as np
import torch
from torch.utils.data import Dataset, get_worker_info, Sampler
from torch.utils.data import (
BatchSampler,
Dataset,
DistributedSampler,
get_worker_info,
RandomSampler,
Sampler,
SequentialSampler,
)
from torch.utils.data.dataloader import (
_BaseDataLoaderIter,
_MultiProcessingDataLoaderIter,
Expand Down Expand Up @@ -370,7 +378,7 @@ def _cycle_to_next_worker_and_reset(dataloader: DataLoader, state_dict: Dict[str
# get current num workers
num_workers = getattr(iter_dataloader, "_num_workers", 0)
# as `state_dict` are workers dependent, Lightning doesn't support changing
# the `num_workers` for fault tolerant training
# the `num_workers` for Fault-tolerance
if state_dict["num_workers"] != num_workers:
raise MisconfigurationException(
f"The provided `num_workers` {num_workers} doesn't match the one used "
Expand Down Expand Up @@ -734,13 +742,93 @@ def _patch_dataloader_get_iterators() -> None:

def _teardown_dataloader_get_iterators() -> None:
"""This function is used to restore the DataLoader `get_iterator` with its original one."""
# cleanup the get_iterator replacement in case of Fault Tolerant Training.
# cleanup the get_iterator replacement in case of Fault-tolerance.
get_iterator = getattr(DataLoader, "_ori_get_iterator", None)
if get_iterator:
DataLoader._get_iterator = get_iterator
del DataLoader._ori_get_iterator


def _validate_iterable_dataset(dataloader: DataLoader) -> None:
SUPPORTED_SAMPLERS = (RandomSampler, SequentialSampler, DistributedSampler)

dataset = dataloader.dataset

if getattr(dataset, "__next__", None) is None:
raise AttributeError(
"Fault-tolerance doesn't support an `IterableDataset` without `__next__` "
"method implemented. Hint: We recommend you to move your logic from `__iter__`"
" inside and rely on a sampler to perform the sample sampling."
)

samplers = {k: v for k, v in dataset.__dict__.items() if isinstance(v, Sampler)}

if not samplers:
raise TypeError("Fault-tolerance doesn't support an IterableDataset without a sampler as attribute.")

sampler = [v for v in samplers.values() if type(v) in SUPPORTED_SAMPLERS]

if not sampler:
raise TypeError(f"Fault-tolerance supports only {SUPPORTED_SAMPLERS}.")

if len(sampler) > 1:
raise ValueError(f"A single sampler is supported within an Iterable Dataset. Found {sampler}.")

if type(sampler[0]) is DistributedSampler and sampler.shuffle:
raise TypeError("A `DistributedSampler` sampler shuffle attribute is set to True.")
elif type(sampler[0]) is not SequentialSampler:
raise TypeError("Only `SequentialSampler` is supported.")


def _validate_map_dataset(dataloader: DataLoader) -> None:
SUPPORTED_SAMPLERS = (RandomSampler, SequentialSampler, DistributedSampler)

sampler = getattr(dataloader, "sampler", None)
if sampler is not None and type(sampler) not in SUPPORTED_SAMPLERS:
raise TypeError(f"Fault-tolerance supports only {SUPPORTED_SAMPLERS}.")

batch_sampler = getattr(dataloader, "batch_sampler", None)
if batch_sampler is not None and type(batch_sampler) is not BatchSampler:
raise TypeError("Fault-tolerance supports only a `BatchSampler`.")

if type(sampler) is DistributedSampler and sampler.shuffle:
raise TypeError("A `DistributedSampler` sampler shuffle attribute is set to True.")
elif type(sampler) is RandomSampler:
raise TypeError("Only `SequentialSampler` is supported.")


def _validate_fault_tolerant_automatic(dataloader: Iterable, stage: "pl.trainer.states.RunningStage") -> None:
"""This function is used to validate that Fault-tolerance is possible with the user data."""
if not _FaultTolerantMode.detect_current_mode().is_automatic:
return

from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator

if isinstance(dataloader, CombinedLoader):
dataloaders = dataloader.loaders
else:
dataloaders = dataloader

dl_loaders = []

def flatten_dataloader(dataloader: Union[DataLoader, CycleIterator, Iterable]) -> None:
nonlocal dl_loaders
if isinstance(dataloader, CycleIterator):
dataloader = dataloader.loader
dl_loaders.append(dataloader)

apply_to_collection(dataloaders, (DataLoader, CycleIterator), flatten_dataloader)

if len(dl_loaders) > 1 and stage == pl.trainer.states.RunningStage.TRAINING:
raise ValueError("Fault-tolerance supports only a single dataloader.")

for dataloader in dl_loaders:
validator_fn = (
_validate_iterable_dataset if isinstance(dataloader.dataset, IterableDataset) else _validate_map_dataset
)
validator_fn(dataloader)


def _collect_states_on_rank_zero_over_collection(state_dict: Any, key: str = "state") -> Any:
"""This utility collects the state across processes for a collection of state."""

Expand Down
112 changes: 111 additions & 1 deletion tests/utilities/test_auto_restart.py
Expand Up @@ -34,10 +34,12 @@
from torch.utils.data._utils.worker import get_worker_info
from torch.utils.data.dataloader import DataLoader, default_collate
from torch.utils.data.dataset import Dataset, IterableDataset
from torch.utils.data.sampler import Sampler

import tests.helpers.utils as tutils
from pytorch_lightning import Callback, LightningModule, seed_everything, Trainer
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.states import RunningStage, TrainerState
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities.auto_restart import (
_add_capture_metadata_collate,
_collect_states_on_rank_zero_over_collection,
Expand All @@ -48,6 +50,7 @@
_SingleProcessDataLoaderIterStateful,
_SupportsStateDict,
_teardown_dataloader_get_iterators,
_validate_fault_tolerant_automatic,
CaptureIterableDataset,
CaptureMapDataset,
FastForwardSampler,
Expand Down Expand Up @@ -665,6 +668,7 @@ def create_iterable_dataset(batch_size, num_workers, attr_name="iter_sampler", w
return dataset


@mock.patch("pytorch_lightning.trainer.data_loading._validate_fault_tolerant_automatic", lambda x, y: None)
@pytest.mark.parametrize("use_fault_tolerant", ["0", "1"])
def test_data_loading_wraps_dataset_and_samplers(use_fault_tolerant, tmpdir):
"""This test ensures the dataset and sampler are properly wrapped when fault tolerant is enabled."""
Expand Down Expand Up @@ -893,6 +897,10 @@ def _run_training(trainer_kwargs, dataset_classes, fail_on_step: int = -1, ckpt_
return model.seen_batches, model.parameters()


# this test will fail `fault_tolerant` don't support multiple datasets.
# this tests works as the dataset is fully deterministic and therefore
# there is not overall between the seeds.
@mock.patch("pytorch_lightning.trainer.data_loading._validate_fault_tolerant_automatic", lambda x, y: None)
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
@pytest.mark.parametrize(
"dataset_classes",
Expand Down Expand Up @@ -1180,6 +1188,108 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on
assert "dataloader_state_dict" in state_dict


@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
def test_validate_fault_tolerant(tmpdir):
def data():
return range(10)

def dataloader():
return DataLoader(data())

_validate_fault_tolerant_automatic(dataloader(), RunningStage.TRAINING)

dataloaders = CombinedLoader([dataloader(), dataloader()])
with pytest.raises(ValueError, match="Fault-tolerance supports only a single dataloader."):
_validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING)

dataloaders = CombinedLoader([dataloader(), dataloader()], mode="max_size_cycle")
with pytest.raises(ValueError, match="Fault-tolerance supports only a single dataloader."):
_validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING)

dataloaders = [dataloader(), dataloader()]
with pytest.raises(ValueError, match="Fault-tolerance supports only a single dataloader."):
_validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING)

_validate_fault_tolerant_automatic(dataloaders, RunningStage.VALIDATING)

dataloaders = [DataLoader(data(), sampler=DistributedSampler(data(), num_replicas=2, rank=0, shuffle=True))]
with pytest.raises(TypeError, match="A `DistributedSampler` sampler shuffle attribute is set to True."):
_validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING)

dataloaders = [DataLoader(data(), sampler=DistributedSampler(data(), num_replicas=2, rank=0, shuffle=False))]
_validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING)

dataset = SequentialGetItemDataset(2)
dataloaders = [
DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=False)),
DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=False)),
]
with pytest.raises(ValueError, match="Fault-tolerance supports only a single dataloader."):
_validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING)

dataloaders = [
DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=True)),
DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=False)),
]
with pytest.raises(ValueError, match="Fault-tolerance supports only a single."):
_validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING)

dataloaders = [
DataLoader(dataset, sampler=RandomSampler(dataset)),
DataLoader(dataset, sampler=SequentialSampler(dataset)),
]

with pytest.raises(TypeError, match="Only `SequentialSampler` is supported."):
_validate_fault_tolerant_automatic(dataloaders, RunningStage.VALIDATING)

class CustomRandomSampler(RandomSampler):
pass

dl = DataLoader(data(), sampler=CustomRandomSampler(data()))
with pytest.raises(TypeError, match="RandomSampler"):
_validate_fault_tolerant_automatic(dl, RunningStage.TRAINING)

class CustomBatchSampler(BatchSampler):
pass

sampler = Sampler(data())
batch_sampler = CustomBatchSampler(sampler, 2, False)
dl = DataLoader(data(), batch_sampler=batch_sampler)
with pytest.raises(TypeError, match="BatchSampler"):
_validate_fault_tolerant_automatic(dl, RunningStage.TRAINING)

class CustomIterable(IterableDataset):
pass

iterable_dataloader = DataLoader(CustomIterable())
with pytest.raises(AttributeError, match="without `__next__` method"):
_validate_fault_tolerant_automatic(iterable_dataloader, RunningStage.TRAINING)

class CustomIterable(IterableDataset):
def __next__(self):
return torch.tensor(0)

iterable_dataloader = DataLoader(CustomIterable())
with pytest.raises(TypeError, match="IterableDataset without a sampler as attribute"):
_validate_fault_tolerant_automatic(iterable_dataloader, RunningStage.TRAINING)

class CustomIterable(IterableDataset):
def __init__(self):
super().__init__()
self.sampler = CustomRandomSampler(data())

def __next__(self):
return torch.tensor(0)

iterable_dataloader = DataLoader(CustomIterable())
with pytest.raises(TypeError, match="RandomSampler"):
_validate_fault_tolerant_automatic(iterable_dataloader, RunningStage.TRAINING)

dataloaders = [iterable_dataloader, DataLoader(CustomIterable())]
with pytest.raises(TypeError, match="RandomSampler"):
_validate_fault_tolerant_automatic(dataloaders, RunningStage.VALIDATING)


def test_rotate_worker_indices():
"""This test ensures `worker_id` are rotated properly depending on which one was the latest."""
state_dict = {0: 0, 1: 1}
Expand Down

0 comments on commit e94aff1

Please sign in to comment.