Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fault Tolerant: Add support for fault tolerant dataloader validator #10465

Merged
merged 26 commits into from Nov 26, 2021
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -28,6 +28,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for re-instantiation of custom (subclasses of) `DataLoaders` returned in the `*_dataloader()` methods, i.e., automatic replacement of samplers now works with custom types of `DataLoader` ([#10680](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639))


- Added a function to validate if fault tolerant training is supported. ([#10465](https://github.com/PyTorchLightning/pytorch-lightning/issues/10465))


- Show a better error message when a custom `DataLoader` implementation is not well implemented and we need to reconstruct it ([#10719](https://github.com/PyTorchLightning/pytorch-lightning/issues/10719))


Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/trainer/data_loading.py
Expand Up @@ -28,7 +28,7 @@
from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.auto_restart import _add_capture_metadata_collate
from pytorch_lightning.utilities.auto_restart import _add_capture_metadata_collate, _validate_fault_tolerant_automatic
from pytorch_lightning.utilities.data import (
_auto_add_worker_init_fn,
_replace_dataloader_init_method,
Expand Down Expand Up @@ -438,6 +438,7 @@ def request_dataloader(
if isinstance(dataloader, tuple):
dataloader = list(dataloader)
self.training_type_plugin.barrier("get_dataloaders")
_validate_fault_tolerant_automatic(dataloader, stage)
return dataloader

@staticmethod
Expand Down
98 changes: 94 additions & 4 deletions pytorch_lightning/utilities/auto_restart.py
Expand Up @@ -16,11 +16,19 @@
from functools import partial, wraps
from random import getstate as python_get_rng_state
from random import setstate as python_set_rng_state
from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, Union

import numpy as np
import torch
from torch.utils.data import Dataset, get_worker_info, Sampler
from torch.utils.data import (
BatchSampler,
Dataset,
DistributedSampler,
get_worker_info,
RandomSampler,
Sampler,
SequentialSampler,
)
from torch.utils.data.dataloader import (
_BaseDataLoaderIter,
_MultiProcessingDataLoaderIter,
Expand Down Expand Up @@ -370,7 +378,7 @@ def _cycle_to_next_worker_and_reset(dataloader: DataLoader, state_dict: Dict[str
# get current num workers
num_workers = getattr(iter_dataloader, "_num_workers", 0)
# as `state_dict` are workers dependent, Lightning doesn't support changing
# the `num_workers` for fault tolerant training
# the `num_workers` for Fault-tolerance
if state_dict["num_workers"] != num_workers:
raise MisconfigurationException(
f"The provided `num_workers` {num_workers} doesn't match the one used "
Expand Down Expand Up @@ -734,13 +742,95 @@ def _patch_dataloader_get_iterators() -> None:

def _teardown_dataloader_get_iterators() -> None:
"""This function is used to restore the DataLoader `get_iterator` with its original one."""
# cleanup the get_iterator replacement in case of Fault Tolerant Training.
# cleanup the get_iterator replacement in case of Fault-tolerance.
get_iterator = getattr(DataLoader, "_ori_get_iterator", None)
if get_iterator:
DataLoader._get_iterator = get_iterator
del DataLoader._ori_get_iterator


def _validate_iterable_dataset(dataloader: DataLoader):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
SUPPORTED_SAMPLERS = (RandomSampler, SequentialSampler, DistributedSampler)

dataset = dataloader.dataset

if getattr(dataset, "__next__", None) is None:
raise AttributeError(
"Fault-tolerance doesn't support an `IterableDataset` without `__next__` "
"method implemented. Hint: We recommend you to move your logic from `__iter__`"
" inside and rely on a sampler to perform the sample sampling."
)

samplers = {k: v for k, v in dataset.__dict__.items() if isinstance(v, Sampler)}

if not samplers:
raise AttributeError("Fault-tolerance doesn't support an IterableDataset without a sampler as attribute.")
tchaton marked this conversation as resolved.
Show resolved Hide resolved

sampler = [v for v in samplers.values() if type(v) in SUPPORTED_SAMPLERS]

if not sampler:
raise TypeError(f"Fault-tolerance supports only {SUPPORTED_SAMPLERS}.")

if len(sampler) > 1:
raise ValueError(f"A single sampler is supported within an Iterable Dataset. Found {sampler}.")

if type(sampler[0]) is DistributedSampler and sampler.shuffle:
raise ValueError("A `DistributedSampler` sampler shuffle attribute is set to True.")

if type(sampler[0]) is RandomSampler:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("Only SequentialSampler is supported.")
tchaton marked this conversation as resolved.
Show resolved Hide resolved


def _validate_map_dataset(dataloader: DataLoader):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
SUPPORTED_SAMPLERS = (RandomSampler, SequentialSampler, DistributedSampler)

sampler = getattr(dataloader, "sampler", None)
if sampler is not None and type(sampler) not in SUPPORTED_SAMPLERS:
raise ValueError(f"Fault-tolerance supports only {SUPPORTED_SAMPLERS}.")
tchaton marked this conversation as resolved.
Show resolved Hide resolved

batch_sampler = getattr(dataloader, "batch_sampler", None)
if batch_sampler is not None and type(batch_sampler) is not BatchSampler:
raise ValueError("Fault-tolerance supports only a BatchSampler.")
tchaton marked this conversation as resolved.
Show resolved Hide resolved

if type(sampler) is DistributedSampler and sampler.shuffle:
raise ValueError("A `DistributedSampler` sampler shuffle attribute is set to True.")

if type(sampler) is RandomSampler:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError("Only SequentialSampler is supported.")
tchaton marked this conversation as resolved.
Show resolved Hide resolved


def _validate_fault_tolerant_automatic(dataloader: Iterable, stage: "pl.trainer.states.RunningStage") -> None:
"""This function is used to validate that Fault-tolerance is possible with the user data."""
from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator

if not _FaultTolerantMode.detect_current_mode().is_automatic:
return

if isinstance(dataloader, CombinedLoader):
dataloaders = dataloader.loaders
else:
dataloaders = dataloader

dl_loaders = []
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def flatten_dataloader(dataloader: Union[DataLoader, CycleIterator, Iterable]) -> None:
nonlocal dl_loaders
if isinstance(dataloader, CycleIterator):
dataloader = dataloader.loader
dl_loaders.append(dataloader)

apply_to_collection(dataloaders, (DataLoader, CycleIterator), flatten_dataloader)

if len(dl_loaders) > 1 and stage == pl.trainer.states.RunningStage.TRAINING:
raise ValueError("Fault-tolerance supports only a single dataloader.")

for dataloader in dl_loaders:
validator_fn = (
_validate_iterable_dataset if isinstance(dataloader.dataset, IterableDataset) else _validate_map_dataset
)
validator_fn(dataloader)


def _collect_states_on_rank_zero_over_collection(state_dict: Any, key: str = "state") -> Any:
"""This utility collects the state across processes for a collection of state."""

Expand Down
156 changes: 141 additions & 15 deletions tests/utilities/test_auto_restart.py
Expand Up @@ -33,10 +33,12 @@
from torch.utils.data._utils.worker import get_worker_info
from torch.utils.data.dataloader import DataLoader, default_collate
from torch.utils.data.dataset import Dataset, IterableDataset
from torch.utils.data.sampler import Sampler

import tests.helpers.utils as tutils
from pytorch_lightning import Callback, LightningModule, seed_everything, Trainer
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.states import RunningStage, TrainerState
from pytorch_lightning.trainer.supporters import CombinedLoader
from pytorch_lightning.utilities.auto_restart import (
_add_capture_metadata_collate,
_collect_states_on_rank_zero_over_collection,
Expand All @@ -47,6 +49,7 @@
_SingleProcessDataLoaderIterStateful,
_SupportsStateDict,
_teardown_dataloader_get_iterators,
_validate_fault_tolerant_automatic,
CaptureIterableDataset,
CaptureMapDataset,
FastForwardSampler,
Expand Down Expand Up @@ -720,7 +723,9 @@ def on_train_batch_start(self, trainer, *_) -> None:
assert not isinstance(loaders["a"][2].loader.dataset, CaptureMapDataset)
assert isinstance(loaders["b"].loader.dataset, RangeIterableDataset)

with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": use_fault_tolerant}):
with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": use_fault_tolerant}), mock.patch(
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"pytorch_lightning.trainer.data_loading._validate_fault_tolerant_automatic", lambda x, y: None
tchaton marked this conversation as resolved.
Show resolved Hide resolved
):
model = TestModel()
model.training_epoch_end = None
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_train_batches=1, callbacks=Check())
Expand Down Expand Up @@ -917,22 +922,27 @@ def test_dataset_rng_states_restart_with_lightning(tmpdir, dataset_classes, mult
multiple_trainloader_mode=multiple_trainloader_mode,
)

all_batches, weights0 = _run_training(trainer_kwargs, dataset_classes)
all_batches = torch.stack(all_batches)
assert len(all_batches) == 9
# this test will fail `fault_tolerant` don't support multiple datasets.
# this tests works as the dataset is fully deterministic and therefore
# there is not overall between the seeds.
with mock.patch("pytorch_lightning.trainer.data_loading._validate_fault_tolerant_automatic", lambda x, y: None):
tchaton marked this conversation as resolved.
Show resolved Hide resolved

# Simulate 1st failure
complete_batches, _ = _run_training(trainer_kwargs, dataset_classes, fail_on_step=4)
assert len(complete_batches) == 4
all_batches, weights0 = _run_training(trainer_kwargs, dataset_classes)
all_batches = torch.stack(all_batches)
assert len(all_batches) == 9

checkpoint_path = os.path.join(tmpdir, ".pl_auto_save.ckpt")
assert os.path.exists(checkpoint_path)
# Simulate 1st failure
complete_batches, _ = _run_training(trainer_kwargs, dataset_classes, fail_on_step=4)
assert len(complete_batches) == 4

# Resume after failure
resumed_batches, weights1 = _run_training(
trainer_kwargs, dataset_classes, fail_on_step=-1, ckpt_path=checkpoint_path
)
assert len(resumed_batches) == 5
checkpoint_path = os.path.join(tmpdir, ".pl_auto_save.ckpt")
assert os.path.exists(checkpoint_path)

# Resume after failure
resumed_batches, weights1 = _run_training(
trainer_kwargs, dataset_classes, fail_on_step=-1, ckpt_path=checkpoint_path
)
assert len(resumed_batches) == 5

# the resumed batches should match the batches of the successful training
all_batches_resumed = torch.stack(complete_batches + resumed_batches)
Expand Down Expand Up @@ -1168,6 +1178,122 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on
assert "dataloader_state_dict" in state_dict


@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
def test_validate_fault_tolerant(tmpdir):

data = range(10)
dataloader = DataLoader(data)

_validate_fault_tolerant_automatic(dataloader, RunningStage.TRAINING)

with pytest.raises(ValueError, match="Fault-tolerance supports only a single dataloader."):
dataloaders = CombinedLoader([DataLoader(data), DataLoader(range(10))])
_validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING)

with pytest.raises(ValueError, match="Fault-tolerance supports only a single dataloader."):
dataloaders = CombinedLoader([DataLoader(data), DataLoader(range(10))], mode="max_size_cycle")
_validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING)

dataloaders = [DataLoader(data), DataLoader(range(10))]
with pytest.raises(ValueError, match="Fault-tolerance supports only a single dataloader."):
_validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING)

_validate_fault_tolerant_automatic(dataloaders, RunningStage.VALIDATING)

with pytest.raises(ValueError, match="A `DistributedSampler` sampler shuffle attribute is set to True."):
dataloaders = [DataLoader(data, sampler=DistributedSampler(data, num_replicas=2, rank=0, shuffle=True))]
_validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING)

dataloaders = [DataLoader(data, sampler=DistributedSampler(data, num_replicas=2, rank=0, shuffle=False))]
_validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING)

dataset = SequentialGetItemDataset(2)
dataloaders = [
DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=False)),
DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=False)),
]

with pytest.raises(ValueError, match="Fault-tolerance supports only a single dataloader."):
_validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING)

dataloaders = [
DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=True)),
DataLoader(dataset, sampler=DistributedSampler(dataset, num_replicas=2, rank=0, shuffle=False)),
]

