Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fault Tolerant Manual: Add stateful dataloader iter #10674

Merged
merged 41 commits into from Nov 23, 2021
Merged
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
ca8dbbb
update
tchaton Nov 19, 2021
3f0d28a
update
tchaton Nov 19, 2021
0c47670
update
tchaton Nov 19, 2021
3e5b52e
update
tchaton Nov 19, 2021
24c8245
update
tchaton Nov 19, 2021
bcd5569
update
tchaton Nov 19, 2021
8d3844d
update
tchaton Nov 19, 2021
1829b46
update
tchaton Nov 19, 2021
a1a364a
typo
tchaton Nov 19, 2021
de41675
update on comments
tchaton Nov 22, 2021
8178a32
Update pytorch_lightning/utilities/auto_restart.py
kaushikb11 Nov 22, 2021
00b9355
update
tchaton Nov 22, 2021
96f0517
update
tchaton Nov 22, 2021
297fd67
Merge branch 'fault_tolerant_enum' of https://github.com/PyTorchLight…
tchaton Nov 22, 2021
9800cba
update
tchaton Nov 22, 2021
427ed03
docstring improvement
tchaton Nov 22, 2021
ae712b0
update
tchaton Nov 22, 2021
9a5166d
Rename and simplify
carmocca Nov 22, 2021
b5fa819
Add comment
carmocca Nov 22, 2021
c82b2f2
update
tchaton Nov 22, 2021
2ede205
update
tchaton Nov 22, 2021
b16c4c0
update
tchaton Nov 22, 2021
ce9c23c
update
tchaton Nov 22, 2021
2baddb9
update
tchaton Nov 22, 2021
97548bb
update
tchaton Nov 22, 2021
d953ae9
update
tchaton Nov 22, 2021
41ffbab
use_teardown
tchaton Nov 22, 2021
d04596d
Use `Protocol`
carmocca Nov 22, 2021
ff7b836
Simplify test
carmocca Nov 22, 2021
a5698e6
Update CHANGELOG.md
carmocca Nov 22, 2021
79fdacc
update
tchaton Nov 22, 2021
916b520
update
tchaton Nov 22, 2021
4b67fbf
update
tchaton Nov 22, 2021
c9481e2
update
tchaton Nov 22, 2021
ef29342
update
tchaton Nov 22, 2021
4a1fff7
update
tchaton Nov 22, 2021
cb27e30
update
tchaton Nov 22, 2021
7903d24
resolve tests
tchaton Nov 22, 2021
20d19a1
update
tchaton Nov 22, 2021
1104cbc
update
tchaton Nov 23, 2021
f071f9a
change to 0
tchaton Nov 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Add `_SupportsStateDict` protocol to detect if classes are stateful ([#10646](https://github.com/PyTorchLightning/pytorch-lightning/issues/10646))
* Add `_FaultTolerantMode` enum used to track different supported fault tolerant modes ([#10645](https://github.com/PyTorchLightning/pytorch-lightning/issues/10645))
* Add a `_rotate_worker_indices` utility to reload the state according the latest worker ([#10647](https://github.com/PyTorchLightning/pytorch-lightning/issues/10647))
* Add stateful workers ([#10674](https://github.com/PyTorchLightning/pytorch-lightning/issues/10674))


-
Expand Down
7 changes: 2 additions & 5 deletions pytorch_lightning/trainer/data_loading.py
Expand Up @@ -15,7 +15,6 @@
import os
from abc import ABC
from copy import deepcopy
from functools import partial
from typing import Any, Callable, Collection, List, Optional, Tuple, Union

from torch.utils.data import DataLoader, RandomSampler, Sampler, SequentialSampler
Expand All @@ -29,7 +28,7 @@
from pytorch_lightning.trainer.supporters import CombinedLoader, CycleIterator
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.auto_restart import _capture_metadata_collate
from pytorch_lightning.utilities.auto_restart import _add_capture_metadata_collate
from pytorch_lightning.utilities.data import (
_auto_add_worker_init_fn,
_update_dataloader,
Expand Down Expand Up @@ -441,9 +440,7 @@ def request_dataloader(
def _add_sampler_metadata_collate(dataloader: DataLoader) -> None:
"""Wrap default collate function to retrive ``FastForwardSampler`` state dict when fault tolerant is
enabled."""
dataloader.collate_fn = partial(
_capture_metadata_collate, dataset=dataloader.dataset, default_collate=dataloader.collate_fn
)
_add_capture_metadata_collate(dataloader)
tchaton marked this conversation as resolved.
Show resolved Hide resolved

@staticmethod
def _resolve_overfit_batches(dataloader: Collection[DataLoader]) -> Collection[DataLoader]:
Expand Down
163 changes: 154 additions & 9 deletions pytorch_lightning/utilities/auto_restart.py
Expand Up @@ -21,11 +21,17 @@
import numpy as np
import torch
from torch.utils.data import Dataset, get_worker_info, Sampler
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader, IterableDataset
from torch.utils.data.dataloader import (
_BaseDataLoaderIter,
_MultiProcessingDataLoaderIter,
_SingleProcessDataLoaderIter,
DataLoader,
IterableDataset,
)
from typing_extensions import Protocol, runtime_checkable

import pytorch_lightning as pl
from pytorch_lightning.utilities.enums import AutoRestartBatchKeys
from pytorch_lightning.utilities.enums import _FaultTolerantMode, AutoRestartBatchKeys
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training

Expand Down Expand Up @@ -435,8 +441,10 @@ def _find_current_worker(iterator: Iterator) -> Dict[str, Optional[int]]:
return {"num_workers": num_workers, "previous_worker": previous_worker}


def _capture_metadata_collate(samples: List, dataset: Dataset, default_collate: Callable) -> Dict:
"""A collate function that adds the state dict of a :class:`CaptureIterableDataset` or
def _capture_metadata_collate(
samples: List, dataset: Dataset, collate_fn: Callable, fault_tolerant_mode: _FaultTolerantMode
) -> Any:
"""A collate_fn function that adds the state dict of a :class:`CaptureIterableDataset` or
:class:`CaptureMapDataset` used in the worker processes. This function gets executed within the worker
processes. The structure will be:

Expand All @@ -447,10 +455,25 @@ def _capture_metadata_collate(samples: List, dataset: Dataset, default_collate:
"__pl_restart_meta": {"sampler_name0": state_dict0, "sampler_name1": state_dict1},
}
"""
data = default_collate(samples)
if not isinstance(dataset, (CaptureIterableDataset, CaptureMapDataset)):
return data
metadata = dataset.state_dict()
data = collate_fn(samples)
metadata = None
if fault_tolerant_mode.is_automatic:
metadata = dataset.state_dict()
else:
state_dict_fn = getattr(dataset, "state_dict", None)
info = get_worker_info()
worker_id = info.id if info else 0
if state_dict_fn:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
metadata = state_dict_fn()
if worker_id not in metadata:
if info and info.num_workers > 1:
raise MisconfigurationException(
f"The state_dict returned by {dataset} needs to be indexed by `worker_id` integer keys."
)
metadata = {0: metadata}
if metadata is None:
metadata = {worker_id: {}}

return {"data": data, AutoRestartBatchKeys.PL_RESTART_META: metadata}
tchaton marked this conversation as resolved.
Show resolved Hide resolved


Expand Down Expand Up @@ -480,6 +503,9 @@ def patch_dataloader_iterator(
will extract the current iteration as part of the metadata returned by a custom batch.
"""

if not _FaultTolerantMode.detect_current_mode().is_automatic:
return
tchaton marked this conversation as resolved.
Show resolved Hide resolved

assert isinstance(dataloader.dataset, (CaptureMapDataset, CaptureIterableDataset))

def _next_data_wrapper(fn, it, dl, num_batches_fetched) -> Callable:
Expand Down Expand Up @@ -527,8 +553,14 @@ def wrapper():

def _add_capture_metadata_collate(dataloader: DataLoader) -> None:
"""Wrap default collate function to retrive captured dataset state dict when fault tolerant is enabled."""
faut_tolerant_mode = _FaultTolerantMode.detect_current_mode()
if not faut_tolerant_mode.is_enabled:
return
dataloader.collate_fn = partial(
_capture_metadata_collate, dataset=dataloader.dataset, default_collate=dataloader.collate_fn
_capture_metadata_collate,
dataset=dataloader.dataset,
collate_fn=dataloader.collate_fn,
fault_tolerant_mode=faut_tolerant_mode,
)


Expand Down Expand Up @@ -589,3 +621,116 @@ def state_dict(self) -> Dict[str, Any]:

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
...


class _StatefulDataLoaderIter:
"""This mixin is used to make PyTorch DataLoaderIter stateful."""

def _reset(self, loader: DataLoader, first_iter: bool = False):
super()._reset(loader, first_iter=first_iter)
self._loader = loader
self.num_batches_fetched = 0

def __accumulate_state(self, sampler_state: Dict[str, Any]) -> None:
# initialize the queue if it doesn't exist.
if not hasattr(self, "_sampler_state"):
self._sampler_state = []
self._sampler_state_idx = 0
tchaton marked this conversation as resolved.
Show resolved Hide resolved

# store sampler state within a queue alongside its idx.
self._sampler_state_idx = getattr(self, "_sampler_state_idx", 0) + 1
self._sampler_state.append((sampler_state, self._sampler_state_idx))

def _store_sampler_state(self) -> None:
"""This function is used to extract the sampler states if any."""
sampler_state = {
k: v.state_dict()
for k, v in self._loader.__dict__.items()
if isinstance(v, _SupportsStateDict) and k != "dataset"
}

self.__accumulate_state(sampler_state)

def _next_index(self) -> Any:
indexes = super()._next_index()
self._store_sampler_state()
return indexes

def _prepare_loader(self, loader):
if not isinstance(loader.collate_fn, partial):
loader.collate_fn = partial(_capture_metadata_collate, dataset=loader.dataset, collate_fn=loader.collate_fn)
self._loader = loader
self._data_fetcher: "pl.utilities.fetching.AbstractDataFetcher" = loader._lightning_fetcher
self.num_batches_fetched = 0
self._sampler_state = []
self._sampler_state_idx = 0

def __del__(self) -> None:
if isinstance(self._loader.collate_fn, partial):
self._loader.collate_fn = self._loader.collate_fn.keywords["collate_fn"]

def _next_data(self) -> Any:
combined_batch = super()._next_data()

batch, state = combined_batch["data"], combined_batch[AutoRestartBatchKeys.PL_RESTART_META]

self.num_batches_fetched += 1

sampler_state, sampler_state_idx = self._sampler_state.pop(0)
# there is no workers within the samplers
worker_id = list(state.keys())[0]

state = [
IteratorState(
num_workers=self._loader.num_workers,
sampler_state=sampler_state,
dataset_state=state,
worker_id=worker_id,
num_batches_fetched=self.num_batches_fetched,
)
]
# ensures there is an alignement between the sampler state and currently fetched batch
assert sampler_state_idx == self.num_batches_fetched
self._data_fetcher._store_dataloader_iter_state(self, state)
return batch


class _SingleProcessDataLoaderIterStateful(_StatefulDataLoaderIter, _SingleProcessDataLoaderIter):
def __init__(self, loader: DataLoader):
self._prepare_loader(loader)
super().__init__(loader)


class _MultiProcessingDataLoaderIterStateful(_StatefulDataLoaderIter, _MultiProcessingDataLoaderIter):
def __init__(self, loader: DataLoader):
self._prepare_loader(loader)
super().__init__(loader)


def _get_iterator(self) -> "_BaseDataLoaderIter":
if not hasattr(self, "_lightning_fetcher"):
raise MisconfigurationException(
"A stateful iterator should be used only when a DataFetcher has been attached to the DataLoader."
)
if self.num_workers == 0:
return _SingleProcessDataLoaderIterStateful(self)
else:
if hasattr(self, "check_worker_number_rationality"):
self.check_worker_number_rationality()
tchaton marked this conversation as resolved.
Show resolved Hide resolved
return _MultiProcessingDataLoaderIterStateful(self)


def _patch_dataloader_get_iterators() -> None:
"""This function is used to replace the DataLoader iterator by their stateful version."""
if not hasattr(DataLoader, "_ori_get_iterator"):
DataLoader._ori_get_iterator = DataLoader._get_iterator
DataLoader._get_iterator = _get_iterator


def _teardown_dataloader_get_iterators() -> None:
"""This function is used to restore the DataLoader `get_iterator` with its original one."""
# cleanup the get_iterator replacement in case of Fault Tolerant Training.
get_iterator = getattr(DataLoader, "_ori_get_iterator", None)
if get_iterator:
DataLoader._get_iterator = get_iterator
del DataLoader._ori_get_iterator
12 changes: 12 additions & 0 deletions pytorch_lightning/utilities/fetching.py
Expand Up @@ -99,6 +99,8 @@ def setup(
if self.profiler is not None and stage is None:
raise MisconfigurationException("When providing a profiler, the stage should be provided too.")

self._attach_data_fetcher()

@staticmethod
def _add_capture_metadata_collate(dataloader: Iterable) -> None:
if not isinstance(dataloader, (DataLoader, CombinedLoader)):
Expand Down Expand Up @@ -190,6 +192,16 @@ def collect_state(iterator: Iterator):

return apply_to_collection(self.loader_iters, Iterator, collect_state)

def _attach_data_fetcher(self):
def _attach_data_fetcher_fn(loader: DataLoader):
if isinstance(loader, CycleIterator):
loader = loader.loader

if isinstance(loader, DataLoader) and _fault_tolerant_training():
loader._lightning_fetcher = self

apply_to_collection(self.loaders, (DataLoader, CycleIterator), _attach_data_fetcher_fn)

def __iter__(self) -> Generator[Tuple[Any, bool], None, None]:
if self.dataloader is None:
raise MisconfigurationException("The iterate hasn't been provided. HINT: Did you call setup function ?.")
Expand Down