Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fault Tolerant Manual: Add support for DDP #10638

Merged
merged 135 commits into from Nov 25, 2021
Merged
Show file tree
Hide file tree
Changes from 126 commits
Commits
Show all changes
135 commits
Select commit Hold shift + click to select a range
102e5ef
wip
tchaton Nov 11, 2021
f432ad4
update
tchaton Nov 11, 2021
d996b16
update
tchaton Nov 11, 2021
c67bc13
update
tchaton Nov 12, 2021
647f7b5
update
tchaton Nov 16, 2021
5ce3a7a
update
tchaton Nov 16, 2021
2b7dc18
update
tchaton Nov 16, 2021
fb26a19
update
tchaton Nov 16, 2021
d69f9c1
update
tchaton Nov 16, 2021
75ba813
update
tchaton Nov 16, 2021
7dba1a7
update
tchaton Nov 16, 2021
07f8a1f
update
tchaton Nov 16, 2021
a8e07d0
update
tchaton Nov 16, 2021
f5a5a80
update
tchaton Nov 16, 2021
06ffd2a
update changelog
tchaton Nov 16, 2021
3fe350d
update
tchaton Nov 16, 2021
ae7548b
update
tchaton Nov 16, 2021
eefe269
update
tchaton Nov 16, 2021
dff4809
update
tchaton Nov 16, 2021
abd772b
update
tchaton Nov 16, 2021
ede5aa1
update
tchaton Nov 17, 2021
5253f04
update
tchaton Nov 17, 2021
0800bd0
update on comments
tchaton Nov 17, 2021
fdbfab9
update
tchaton Nov 18, 2021
d7c9913
Merge branch 'master' into remove_example
tchaton Nov 18, 2021
2a93f23
update
tchaton Nov 18, 2021
06b3af8
Merge branch 'remove_example' of https://github.com/PyTorchLightning/…
tchaton Nov 18, 2021
0547dce
update
tchaton Nov 18, 2021
3cd45c9
update
tchaton Nov 18, 2021
ae3dc40
cleanup
tchaton Nov 18, 2021
e916a92
update
tchaton Nov 18, 2021
440c90c
update
tchaton Nov 18, 2021
6d21930
update
tchaton Nov 19, 2021
4203672
update
tchaton Nov 19, 2021
61ff6c2
update
tchaton Nov 19, 2021
0b97a1c
Merge branch 'lhotse_example' of https://github.com/PyTorchLightning/…
tchaton Nov 19, 2021
1db1347
update
tchaton Nov 19, 2021
0109741
update changelog
tchaton Nov 19, 2021
f714e69
update typing
tchaton Nov 19, 2021
e10cadb
update
tchaton Nov 19, 2021
713e62f
update
tchaton Nov 19, 2021
f9078e6
update
tchaton Nov 19, 2021
da36679
update
tchaton Nov 19, 2021
80ea1db
update
tchaton Nov 19, 2021
b0e69b5
update
tchaton Nov 19, 2021
ca8dbbb
update
tchaton Nov 19, 2021
3f0d28a
update
tchaton Nov 19, 2021
0c47670
update
tchaton Nov 19, 2021
3e5b52e
update
tchaton Nov 19, 2021
24c8245
update
tchaton Nov 19, 2021
bcd5569
update
tchaton Nov 19, 2021
8d3844d
update
tchaton Nov 19, 2021
1829b46
update
tchaton Nov 19, 2021
a1a364a
typo
tchaton Nov 19, 2021
de41675
update on comments
tchaton Nov 22, 2021
8178a32
Update pytorch_lightning/utilities/auto_restart.py
kaushikb11 Nov 22, 2021
00b9355
update
tchaton Nov 22, 2021
96f0517
update
tchaton Nov 22, 2021
297fd67
Merge branch 'fault_tolerant_enum' of https://github.com/PyTorchLight…
tchaton Nov 22, 2021
9800cba
update
tchaton Nov 22, 2021
427ed03
docstring improvement
tchaton Nov 22, 2021
ae712b0
update
tchaton Nov 22, 2021
9a5166d
Rename and simplify
carmocca Nov 22, 2021
b5fa819
Add comment
carmocca Nov 22, 2021
c82b2f2
update
tchaton Nov 22, 2021
2ede205
update
tchaton Nov 22, 2021
b16c4c0
update
tchaton Nov 22, 2021
ce9c23c
update
tchaton Nov 22, 2021
2baddb9
update
tchaton Nov 22, 2021
97548bb
update
tchaton Nov 22, 2021
d953ae9
update
tchaton Nov 22, 2021
41ffbab
use_teardown
tchaton Nov 22, 2021
d04596d
Use `Protocol`
carmocca Nov 22, 2021
ff7b836
Simplify test
carmocca Nov 22, 2021
a5698e6
Update CHANGELOG.md
carmocca Nov 22, 2021
79fdacc
update
tchaton Nov 22, 2021
916b520
update
tchaton Nov 22, 2021
4b67fbf
update
tchaton Nov 22, 2021
c9481e2
update
tchaton Nov 22, 2021
ef29342
update
tchaton Nov 22, 2021
4a1fff7
update
tchaton Nov 22, 2021
cb27e30
update
tchaton Nov 22, 2021
8430a9a
update
tchaton Nov 22, 2021
e583d73
update
tchaton Nov 22, 2021
3e7a47f
cleanup
tchaton Nov 22, 2021
02c33eb
update
tchaton Nov 22, 2021
7903d24
resolve tests
tchaton Nov 22, 2021
20d19a1
update
tchaton Nov 22, 2021
d217293
update
tchaton Nov 22, 2021
410c6a0
update
tchaton Nov 22, 2021
1104cbc
update
tchaton Nov 23, 2021
8c00aa5
Merge branch 'add_stateful_workers' into remove_example
tchaton Nov 23, 2021
f071f9a
change to 0
tchaton Nov 23, 2021
49fd52d
Merge branch 'add_stateful_workers' into remove_example
tchaton Nov 23, 2021
c851f33
Merge branch 'remove_example' into fault_tolerant_manual_ddp
tchaton Nov 23, 2021
b777dc3
update
tchaton Nov 23, 2021
2da1674
update
tchaton Nov 23, 2021
647bebd
merge with master
tchaton Nov 23, 2021
dbcfa65
update changelog
tchaton Nov 23, 2021
ae18166
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 23, 2021
a527929
update
tchaton Nov 23, 2021
5827e8b
Merge branch 'add_reloading' of https://github.com/PyTorchLightning/p…
tchaton Nov 23, 2021
97421c3
update
tchaton Nov 23, 2021
51cf75b
update
tchaton Nov 23, 2021
35644b8
update on comments
tchaton Nov 23, 2021
04a5c3d
Merge branch 'add_reloading' into fault_tolerant_cleanup
tchaton Nov 23, 2021
26d46b0
update changelog
tchaton Nov 23, 2021
0529f29
update
tchaton Nov 23, 2021
27fa4ce
update changelog
tchaton Nov 23, 2021
50f0b2d
update
tchaton Nov 23, 2021
5337e7d
remove deadcode
tchaton Nov 23, 2021
30cc2ec
update
tchaton Nov 23, 2021
ec0bdad
update
tchaton Nov 23, 2021
c7ee8e3
Merge branch 'fault_tolerant_cleanup' into enable_fault_tolerant_manual
tchaton Nov 23, 2021
030932f
update
tchaton Nov 23, 2021
6856d75
update
tchaton Nov 23, 2021
d8493da
Merge branch 'master' into fault_tolerant_manual_ddp
tchaton Nov 23, 2021
7b7fad7
update
tchaton Nov 23, 2021
7d85730
update
tchaton Nov 23, 2021
d736679
update
tchaton Nov 24, 2021
e36cec1
update
tchaton Nov 24, 2021
62bf7f1
Merge branch 'enable_fault_tolerant_manual' into fault_tolerant_manua…
tchaton Nov 24, 2021
591232b
update on comments
tchaton Nov 24, 2021
363fa38
update
tchaton Nov 24, 2021
b5ad0a6
update
tchaton Nov 24, 2021
fb72988
update
tchaton Nov 24, 2021
19a533e
update
tchaton Nov 25, 2021
dbf6127
update
tchaton Nov 25, 2021
4076ffc
update on comments
tchaton Nov 25, 2021
8dcc3b5
update
tchaton Nov 25, 2021
aef58f8
update
tchaton Nov 25, 2021
f93b64e
update
tchaton Nov 25, 2021
a661f90
update
tchaton Nov 25, 2021
526334a
update
tchaton Nov 25, 2021
2c5596b
update
tchaton Nov 25, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -9,6 +9,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added