with pytest.raises(ValueError, match="Fault-tolerance supports only a single."):
_validate_fault_tolerant_automatic(dataloaders, RunningStage.TRAINING)

dataloaders = [
DataLoader(dataset, sampler=RandomSampler(dataset)),
DataLoader(dataset, sampler=SequentialSampler(dataset)),
]

with pytest.raises(ValueError, match="Only SequentialSampler is supported."):
_validate_fault_tolerant_automatic(dataloaders, RunningStage.VALIDATING)

with pytest.raises(ValueError, match="RandomSampler"):

class CustomRandomSampler(RandomSampler):
pass

dataloader = DataLoader(data, sampler=CustomRandomSampler(data))
_validate_fault_tolerant_automatic(dataloader, RunningStage.TRAINING)

with pytest.raises(ValueError, match="BatchSampler"):

class CustomBatchSampler(BatchSampler):
pass

sampler = Sampler(data)
batch_sampler = CustomBatchSampler(sampler, 2, False)
dataloader = DataLoader(data, batch_sampler=batch_sampler)
_validate_fault_tolerant_automatic(dataloader, RunningStage.TRAINING)

with pytest.raises(AttributeError, match="without `__next__` method"):

class CustomIterable(IterableDataset):
def __iter__(self):
while True:
yield 0

dataloader = DataLoader(CustomIterable())
_validate_fault_tolerant_automatic(dataloader, RunningStage.TRAINING)

with pytest.raises(AttributeError, match="IterableDataset without a sampler as attribute"):

class CustomIterable(IterableDataset):
def __iter__(self):
return self

def __next__(self):
return torch.tensor(0)

dataloader = DataLoader(CustomIterable())
_validate_fault_tolerant_automatic(dataloader, RunningStage.TRAINING)

with pytest.raises(TypeError, match="RandomSampler"):

class CustomIterable(IterableDataset):
def __init__(self):
super().__init__()
self.data = data
self.sampler = CustomRandomSampler(self.data)

def __iter__(self):
return self

def __next__(self):
return torch.tensor(0)

dataloader = DataLoader(CustomIterable())
_validate_fault_tolerant_automatic(dataloader, RunningStage.TRAINING)

dataloaders = [DataLoader(data), DataLoader(CustomIterable())]
with pytest.raises(TypeError, match="RandomSampler"):
_validate_fault_tolerant_automatic(dataloaders, RunningStage.VALIDATING)


def test_rotate_worker_indices():
"""This test ensures `worker_id` are rotated properly depending on which one was the latest."""
state_dict = {0: 0, 1: 1}
Expand Down