Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fault Tolerant Manual: utilities cleanup #10703

Merged
merged 58 commits into from Nov 24, 2021
Merged
Show file tree
Hide file tree
Changes from 56 commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
ca8dbbb
update
tchaton Nov 19, 2021
3f0d28a
update
tchaton Nov 19, 2021
0c47670
update
tchaton Nov 19, 2021
3e5b52e
update
tchaton Nov 19, 2021
24c8245
update
tchaton Nov 19, 2021
bcd5569
update
tchaton Nov 19, 2021
8d3844d
update
tchaton Nov 19, 2021
1829b46
update
tchaton Nov 19, 2021
a1a364a
typo
tchaton Nov 19, 2021
de41675
update on comments
tchaton Nov 22, 2021
8178a32
Update pytorch_lightning/utilities/auto_restart.py
kaushikb11 Nov 22, 2021
00b9355
update
tchaton Nov 22, 2021
96f0517
update
tchaton Nov 22, 2021
297fd67
Merge branch 'fault_tolerant_enum' of https://github.com/PyTorchLight…
tchaton Nov 22, 2021
9800cba
update
tchaton Nov 22, 2021
427ed03
docstring improvement
tchaton Nov 22, 2021
ae712b0
update
tchaton Nov 22, 2021
9a5166d
Rename and simplify
carmocca Nov 22, 2021
b5fa819
Add comment
carmocca Nov 22, 2021
c82b2f2
update
tchaton Nov 22, 2021
2ede205
update
tchaton Nov 22, 2021
b16c4c0
update
tchaton Nov 22, 2021
ce9c23c
update
tchaton Nov 22, 2021
2baddb9
update
tchaton Nov 22, 2021
97548bb
update
tchaton Nov 22, 2021
d953ae9
update
tchaton Nov 22, 2021
41ffbab
use_teardown
tchaton Nov 22, 2021
d04596d
Use `Protocol`
carmocca Nov 22, 2021
ff7b836
Simplify test
carmocca Nov 22, 2021
a5698e6
Update CHANGELOG.md
carmocca Nov 22, 2021
79fdacc
update
tchaton Nov 22, 2021
916b520
update
tchaton Nov 22, 2021
4b67fbf
update
tchaton Nov 22, 2021
c9481e2
update
tchaton Nov 22, 2021
ef29342
update
tchaton Nov 22, 2021
4a1fff7
update
tchaton Nov 22, 2021
cb27e30
update
tchaton Nov 22, 2021
7903d24
resolve tests
tchaton Nov 22, 2021
20d19a1
update
tchaton Nov 22, 2021
1104cbc
update
tchaton Nov 23, 2021
f071f9a
change to 0
tchaton Nov 23, 2021
b777dc3
update
tchaton Nov 23, 2021
2da1674
update
tchaton Nov 23, 2021
647bebd
merge with master
tchaton Nov 23, 2021
dbcfa65
update changelog
tchaton Nov 23, 2021
ae18166
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 23, 2021
a527929
update
tchaton Nov 23, 2021
5827e8b
Merge branch 'add_reloading' of https://github.com/PyTorchLightning/p…
tchaton Nov 23, 2021
97421c3
update
tchaton Nov 23, 2021
51cf75b
update
tchaton Nov 23, 2021
35644b8
update on comments
tchaton Nov 23, 2021
04a5c3d
Merge branch 'add_reloading' into fault_tolerant_cleanup
tchaton Nov 23, 2021
26d46b0
update changelog
tchaton Nov 23, 2021
5337e7d
remove deadcode
tchaton Nov 23, 2021
30cc2ec
update
tchaton Nov 23, 2021
ec0bdad
update
tchaton Nov 23, 2021
34d86cb
update
tchaton Nov 24, 2021
fb6579b
update
tchaton Nov 24, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Add stateful workers ([#10674](https://github.com/PyTorchLightning/pytorch-lightning/issues/10674))
* Add an utility to collect the states across processes ([#10639](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639))
* Add logic to reload the states across data loading components ([#10699](https://github.com/PyTorchLightning/pytorch-lightning/issues/10699))
* Cleanup some fault tolerant utilities ([#10703](https://github.com/PyTorchLightning/pytorch-lightning/issues/10703))

-

Expand Down
133 changes: 45 additions & 88 deletions pytorch_lightning/utilities/auto_restart.py
Expand Up @@ -394,52 +394,6 @@ def _cycle_to_next_worker_and_reset(dataloader: DataLoader, state_dict: Dict[str
return iter_dataloader


def _dataloader_to_state_dict(
tchaton marked this conversation as resolved.
Show resolved Hide resolved
dataloader: DataLoader, iterator: Iterator, num_batches_processed: int = None
) -> List[Dict[str, Any]]:
"""Convert a dataloader to its associated state dict."""
out = {}
if iterator is not None:
out.update(_find_current_worker(iterator))

if not isinstance(dataloader.dataset, CaptureIterableDataset):
fast_forward_sampler = _find_fast_forward_samplers(dataloader)
if fast_forward_sampler is not None:
out.update(fast_forward_sampler.state_dict(num_batches_processed=num_batches_processed))
return out


def _dataloader_load_state_dict(dataloader: DataLoader, state_dict: List[Dict[str, Any]]) -> DataLoader:
"""Reload ``DataLoader`` fast-forward sampler state dict."""
fast_forward_sampler = _find_fast_forward_samplers(dataloader)

if isinstance(fast_forward_sampler, Sampler):
state_dict = {k: v for k, v in state_dict.items() if k not in ("num_workers", "previous_worker")}
fast_forward_sampler.load_state_dict(state_dict)

return dataloader


def _find_current_worker(iterator: Iterator) -> Dict[str, Optional[int]]:
"""Find the current DataLoader Iterator worker if multiple workers were used."""
# get the current number of workers
num_workers = getattr(iterator, "_num_workers", 0)
if isinstance(iterator, _MultiProcessingDataLoaderIter):
# fetch next worker
next_worker = (next(iterator._worker_queue_idx_cycle)) % num_workers
# get the current worker from next one
previous_worker = (next_worker - 1) % num_workers
# reset back the `worker_queue_idx` to current one, so we can keep
# going without perturbation.
while next(iterator._worker_queue_idx_cycle) != previous_worker:
pass
else:
previous_worker = None

# return the captured metadata.
return {"num_workers": num_workers, "previous_worker": previous_worker}


def _capture_metadata_collate(
samples: List, dataset: Dataset, collate_fn: Callable, fault_tolerant_mode: _FaultTolerantMode
) -> Any:
Expand Down Expand Up @@ -476,6 +430,48 @@ def _capture_metadata_collate(
return {"data": data, AutoRestartBatchKeys.PL_RESTART_META: metadata}


# TODO: Merge this code within stateful DataLoaderIter.
tchaton marked this conversation as resolved.
Show resolved Hide resolved
def _next_data_wrapper(fn, it, dl, num_batches_fetched, data_fetcher) -> Callable:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
@wraps(fn)
def wrapper():
nonlocal num_batches_fetched
nonlocal it
nonlocal dl

carmocca marked this conversation as resolved.
Show resolved Hide resolved
dataset = dl.dataset
combined_batch = fn()
carmocca marked this conversation as resolved.
Show resolved Hide resolved

batch, state = combined_batch["data"], combined_batch[AutoRestartBatchKeys.PL_RESTART_META]
num_batches_fetched += 1

if isinstance(dataset, CaptureIterableDataset):
state = [
IteratorState(
num_workers=dl.num_workers,
sampler_state=iterator_state,
num_batches_fetched=num_batches_fetched,
worker_id=list(iterator_state.keys())[0],
name=sampler_iter_name,
)
for sampler_iter_name, iterator_state in state.items()
]
elif isinstance(dataset, CaptureMapDataset):
ff_sampler = _find_fast_forward_samplers(dl)
state = [
IteratorState(
num_workers=dl.num_workers,
sampler_state=ff_sampler.state_dict(num_batches_fetched),
dataset_state=state,
worker_id=list(state.keys())[0],
num_batches_fetched=num_batches_fetched,
)
]
data_fetcher._store_dataloader_iter_state(it, state)
return batch

return wrapper


def patch_dataloader_iterator(
dataloader: DataLoader,
iterator: Iterator,
Expand Down Expand Up @@ -506,48 +502,9 @@ def patch_dataloader_iterator(
return

assert isinstance(dataloader.dataset, (CaptureMapDataset, CaptureIterableDataset))

def _next_data_wrapper(fn, it, dl, num_batches_fetched) -> Callable:
@wraps(fn)
def wrapper():
nonlocal num_batches_fetched
nonlocal it
nonlocal dl

dataset = dl.dataset
combined_batch = fn()

batch, state = combined_batch["data"], combined_batch[AutoRestartBatchKeys.PL_RESTART_META]
num_batches_fetched += 1

if isinstance(dataset, CaptureIterableDataset):
state = [
IteratorState(
num_workers=dataloader.num_workers,
sampler_state=iterator_state,
num_batches_fetched=num_batches_fetched,
worker_id=list(iterator_state.keys())[0],
name=sampler_iter_name,
)
for sampler_iter_name, iterator_state in state.items()
]
elif isinstance(dataset, CaptureMapDataset):
ff_sampler = _find_fast_forward_samplers(dl)
state = [
IteratorState(
num_workers=dataloader.num_workers,
sampler_state=ff_sampler.state_dict(num_batches_fetched),
dataset_state=state,
worker_id=list(state.keys())[0],
num_batches_fetched=num_batches_fetched,
)
]
data_fetcher._store_dataloader_iter_state(it, state)
return batch

return wrapper

iterator._next_data = _next_data_wrapper(iterator._next_data, iterator, dataloader, num_batches_fetched)
iterator._next_data = _next_data_wrapper(
iterator._next_data, iterator, dataloader, num_batches_fetched, data_fetcher
)


def _add_capture_metadata_collate(dataloader: DataLoader) -> None:
Expand Down
32 changes: 18 additions & 14 deletions pytorch_lightning/utilities/data.py
Expand Up @@ -24,8 +24,8 @@
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.auto_restart import CaptureIterableDataset, CaptureMapDataset, FastForwardSampler
from pytorch_lightning.utilities.enums import _FaultTolerantMode
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.seed import pl_worker_init_function
from pytorch_lightning.utilities.warnings import WarningCache

Expand Down Expand Up @@ -246,17 +246,8 @@ def _get_dataloader_init_kwargs(
dl_kwargs["batch_sampler"] = None
dl_kwargs["sampler"] = None

if _fault_tolerant_training():
dataset = dl_kwargs["dataset"]
if isinstance(dataset, IterableDataset):
# wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states.
dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dl_kwargs["dataset"])
elif get_len(dataset) != float("inf"):
dl_kwargs["dataset"] = CaptureMapDataset(dataset=dl_kwargs["dataset"])
else:
raise MisconfigurationException(
"This shouldn't happen, please open an issue on Lightning Github repository."
)
if _FaultTolerantMode.detect_current_mode().is_automatic:
dl_kwargs = _apply_fault_tolerant_automatic_capture_dataset_wrapper(dl_kwargs)

return dl_kwargs

Expand All @@ -271,6 +262,7 @@ def _dataloader_init_kwargs_resolve_sampler(
Lightning can keep track of its indices. If fault tolerant training is enabled, the sampler will be wrapped into a
`FastForwardSampler`.
"""
fault_tolerant_mode = _FaultTolerantMode.detect_current_mode()
batch_sampler = getattr(dataloader, "batch_sampler")
is_predicting = mode == RunningStage.PREDICTING
# checking the batch sampler type is different than PyTorch default.
Expand All @@ -283,7 +275,7 @@ def _dataloader_init_kwargs_resolve_sampler(
if is_predicting:
batch_sampler = IndexBatchSamplerWrapper(batch_sampler)

if _fault_tolerant_training():
if fault_tolerant_mode.is_automatic:
fast_forward_sampler = batch_sampler = FastForwardSampler(batch_sampler)
fast_forward_sampler.setup(dataloader_batch_size=1)

Expand All @@ -295,7 +287,7 @@ def _dataloader_init_kwargs_resolve_sampler(
"drop_last": False,
}

if _fault_tolerant_training():
if fault_tolerant_mode.is_automatic:
fast_forward_sampler = sampler = FastForwardSampler(sampler)
fast_forward_sampler.setup(dataloader_batch_size=dataloader.batch_size)

Expand All @@ -305,3 +297,15 @@ def _dataloader_init_kwargs_resolve_sampler(
def _auto_add_worker_init_fn(dataloader: DataLoader, rank: int) -> None:
if int(os.environ.get("PL_SEED_WORKERS", 0)) and dataloader.worker_init_fn is None:
dataloader.worker_init_fn = partial(pl_worker_init_function, rank=rank)


def _apply_fault_tolerant_automatic_capture_dataset_wrapper(dl_kwargs: Dict) -> Dict:
dataset = dl_kwargs["dataset"]
if isinstance(dataset, IterableDataset):
# wrap the `IterableDataset` into a `CaptureIterableDataset` to record sampler states.
dl_kwargs["dataset"] = CaptureIterableDataset(dataset=dl_kwargs["dataset"])
tchaton marked this conversation as resolved.
Show resolved Hide resolved
elif get_len(dataset) != float("inf"):
dl_kwargs["dataset"] = CaptureMapDataset(dataset=dl_kwargs["dataset"])
tchaton marked this conversation as resolved.
Show resolved Hide resolved
else:
raise MisconfigurationException("This shouldn't happen, please open an issue on Lightning Github repository.")
return dl_kwargs
40 changes: 0 additions & 40 deletions tests/utilities/test_auto_restart.py
Expand Up @@ -39,8 +39,6 @@
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities.auto_restart import (
_add_capture_metadata_collate,
_dataloader_load_state_dict,
_dataloader_to_state_dict,
_MultiProcessingDataLoaderIterStateful,
_patch_dataloader_get_iterators,
_reload_dataloader_state_dict,
Expand Down Expand Up @@ -665,44 +663,6 @@ def create_iterable_dataset(batch_size, num_workers, attr_name="iter_sampler", w
return dataset


def test_dataloader_to_state_dict_and_reload():
tchaton marked this conversation as resolved.
Show resolved Hide resolved
"""
Note: Those utilities are used only with DataLoader wrapping a ``mapping`` based dataset.
"""

def create_dataloader():
dataset = range(50)
batch_size = 8
sampler = FastForwardSampler(SequentialSampler(dataset))
sampler.setup(batch_size)

return DataLoader(dataset, sampler=sampler, batch_size=batch_size)

dataloader = create_dataloader()
iter_dataloader = iter(dataloader)
_ = next(iter_dataloader)
_ = next(iter_dataloader)

state_dict = _dataloader_to_state_dict(dataloader, iter_dataloader)
assert state_dict == {
"num_workers": 0,
"previous_worker": None,
0: {"current_iteration": 16},
}

dataloader = create_dataloader()
dataloader = _dataloader_load_state_dict(dataloader, state_dict)
iter_dataloader = iter(dataloader)
_ = next(iter_dataloader)

state_dict = _dataloader_to_state_dict(dataloader, iter_dataloader)
assert state_dict == {
"num_workers": 0,
"previous_worker": None,
0: {"current_iteration": 24},
}


@pytest.mark.parametrize("use_fault_tolerant", ["0", "1"])
def test_data_loading_wraps_dataset_and_samplers(use_fault_tolerant, tmpdir):
"""This test ensures the dataset and sampler are properly wrapped when fault tolerant is enabled."""
Expand Down