- Added a flag `SLURMEnvironment(auto_requeue=True|False)` to control whether Lightning handles the requeuing ([#10601](https://github.com/PyTorchLightning/pytorch-lightning/issues/10601))


Expand All @@ -21,6 +22,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Add logic to reload the states across data loading components ([#10699](https://github.com/PyTorchLightning/pytorch-lightning/issues/10699))
* Cleanup some fault tolerant utilities ([#10703](https://github.com/PyTorchLightning/pytorch-lightning/issues/10703))
* Enable Fault Tolerant Manual Training ([#10707](https://github.com/PyTorchLightning/pytorch-lightning/issues/10707))
* Broadcast the `_terminate_gracefully` to all processes and add support for DDP ([#10638](https://github.com/PyTorchLightning/pytorch-lightning/issues/10638))


- Added support for re-instantiation of custom (subclasses of) `DataLoaders` returned in the `*_dataloader()` methods, i.e., automatic replacement of samplers now works with custom types of `DataLoader` ([#10680](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639))

Expand Down
20 changes: 18 additions & 2 deletions pytorch_lightning/loops/epoch/evaluation_epoch_loop.py
Expand Up @@ -17,13 +17,20 @@
from functools import lru_cache
from typing import Any, Dict, Iterator, Optional, Union

import torch
from deprecate import void

from pytorch_lightning.loops.base import Loop
from pytorch_lightning.loops.utilities import _update_dataloader_iter
from pytorch_lightning.trainer.progress import BatchProgress
from pytorch_lightning.utilities.auto_restart import _reload_dataloader_state_dict, MergedIteratorState
from pytorch_lightning.utilities.auto_restart import (
_collect_states_on_rank_zero_over_collection,
_reload_dataloader_state_dict,
MergedIteratorState,
)
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataFetcher
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT

Expand Down Expand Up @@ -174,11 +181,20 @@ def on_save_checkpoint(self) -> Dict:
state: Optional[MergedIteratorState] = getattr(self._data_fetcher.dataloader_iter, state_to_save, None)
if state:
state_dict["dataloader_state_dict"] = asdict(state)
state_dict["dataloader_state_dict"] = _collect_states_on_rank_zero_over_collection(
state_dict["dataloader_state_dict"], device=self.trainer.training_type_plugin.root_device
)
return state_dict

def on_load_checkpoint(self, state_dict: Dict) -> None:
# cache the dataloader state dict until the dataloader objects are available
self._dataloader_state_dict = state_dict.get("dataloader_state_dict")
# dataset states are collected across all ranks
if _fault_tolerant_training():
dataloader_state_dict = state_dict.get("dataloader_state_dict", None)
if not dataloader_state_dict:
return
rank = torch.distributed.get_rank() if distributed_available() else 0
self._dataloader_state_dict = dataloader_state_dict[rank]
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def _reload_dataloader_state_dict(self, data_fetcher: AbstractDataFetcher):
if not self.trainer.sanity_checking and self._dataloader_state_dict:
Expand Down
6 changes: 6 additions & 0 deletions pytorch_lightning/loops/epoch/training_epoch_loop.py
Expand Up @@ -25,6 +25,7 @@
from pytorch_lightning.trainer.progress import BatchProgress, SchedulerProgress
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.auto_restart import _collect_states_on_rank_zero_over_collection
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import AbstractDataFetcher
from pytorch_lightning.utilities.model_helpers import is_overridden
Expand Down Expand Up @@ -320,9 +321,14 @@ def on_save_checkpoint(self) -> Dict:
or self.batch_progress.current.ready == 0 # did not start
):
return state_dict

state_dict["dataloader_state_dict"] = self.trainer.train_dataloader.state_dict(
has_completed=self._has_completed()
)

state_dict["dataloader_state_dict"] = _collect_states_on_rank_zero_over_collection(
state_dict["dataloader_state_dict"], device=self.trainer.training_type_plugin.root_device
)
return state_dict

def on_load_checkpoint(self, state_dict: Dict) -> None:
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/trainer/connectors/data_connector.py
Expand Up @@ -19,6 +19,7 @@

import pytorch_lightning as pl
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.auto_restart import _teardown_dataloader_get_iterators
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.fetching import (
AbstractDataFetcher,
Expand Down Expand Up @@ -254,6 +255,7 @@ def teardown(self) -> None:
if self.sanity_check_data_fetcher:
self.sanity_check_data_fetcher.teardown()
self.sanity_check_data_fetcher = None
_teardown_dataloader_get_iterators()


@dataclass
Expand Down
5 changes: 5 additions & 0 deletions pytorch_lightning/trainer/supporters.py
Expand Up @@ -29,6 +29,7 @@
patch_dataloader_iterator,
)
from pytorch_lightning.utilities.data import get_len
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training

Expand Down Expand Up @@ -403,6 +404,10 @@ def create_loader_iters(dataloader: DataLoader, state_dict: Dict) -> Iterator:
if isinstance(dataloader, CycleIterator):
dataloader = dataloader_to_iter_on.loader

# dataset states are collected across all ranks
rank = torch.distributed.get_rank() if distributed_available() else 0
tchaton marked this conversation as resolved.
Show resolved Hide resolved
state_dict = state_dict[rank]

_reload_dataloader_state_dict(dataloader, state_dict)

# We finally spawned the workers if any.
Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/trainer/trainer.py
Expand Up @@ -2094,7 +2094,11 @@ def _results(self) -> Optional[ResultCollection]:
return active_loop._results

def _exit_gracefully_on_signal(self) -> None:
if _fault_tolerant_training() and self._terminate_gracefully:
if _fault_tolerant_training():
# the signal should be sent to rank 0
tchaton marked this conversation as resolved.
Show resolved Hide resolved
should_terminate_gracefully = self.training_type_plugin.broadcast(self._terminate_gracefully)
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if not should_terminate_gracefully:
return
caller = inspect.stack()[1]
class_name = caller[0].f_locals["self"].__class__.__name__
raise ExitGracefullyException(f"Exiting gracefully on {class_name}:{caller.function}")
Expand Down
14 changes: 14 additions & 0 deletions pytorch_lightning/utilities/auto_restart.py
Expand Up @@ -31,6 +31,8 @@
from typing_extensions import Protocol, runtime_checkable

import pytorch_lightning as pl
from pytorch_lightning.utilities.apply_func import apply_to_collection
from pytorch_lightning.utilities.distributed import _collect_states_on_rank_zero
from pytorch_lightning.utilities.enums import _FaultTolerantMode, AutoRestartBatchKeys
from pytorch_lightning.utilities.exceptions import MisconfigurationException

Expand Down Expand Up @@ -737,3 +739,15 @@ def _teardown_dataloader_get_iterators() -> None:
if get_iterator:
DataLoader._get_iterator = get_iterator
del DataLoader._ori_get_iterator


def _collect_states_on_rank_zero_over_collection(state_dict: Any, device: torch.device) -> Any:
"""This utility collects the state across processes for a collection of state."""

def fn(state: Dict):
nonlocal device
tchaton marked this conversation as resolved.
Show resolved Hide resolved
if state.get("state", None) is not None:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
return _collect_states_on_rank_zero(state, device=device)
return {k: apply_to_collection(v, Dict, fn) for k, v in state.items()}

return apply_to_collection(state_dict, Dict, fn)
5 changes: 3 additions & 2 deletions pytorch_lightning/utilities/imports.py
Expand Up @@ -14,7 +14,6 @@
"""General utilities."""
import importlib
import operator
import os
import platform
import sys
from importlib.util import find_spec
Expand Down Expand Up @@ -111,4 +110,6 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version:

# experimental feature within PyTorch Lightning.
def _fault_tolerant_training() -> bool:
return bool(int(os.getenv("PL_FAULT_TOLERANT_TRAINING", 0)))
from pytorch_lightning.utilities.enums import _FaultTolerantMode

return _FaultTolerantMode.detect_current_mode().is_enabled
8 changes: 8 additions & 0 deletions tests/utilities/test_auto_restart.py
Expand Up @@ -39,6 +39,7 @@
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities.auto_restart import (
_add_capture_metadata_collate,
_collect_states_on_rank_zero_over_collection,
_MultiProcessingDataLoaderIterStateful,
_patch_dataloader_get_iterators,
_reload_dataloader_state_dict,
Expand Down Expand Up @@ -1254,6 +1255,13 @@ def load_state_dict(self, state_dict):
self.counter = state_dict[0]["counter"]


def test_collect_states_with_collection():
state = {"state": 0}
collection = [{"a": state, "b": [{"a": state}]}]
generated = _collect_states_on_rank_zero_over_collection(collection, torch.device("cpu"))
assert generated == [{"a": {0: state}, "b": [{"a": {0: state}}]}]


@pytest.mark.parametrize("num_workers", [0])
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "2"})
def test_stateful_workers(num_workers):
Expand Down