Skip to content

Commit

Permalink
Fault Tolerant Manual: Add stateful dataloader iter (#10674)
Browse files Browse the repository at this point in the history
  • Loading branch information
tchaton committed Nov 23, 2021
1 parent 48cf1ad commit 1702036
Show file tree
Hide file tree
Showing 5 changed files with 259 additions and 20 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Add `_SupportsStateDict` protocol to detect if classes are stateful ([#10646](https://github.com/PyTorchLightning/pytorch-lightning/issues/10646))
* Add `_FaultTolerantMode` enum used to track different supported fault tolerant modes ([#10645](https://github.com/PyTorchLightning/pytorch-lightning/issues/10645))
* Add a `_rotate_worker_indices` utility to reload the state according the latest worker ([#10647](https://github.com/PyTorchLightning/pytorch-lightning/issues/10647))
* Add stateful workers ([#10674](https://github.com/PyTorchLightning/pytorch-lightning/issues/10674))


-
Expand Down
13 changes: 2 additions & 11 deletions pytorch_lightning/trainer/data_loading.py
Expand Up @@ -15,7 +15,6 @@
import os
from abc import ABC
from copy import deepcopy
from functools import partial
from typing import Any, Callable, Collection, List, Optional, Tuple, Union

from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler
Expand All @@ -29,7 +28,7 @@
from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.auto_restart import _capture_metadata_collate
from pytorch_lightning.utilities.auto_restart import _add_capture_metadata_collate
from pytorch_lightning.utilities.data import (
_auto_add_worker_init_fn,
_update_dataloader,
Expand Down Expand Up @@ -215,7 +214,7 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) -

# add collate_fn to collect metadata for fault tolerant training
if _fault_tolerant_training():
apply_to_collection(self.train_dataloader, DataLoader, self._add_sampler_metadata_collate)
apply_to_collection(self.train_dataloader, DataLoader, _add_capture_metadata_collate)

# wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
self.train_dataloader = CombinedLoader(self.train_dataloader, self._data_connector.multiple_trainloader_mode)
Expand Down Expand Up @@ -437,14 +436,6 @@ def request_dataloader(
self.training_type_plugin.barrier("get_dataloaders")
return dataloader

@staticmethod
def _add_sampler_metadata_collate(dataloader: DataLoader) -> None:
"""Wrap default collate function to retrive ``FastForwardSampler`` state dict when fault tolerant is
enabled."""
dataloader.collate_fn = partial(
_capture_metadata_collate, dataset=dataloader.dataset, default_collate=dataloader.collate_fn
)

@staticmethod
def _resolve_overfit_batches(dataloader: Collection[DataLoader]) -> Collection[DataLoader]:
all_have_sequential_sampler = True
Expand Down
153 changes: 144 additions & 9 deletions pytorch_lightning/utilities/auto_restart.py
Expand Up @@ -21,11 +21,17 @@
import numpy as np
import torch
from torch.utils.data import Dataset, get_worker_info, Sampler
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader, IterableDataset
from torch.utils.data.dataloader import (
_BaseDataLoaderIter,
_MultiProcessingDataLoaderIter,
_SingleProcessDataLoaderIter,
DataLoader,
IterableDataset,
)
from typing_extensions import Protocol, runtime_checkable

import pytorch_lightning as pl
from pytorch_lightning.utilities.enums import AutoRestartBatchKeys
from pytorch_lightning.utilities.enums import _FaultTolerantMode, AutoRestartBatchKeys
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training

Expand Down Expand Up @@ -435,8 +441,10 @@ def _find_current_worker(iterator: Iterator) -> Dict[str, Optional[int]]:
return {"num_workers": num_workers, "previous_worker": previous_worker}


def _capture_metadata_collate(samples: List, dataset: Dataset, default_collate: Callable) -> Dict:
"""A collate function that adds the state dict of a :class:`CaptureIterableDataset` or
def _capture_metadata_collate(
samples: List, dataset: Dataset, collate_fn: Callable, fault_tolerant_mode: _FaultTolerantMode
) -> Any:
"""A collate_fn function that adds the state dict of a :class:`CaptureIterableDataset` or
:class:`CaptureMapDataset` used in the worker processes. This function gets executed within the worker
processes. The structure will be:
Expand All @@ -447,10 +455,25 @@ def _capture_metadata_collate(samples: List, dataset: Dataset, default_collate:
"__pl_restart_meta": {"sampler_name0": state_dict0, "sampler_name1": state_dict1},
}
"""
data = default_collate(samples)
if not isinstance(dataset, (CaptureIterableDataset, CaptureMapDataset)):
return data
metadata = dataset.state_dict()
data = collate_fn(samples)
metadata = None
if fault_tolerant_mode.is_automatic:
metadata = dataset.state_dict()
else:
state_dict_fn = getattr(dataset, "state_dict", None)
info = get_worker_info()
worker_id = info.id if info else 0
if state_dict_fn is not None:
metadata = state_dict_fn()
if worker_id not in metadata:
if info and info.num_workers > 1:
raise MisconfigurationException(
f"The state_dict returned by {dataset} needs to be indexed by `worker_id` integer keys."
)
metadata = {0: metadata}
if metadata is None:
metadata = {worker_id: {}}

return {"data": data, AutoRestartBatchKeys.PL_RESTART_META: metadata}


Expand Down Expand Up @@ -480,6 +503,9 @@ def patch_dataloader_iterator(
will extract the current iteration as part of the metadata returned by a custom batch.
"""

if not _FaultTolerantMode.detect_current_mode().is_automatic:
return

assert isinstance(dataloader.dataset, (CaptureMapDataset, CaptureIterableDataset))

def _next_data_wrapper(fn, it, dl, num_batches_fetched) -> Callable:
Expand Down Expand Up @@ -527,8 +553,14 @@ def wrapper():

def _add_capture_metadata_collate(dataloader: DataLoader) -> None:
"""Wrap default collate function to retrive captured dataset state dict when fault tolerant is enabled."""
faut_tolerant_mode = _FaultTolerantMode.detect_current_mode()
if not faut_tolerant_mode.is_enabled:
return
dataloader.collate_fn = partial(
_capture_metadata_collate, dataset=dataloader.dataset, default_collate=dataloader.collate_fn
_capture_metadata_collate,
dataset=dataloader.dataset,
collate_fn=dataloader.collate_fn,
fault_tolerant_mode=faut_tolerant_mode,
)


Expand Down Expand Up @@ -589,3 +621,106 @@ def state_dict(self) -> Dict[str, Any]:

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
...


class _StatefulDataLoaderIter:
"""This mixin is used to make PyTorch DataLoaderIter stateful."""

def __accumulate_state(self, sampler_state: Dict[str, Any]) -> None:
# store sampler state within a queue alongside its idx.
self._sampler_state_idx = getattr(self, "_sampler_state_idx", 0) + 1
self._sampler_state.append((sampler_state, self._sampler_state_idx))

def _store_sampler_state(self) -> None:
"""This function is used to extract the sampler states if any."""
sampler_state = {
k: v.state_dict()
for k, v in self._loader.__dict__.items()
if isinstance(v, _SupportsStateDict) and k != "dataset"
}

self.__accumulate_state(sampler_state)

def _next_index(self) -> Any:
indexes = super()._next_index()
self._store_sampler_state()
return indexes

def _prepare_loader(self, loader):
if not isinstance(loader.collate_fn, partial):
loader.collate_fn = partial(_capture_metadata_collate, dataset=loader.dataset, collate_fn=loader.collate_fn)
self._loader = loader
self._data_fetcher: "pl.utilities.fetching.AbstractDataFetcher" = loader._lightning_fetcher
self.num_batches_fetched = 0
self._sampler_state = []
self._sampler_state_idx = 0

def __del__(self) -> None:
if isinstance(self._loader.collate_fn, partial):
self._loader.collate_fn = self._loader.collate_fn.keywords["collate_fn"]

def _next_data(self) -> Any:
combined_batch = super()._next_data()

batch, state = combined_batch["data"], combined_batch[AutoRestartBatchKeys.PL_RESTART_META]

self.num_batches_fetched += 1

sampler_state, sampler_state_idx = self._sampler_state.pop(0)
# there is no workers within the samplers
worker_id = list(state.keys())[0]

state = [
IteratorState(
num_workers=self._loader.num_workers,
sampler_state=sampler_state,
dataset_state=state,
worker_id=worker_id,
num_batches_fetched=self.num_batches_fetched,
)
]
# ensures there is an alignement between the sampler state and currently fetched batch
assert sampler_state_idx == self.num_batches_fetched
self._data_fetcher._store_dataloader_iter_state(self, state)
return batch


class _SingleProcessDataLoaderIterStateful(_StatefulDataLoaderIter, _SingleProcessDataLoaderIter):
def __init__(self, loader: DataLoader):
self._prepare_loader(loader)
super().__init__(loader)


class _MultiProcessingDataLoaderIterStateful(_StatefulDataLoaderIter, _MultiProcessingDataLoaderIter):
def __init__(self, loader: DataLoader):
self._prepare_loader(loader)
super().__init__(loader)


def _get_iterator(self) -> "_BaseDataLoaderIter":
if not hasattr(self, "_lightning_fetcher"):
raise MisconfigurationException(
"A stateful iterator should be used only when a DataFetcher has been attached to the DataLoader."
)
if self.num_workers == 0:
return _SingleProcessDataLoaderIterStateful(self)
else:
if hasattr(self, "check_worker_number_rationality"):
self.check_worker_number_rationality()
return _MultiProcessingDataLoaderIterStateful(self)


def _patch_dataloader_get_iterators() -> None:
"""This function is used to replace the DataLoader iterator by their stateful version."""
if not hasattr(DataLoader, "_ori_get_iterator"):
DataLoader._ori_get_iterator = DataLoader._get_iterator
DataLoader._get_iterator = _get_iterator


def _teardown_dataloader_get_iterators() -> None:
"""This function is used to restore the DataLoader `get_iterator` with its original one."""
# cleanup the get_iterator replacement in case of Fault Tolerant Training.
get_iterator = getattr(DataLoader, "_ori_get_iterator", None)
if get_iterator:
DataLoader._get_iterator = get_iterator
del DataLoader._ori_get_iterator
12 changes: 12 additions & 0 deletions pytorch_lightning/utilities/fetching.py
Expand Up @@ -99,6 +99,8 @@ def setup(
if self.profiler is not None and stage is None:
raise MisconfigurationException("When providing a profiler, the stage should be provided too.")

self._attach_data_fetcher()

@staticmethod
def _add_capture_metadata_collate(dataloader: Iterable) -> None:
if not isinstance(dataloader, (DataLoader, CombinedLoader)):
Expand Down Expand Up @@ -190,6 +192,16 @@ def collect_state(iterator: Iterator):

return apply_to_collection(self.loader_iters, Iterator, collect_state)

def _attach_data_fetcher(self):
def _attach_data_fetcher_fn(loader: DataLoader):
if isinstance(loader, CycleIterator):
loader = loader.loader

if isinstance(loader, DataLoader) and _fault_tolerant_training():
loader._lightning_fetcher = self

apply_to_collection(self.loaders, (DataLoader, CycleIterator), _attach_data_fetcher_fn)

def __iter__(self) -> Generator[Tuple[Any, bool], None, None]:
if self.dataloader is None:
raise MisconfigurationException("The iterate hasn't been provided. HINT: Did you call setup function ?.")
Expand Down

0 comments on commit 1702036

Please sign in to comment.