From ca8dbbbeea1e661d145dc1799684e467a7931d10 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 19 Nov 2021 18:20:54 +0000 Subject: [PATCH 01/42] update --- pytorch_lightning/trainer/trainer.py | 5 +++++ pytorch_lightning/utilities/auto_restart.py | 19 ++++++++++++++-- pytorch_lightning/utilities/enums.py | 19 ++++++++++++++++ tests/utilities/test_auto_restart.py | 25 ++++++++++++++++++++- 4 files changed, 65 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 2f6e987635d47..27ef4002da500 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -80,8 +80,10 @@ parse_argparser, parse_env_variables, ) +from pytorch_lightning.utilities.auto_restart import _detect_fault_tolerant_env_enum from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.distributed import distributed_available +from pytorch_lightning.utilities.enums import FaultTolerantTrainingMode from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training from pytorch_lightning.utilities.meta import is_on_meta_device, materialize_module @@ -470,6 +472,9 @@ def __init__( # Needed because of LightningOptimizer self._lightning_optimizers = None + # detect the fault tolerant flag + self._fault_tolerant_mode: FaultTolerantTrainingMode = _detect_fault_tolerant_env_enum() + # .validate() and .test() set this when they load a checkpoint self.validated_ckpt_path: Optional[str] = None self.tested_ckpt_path: Optional[str] = None diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index ef52717636d90..623e78d9d3d6d 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import os from copy import deepcopy from dataclasses import dataclass, field from functools import partial, wraps @@ -25,7 +25,7 @@ from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader, IterableDataset import pytorch_lightning as pl -from pytorch_lightning.utilities.enums import AutoRestartBatchKeys +from pytorch_lightning.utilities.enums import AutoRestartBatchKeys, FaultTolerantTrainingMode from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training @@ -571,3 +571,18 @@ def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, A else: raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.") + + +def _detect_fault_tolerant_env_enum() -> FaultTolerantTrainingMode: + value = os.getenv("PL_FAULT_TOLERANT_TRAINING", "0") + if value == "0": + return FaultTolerantTrainingMode.DISABLED + elif value == "1": + return FaultTolerantTrainingMode.AUTOMATIC + elif value == "2": + return FaultTolerantTrainingMode.MANUAL + else: + raise MisconfigurationException( + "The environnement flag `PL_FAULT_TOLERANT_TRAINING` should be either " + "'0' (disabled), '1' (automatic) or '2' (manual)." + ) diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index cbb4f68bedfac..2eb38fdb21e66 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -256,3 +256,22 @@ def interactive_compatible_types() -> List["_StrategyType"]: def is_interactive_compatible(self) -> bool: """Returns whether self is interactive compatible.""" return self in _StrategyType.interactive_compatible_types() + + +class FaultTolerantTrainingMode(LightningEnum): + + DISABLED = "disabled" + AUTOMATIC = "automatic" + MANUAL = "manual" + + @property + def is_enabled(self) -> bool: + return self is not FaultTolerantTrainingMode.DISABLED + + @property + def is_automatic(self) -> bool: + return self is not FaultTolerantTrainingMode.AUTOMATIC + + @property + def is_manual(self) -> bool: + return self is not FaultTolerantTrainingMode.MANUAL diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index b36a9d1d76941..bbde1053cad76 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -39,12 +39,13 @@ _add_capture_metadata_collate, _dataloader_load_state_dict, _dataloader_to_state_dict, + _detect_fault_tolerant_env_enum, CaptureIterableDataset, CaptureMapDataset, FastForwardSampler, MergedIteratorState, ) -from pytorch_lightning.utilities.enums import AutoRestartBatchKeys +from pytorch_lightning.utilities.enums import AutoRestartBatchKeys, FaultTolerantTrainingMode from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException from pytorch_lightning.utilities.fetching import DataFetcher from pytorch_lightning.utilities.imports import _fault_tolerant_training @@ -1192,3 +1193,25 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on assert "dataloader_state_dict" not in state_dict else: assert "dataloader_state_dict" in state_dict + + +def test_fault_tolerant_manual_mode_enum(): + + with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "0"}): + assert FaultTolerantTrainingMode.DISABLED == _detect_fault_tolerant_env_enum() + trainer = Trainer() + assert not trainer._fault_tolerant_mode.is_enabled + + with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}): + assert FaultTolerantTrainingMode.AUTOMATIC == _detect_fault_tolerant_env_enum() + assert trainer._fault_tolerant_mode.is_automatic + + with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "2"}): + assert FaultTolerantTrainingMode.MANUAL == _detect_fault_tolerant_env_enum() + assert trainer._fault_tolerant_mode.is_manual + + with pytest.raises( + MisconfigurationException, match="The environnement flag `PL_FAULT_TOLERANT_TRAINING` should be either" + ): + with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "3"}): + assert FaultTolerantTrainingMode.MANUAL == _detect_fault_tolerant_env_enum() From 3f0d28a55531522a81b0ff00f246549c049ea556 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 19 Nov 2021 18:21:09 +0000 Subject: [PATCH 02/42] update --- pytorch_lightning/trainer/trainer.py | 4 ++-- pytorch_lightning/utilities/auto_restart.py | 2 +- tests/utilities/test_auto_restart.py | 10 +++++----- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 27ef4002da500..f09170d2c1365 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -80,7 +80,7 @@ parse_argparser, parse_env_variables, ) -from pytorch_lightning.utilities.auto_restart import _detect_fault_tolerant_env_enum +from pytorch_lightning.utilities.auto_restart import _detect_fault_tolerant_env_to_enum from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.enums import FaultTolerantTrainingMode @@ -473,7 +473,7 @@ def __init__( self._lightning_optimizers = None # detect the fault tolerant flag - self._fault_tolerant_mode: FaultTolerantTrainingMode = _detect_fault_tolerant_env_enum() + self._fault_tolerant_mode: FaultTolerantTrainingMode = _detect_fault_tolerant_env_to_enum() # .validate() and .test() set this when they load a checkpoint self.validated_ckpt_path: Optional[str] = None diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 623e78d9d3d6d..68af27db20944 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -573,7 +573,7 @@ def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, A raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.") -def _detect_fault_tolerant_env_enum() -> FaultTolerantTrainingMode: +def _detect_fault_tolerant_env_to_enum() -> FaultTolerantTrainingMode: value = os.getenv("PL_FAULT_TOLERANT_TRAINING", "0") if value == "0": return FaultTolerantTrainingMode.DISABLED diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index bbde1053cad76..d3ffe116b41e8 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -39,7 +39,7 @@ _add_capture_metadata_collate, _dataloader_load_state_dict, _dataloader_to_state_dict, - _detect_fault_tolerant_env_enum, + _detect_fault_tolerant_env_to_enum, CaptureIterableDataset, CaptureMapDataset, FastForwardSampler, @@ -1198,20 +1198,20 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on def test_fault_tolerant_manual_mode_enum(): with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "0"}): - assert FaultTolerantTrainingMode.DISABLED == _detect_fault_tolerant_env_enum() + assert FaultTolerantTrainingMode.DISABLED == _detect_fault_tolerant_env_to_enum() trainer = Trainer() assert not trainer._fault_tolerant_mode.is_enabled with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}): - assert FaultTolerantTrainingMode.AUTOMATIC == _detect_fault_tolerant_env_enum() + assert FaultTolerantTrainingMode.AUTOMATIC == _detect_fault_tolerant_env_to_enum() assert trainer._fault_tolerant_mode.is_automatic with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "2"}): - assert FaultTolerantTrainingMode.MANUAL == _detect_fault_tolerant_env_enum() + assert FaultTolerantTrainingMode.MANUAL == _detect_fault_tolerant_env_to_enum() assert trainer._fault_tolerant_mode.is_manual with pytest.raises( MisconfigurationException, match="The environnement flag `PL_FAULT_TOLERANT_TRAINING` should be either" ): with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "3"}): - assert FaultTolerantTrainingMode.MANUAL == _detect_fault_tolerant_env_enum() + assert FaultTolerantTrainingMode.MANUAL == _detect_fault_tolerant_env_to_enum() From 0c476704115b30c623bf43efd3b9678e4e8ad556 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 19 Nov 2021 18:23:02 +0000 Subject: [PATCH 03/42] update --- pytorch_lightning/utilities/auto_restart.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 68af27db20944..4b8d4f2e1babb 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -574,12 +574,13 @@ def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, A def _detect_fault_tolerant_env_to_enum() -> FaultTolerantTrainingMode: - value = os.getenv("PL_FAULT_TOLERANT_TRAINING", "0") - if value == "0": + """This utility detect is Fault Tolerant is activated and maps it to `FaultTolerantTrainingMode`.""" + env_value = os.getenv("PL_FAULT_TOLERANT_TRAINING", "0") + if env_value == "0": return FaultTolerantTrainingMode.DISABLED - elif value == "1": + elif env_value == "1": return FaultTolerantTrainingMode.AUTOMATIC - elif value == "2": + elif env_value == "2": return FaultTolerantTrainingMode.MANUAL else: raise MisconfigurationException( From 3e5b52ebfd283838123b78892f60a7b4bf149802 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 19 Nov 2021 18:24:39 +0000 Subject: [PATCH 04/42] update --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 96830878d84ae..dd0e122ce3cdc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a flag `SLURMEnvironment(auto_requeue=True|False)` to control whether Lightning handles the requeuing ([#10601](https://github.com/PyTorchLightning/pytorch-lightning/issues/10601)) -- +- Fault Tolerant Manual + * Add `FaultTolerantTrainingMode` enum used to track different supported fault tolerant modes ([#10645](https://github.com/PyTorchLightning/pytorch-lightning/issues/10645)) - From 24c82451a64850d746accadb88d789a06d4148c9 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 19 Nov 2021 18:40:30 +0000 Subject: [PATCH 05/42] update --- pytorch_lightning/utilities/auto_restart.py | 13 +++++++++ tests/utilities/test_auto_restart.py | 30 +++++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index ef52717636d90..11593c779c260 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect from copy import deepcopy from dataclasses import dataclass, field from functools import partial, wraps @@ -571,3 +572,15 @@ def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, A else: raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.") + + +def is_obj_stateful(obj: Any) -> bool: + """In order to be stateful, an object should implement a ``state_dict`` and ``load_state_dict`` method.""" + load_state_dict_fn = getattr(obj, "load_state_dict", None) + if not isinstance(load_state_dict_fn, Callable): + return False + params = inspect.signature(load_state_dict_fn).parameters + if len(params) == 0: + return False + + return isinstance(getattr(obj, "state_dict", None), Callable) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index b36a9d1d76941..6717d8f09d4b5 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -42,6 +42,7 @@ CaptureIterableDataset, CaptureMapDataset, FastForwardSampler, + is_obj_stateful, MergedIteratorState, ) from pytorch_lightning.utilities.enums import AutoRestartBatchKeys @@ -1192,3 +1193,32 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on assert "dataloader_state_dict" not in state_dict else: assert "dataloader_state_dict" in state_dict + + +def test_is_obj_stateful(): + class StatefulClass: + def state_dict(self): + pass + + def load_state_dict(self, state_dict): + pass + + obj = StatefulClass() + assert is_obj_stateful(obj) + + class NotStatefulClass: + def state_dict(self): + pass + + def load_state_dict(self): + pass + + obj = NotStatefulClass() + assert not is_obj_stateful(obj) + + class NotStatefulClass: + def load_state_dict(self): + pass + + obj = NotStatefulClass() + assert not is_obj_stateful(obj) From bcd5569ef5e404e69430f089f613ac16c50eb6a9 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 19 Nov 2021 18:42:28 +0000 Subject: [PATCH 06/42] update --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 96830878d84ae..30c16a4d67637 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a flag `SLURMEnvironment(auto_requeue=True|False)` to control whether Lightning handles the requeuing ([#10601](https://github.com/PyTorchLightning/pytorch-lightning/issues/10601)) -- +- Fault Tolerant Manual + * Add `is_obj_stateful` utility to detect if user data loading components are stateful ([#10646](https://github.com/PyTorchLightning/pytorch-lightning/issues/10646)) - From 8d3844dc782a62613956fb6fe177b68872fe9a7f Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 19 Nov 2021 18:44:03 +0000 Subject: [PATCH 07/42] update --- CHANGELOG.md | 2 +- pytorch_lightning/utilities/auto_restart.py | 2 +- tests/utilities/test_auto_restart.py | 10 +++++----- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 30c16a4d67637..4f59e996365a5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,7 +13,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fault Tolerant Manual - * Add `is_obj_stateful` utility to detect if user data loading components are stateful ([#10646](https://github.com/PyTorchLightning/pytorch-lightning/issues/10646)) + * Add `_is_obj_stateful` utility to detect if user data loading components are stateful ([#10646](https://github.com/PyTorchLightning/pytorch-lightning/issues/10646)) - diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 11593c779c260..0310aceb4bc49 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -574,7 +574,7 @@ def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, A raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.") -def is_obj_stateful(obj: Any) -> bool: +def _is_obj_stateful(obj: Any) -> bool: """In order to be stateful, an object should implement a ``state_dict`` and ``load_state_dict`` method.""" load_state_dict_fn = getattr(obj, "load_state_dict", None) if not isinstance(load_state_dict_fn, Callable): diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 6717d8f09d4b5..7d75b7605ce0a 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -39,10 +39,10 @@ _add_capture_metadata_collate, _dataloader_load_state_dict, _dataloader_to_state_dict, + _is_obj_stateful, CaptureIterableDataset, CaptureMapDataset, FastForwardSampler, - is_obj_stateful, MergedIteratorState, ) from pytorch_lightning.utilities.enums import AutoRestartBatchKeys @@ -1195,7 +1195,7 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on assert "dataloader_state_dict" in state_dict -def test_is_obj_stateful(): +def test__is_obj_stateful(): class StatefulClass: def state_dict(self): pass @@ -1204,7 +1204,7 @@ def load_state_dict(self, state_dict): pass obj = StatefulClass() - assert is_obj_stateful(obj) + assert _is_obj_stateful(obj) class NotStatefulClass: def state_dict(self): @@ -1214,11 +1214,11 @@ def load_state_dict(self): pass obj = NotStatefulClass() - assert not is_obj_stateful(obj) + assert not _is_obj_stateful(obj) class NotStatefulClass: def load_state_dict(self): pass obj = NotStatefulClass() - assert not is_obj_stateful(obj) + assert not _is_obj_stateful(obj) From 1829b46ac783c12673a2c12573c16322867bde09 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 19 Nov 2021 18:46:12 +0000 Subject: [PATCH 08/42] update --- CHANGELOG.md | 2 +- pytorch_lightning/trainer/trainer.py | 4 ++-- pytorch_lightning/utilities/auto_restart.py | 12 ++++++------ pytorch_lightning/utilities/enums.py | 8 ++++---- tests/utilities/test_auto_restart.py | 12 +++++++----- 5 files changed, 20 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index dd0e122ce3cdc..5ff6d021c8c4c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,7 +13,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fault Tolerant Manual - * Add `FaultTolerantTrainingMode` enum used to track different supported fault tolerant modes ([#10645](https://github.com/PyTorchLightning/pytorch-lightning/issues/10645)) + * Add `_FaultTolerantTrainingMode` enum used to track different supported fault tolerant modes ([#10645](https://github.com/PyTorchLightning/pytorch-lightning/issues/10645)) - diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index f09170d2c1365..bc8c2fefabc5b 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -83,7 +83,7 @@ from pytorch_lightning.utilities.auto_restart import _detect_fault_tolerant_env_to_enum from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.distributed import distributed_available -from pytorch_lightning.utilities.enums import FaultTolerantTrainingMode +from pytorch_lightning.utilities.enums import _FaultTolerantTrainingMode from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training from pytorch_lightning.utilities.meta import is_on_meta_device, materialize_module @@ -473,7 +473,7 @@ def __init__( self._lightning_optimizers = None # detect the fault tolerant flag - self._fault_tolerant_mode: FaultTolerantTrainingMode = _detect_fault_tolerant_env_to_enum() + self._fault_tolerant_mode: _FaultTolerantTrainingMode = _detect_fault_tolerant_env_to_enum() # .validate() and .test() set this when they load a checkpoint self.validated_ckpt_path: Optional[str] = None diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 4b8d4f2e1babb..5e6e2c7fe52ff 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -25,7 +25,7 @@ from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader, IterableDataset import pytorch_lightning as pl -from pytorch_lightning.utilities.enums import AutoRestartBatchKeys, FaultTolerantTrainingMode +from pytorch_lightning.utilities.enums import _FaultTolerantTrainingMode, AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training @@ -573,15 +573,15 @@ def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, A raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.") -def _detect_fault_tolerant_env_to_enum() -> FaultTolerantTrainingMode: - """This utility detect is Fault Tolerant is activated and maps it to `FaultTolerantTrainingMode`.""" +def _detect_fault_tolerant_env_to_enum() -> _FaultTolerantTrainingMode: + """This utility detect is Fault Tolerant is activated and maps it to `_FaultTolerantTrainingMode`.""" env_value = os.getenv("PL_FAULT_TOLERANT_TRAINING", "0") if env_value == "0": - return FaultTolerantTrainingMode.DISABLED + return _FaultTolerantTrainingMode.DISABLED elif env_value == "1": - return FaultTolerantTrainingMode.AUTOMATIC + return _FaultTolerantTrainingMode.AUTOMATIC elif env_value == "2": - return FaultTolerantTrainingMode.MANUAL + return _FaultTolerantTrainingMode.MANUAL else: raise MisconfigurationException( "The environnement flag `PL_FAULT_TOLERANT_TRAINING` should be either " diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index 2eb38fdb21e66..f7491241dfc3f 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -258,7 +258,7 @@ def is_interactive_compatible(self) -> bool: return self in _StrategyType.interactive_compatible_types() -class FaultTolerantTrainingMode(LightningEnum): +class _FaultTolerantTrainingMode(LightningEnum): DISABLED = "disabled" AUTOMATIC = "automatic" @@ -266,12 +266,12 @@ class FaultTolerantTrainingMode(LightningEnum): @property def is_enabled(self) -> bool: - return self is not FaultTolerantTrainingMode.DISABLED + return self is not _FaultTolerantTrainingMode.DISABLED @property def is_automatic(self) -> bool: - return self is not FaultTolerantTrainingMode.AUTOMATIC + return self is _FaultTolerantTrainingMode.AUTOMATIC @property def is_manual(self) -> bool: - return self is not FaultTolerantTrainingMode.MANUAL + return self is _FaultTolerantTrainingMode.MANUAL diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index d3ffe116b41e8..411921fefcfa0 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -45,7 +45,7 @@ FastForwardSampler, MergedIteratorState, ) -from pytorch_lightning.utilities.enums import AutoRestartBatchKeys, FaultTolerantTrainingMode +from pytorch_lightning.utilities.enums import _FaultTolerantTrainingMode, AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException from pytorch_lightning.utilities.fetching import DataFetcher from pytorch_lightning.utilities.imports import _fault_tolerant_training @@ -1198,20 +1198,22 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on def test_fault_tolerant_manual_mode_enum(): with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "0"}): - assert FaultTolerantTrainingMode.DISABLED == _detect_fault_tolerant_env_to_enum() + assert _FaultTolerantTrainingMode.DISABLED == _detect_fault_tolerant_env_to_enum() trainer = Trainer() assert not trainer._fault_tolerant_mode.is_enabled with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}): - assert FaultTolerantTrainingMode.AUTOMATIC == _detect_fault_tolerant_env_to_enum() + assert _FaultTolerantTrainingMode.AUTOMATIC == _detect_fault_tolerant_env_to_enum() + trainer = Trainer() assert trainer._fault_tolerant_mode.is_automatic with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "2"}): - assert FaultTolerantTrainingMode.MANUAL == _detect_fault_tolerant_env_to_enum() + assert _FaultTolerantTrainingMode.MANUAL == _detect_fault_tolerant_env_to_enum() + trainer = Trainer() assert trainer._fault_tolerant_mode.is_manual with pytest.raises( MisconfigurationException, match="The environnement flag `PL_FAULT_TOLERANT_TRAINING` should be either" ): with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "3"}): - assert FaultTolerantTrainingMode.MANUAL == _detect_fault_tolerant_env_to_enum() + assert _FaultTolerantTrainingMode.MANUAL == _detect_fault_tolerant_env_to_enum() From a1a364a268a5cf9eadbfa382f4b33d541847774e Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 19 Nov 2021 19:22:34 +0000 Subject: [PATCH 09/42] typo --- pytorch_lightning/utilities/auto_restart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 5e6e2c7fe52ff..0614e4dffb8cb 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -574,7 +574,7 @@ def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, A def _detect_fault_tolerant_env_to_enum() -> _FaultTolerantTrainingMode: - """This utility detect is Fault Tolerant is activated and maps it to `_FaultTolerantTrainingMode`.""" + """This utility detect if Fault Tolerant is activated and maps its value to `_FaultTolerantTrainingMode`.""" env_value = os.getenv("PL_FAULT_TOLERANT_TRAINING", "0") if env_value == "0": return _FaultTolerantTrainingMode.DISABLED From de41675d7d4a33e1bf7d5a767b17a6c5a672763f Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 22 Nov 2021 13:19:59 +0000 Subject: [PATCH 10/42] update on comments --- pytorch_lightning/trainer/trainer.py | 4 ++-- pytorch_lightning/utilities/auto_restart.py | 11 +++++------ tests/utilities/test_auto_restart.py | 10 +++++----- 3 files changed, 12 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index bc8c2fefabc5b..da5b952ce1172 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -80,7 +80,7 @@ parse_argparser, parse_env_variables, ) -from pytorch_lightning.utilities.auto_restart import _detect_fault_tolerant_env_to_enum +from pytorch_lightning.utilities.auto_restart import _detect_fault_tolerant_training_mode from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.enums import _FaultTolerantTrainingMode @@ -473,7 +473,7 @@ def __init__( self._lightning_optimizers = None # detect the fault tolerant flag - self._fault_tolerant_mode: _FaultTolerantTrainingMode = _detect_fault_tolerant_env_to_enum() + self._fault_tolerant_mode: _FaultTolerantTrainingMode = _detect_fault_tolerant_training_mode() # .validate() and .test() set this when they load a checkpoint self.validated_ckpt_path: Optional[str] = None diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 0614e4dffb8cb..2afb591cbc172 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -573,7 +573,7 @@ def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, A raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.") -def _detect_fault_tolerant_env_to_enum() -> _FaultTolerantTrainingMode: +def _detect_fault_tolerant_training_mode() -> _FaultTolerantTrainingMode: """This utility detect if Fault Tolerant is activated and maps its value to `_FaultTolerantTrainingMode`.""" env_value = os.getenv("PL_FAULT_TOLERANT_TRAINING", "0") if env_value == "0": @@ -582,8 +582,7 @@ def _detect_fault_tolerant_env_to_enum() -> _FaultTolerantTrainingMode: return _FaultTolerantTrainingMode.AUTOMATIC elif env_value == "2": return _FaultTolerantTrainingMode.MANUAL - else: - raise MisconfigurationException( - "The environnement flag `PL_FAULT_TOLERANT_TRAINING` should be either " - "'0' (disabled), '1' (automatic) or '2' (manual)." - ) + raise MisconfigurationException( + "The environnement flag `PL_FAULT_TOLERANT_TRAINING` should be either " + "'0' (disabled), '1' (automatic) or '2' (manual)." + ) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 411921fefcfa0..c1dc0afb76799 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -39,7 +39,7 @@ _add_capture_metadata_collate, _dataloader_load_state_dict, _dataloader_to_state_dict, - _detect_fault_tolerant_env_to_enum, + _detect_fault_tolerant_training_mode, CaptureIterableDataset, CaptureMapDataset, FastForwardSampler, @@ -1198,17 +1198,17 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on def test_fault_tolerant_manual_mode_enum(): with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "0"}): - assert _FaultTolerantTrainingMode.DISABLED == _detect_fault_tolerant_env_to_enum() + assert _FaultTolerantTrainingMode.DISABLED == _detect_fault_tolerant_training_mode() trainer = Trainer() assert not trainer._fault_tolerant_mode.is_enabled with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}): - assert _FaultTolerantTrainingMode.AUTOMATIC == _detect_fault_tolerant_env_to_enum() + assert _FaultTolerantTrainingMode.AUTOMATIC == _detect_fault_tolerant_training_mode() trainer = Trainer() assert trainer._fault_tolerant_mode.is_automatic with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "2"}): - assert _FaultTolerantTrainingMode.MANUAL == _detect_fault_tolerant_env_to_enum() + assert _FaultTolerantTrainingMode.MANUAL == _detect_fault_tolerant_training_mode() trainer = Trainer() assert trainer._fault_tolerant_mode.is_manual @@ -1216,4 +1216,4 @@ def test_fault_tolerant_manual_mode_enum(): MisconfigurationException, match="The environnement flag `PL_FAULT_TOLERANT_TRAINING` should be either" ): with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "3"}): - assert _FaultTolerantTrainingMode.MANUAL == _detect_fault_tolerant_env_to_enum() + _detect_fault_tolerant_training_mode() From 8178a3238e1eaab5016c96aed06741d62f2e8fb9 Mon Sep 17 00:00:00 2001 From: Kaushik B <45285388+kaushikb11@users.noreply.github.com> Date: Mon, 22 Nov 2021 19:04:32 +0530 Subject: [PATCH 11/42] Update pytorch_lightning/utilities/auto_restart.py --- pytorch_lightning/utilities/auto_restart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 2afb591cbc172..48ea16557ec40 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -574,7 +574,7 @@ def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, A def _detect_fault_tolerant_training_mode() -> _FaultTolerantTrainingMode: - """This utility detect if Fault Tolerant is activated and maps its value to `_FaultTolerantTrainingMode`.""" + """This utility detects if Fault Tolerant is activated and maps its value to `_FaultTolerantTrainingMode`.""" env_value = os.getenv("PL_FAULT_TOLERANT_TRAINING", "0") if env_value == "0": return _FaultTolerantTrainingMode.DISABLED From 00b93556608d9a082c015c95521eccf7b56275b6 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 22 Nov 2021 13:55:41 +0000 Subject: [PATCH 12/42] update --- pytorch_lightning/utilities/auto_restart.py | 110 ++++++++++++++++++++ tests/utilities/test_auto_restart.py | 7 ++ 2 files changed, 117 insertions(+) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index ef52717636d90..08b480b094b77 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -571,3 +571,113 @@ def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, A else: raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.") + + +class _StatefulMixin: + """This mixin is used to make PyTorch DataLoaderIter stateful.""" + + def _reset(self, loader: DataLoader, first_iter: bool = False): + super()._reset(loader, first_iter=first_iter) + self._loader = loader + self.num_batches_fetched = 0 + + def __accumulate_state(self, sampler_state: Dict[str, Any]) -> None: + # initialize the queue if it doesn't exist. + if not hasattr(self, "_sampler_state"): + self._sampler_state = [] + self._sampler_state_idx = 0 + + # store sampler state within a queue alongside its idx. + self._sampler_state_idx = getattr(self, "_sampler_state_idx", 0) + 1 + self._sampler_state.append((sampler_state, self._sampler_state_idx)) + + def _store_sampler_state(self) -> None: + """This function is used to extract the sampler states if any.""" + sampler_state = { + k: v.state_dict() for k, v in self._loader.__dict__.items() if is_obj_stateful(v) and k != "dataset" + } + + self.__accumulate_state(sampler_state) + + def _next_index(self) -> Any: + indexes = super()._next_index() + self._store_sampler_state() + return indexes + + def _prepare_loader(self, loader): + if not isinstance(loader.collate_fn, partial): + loader.collate_fn = partial( + _capture_metadata_collate, dataset=loader.dataset, default_collate=loader.collate_fn + ) + self._loader = loader + + def __del__(self) -> None: + if isinstance(self._loader.collate_fn, partial): + self._loader.collate_fn = self._loader.collate_fn.keywords["default_collate"] + + def _next_data(self) -> Any: + combined_batch = super()._next_data() + + batch, state = combined_batch["data"], combined_batch[AutoRestartBatchKeys.PL_RESTART_META] + + self.num_batches_fetched += 1 + + sampler_state, sampler_state_idx = self._sampler_state.pop(0) + # there is no workers within the samplers + worker_id = list(state.keys())[0] + + state = [ + IteratorState( + num_workers=self._loader.num_workers, + sampler_state=sampler_state, + dataset_state=state, + worker_id=worker_id, + num_batches_fetched=self.num_batches_fetched, + ) + ] + # ensures there is an alignement between the sampler state and currently fetched batch + assert sampler_state_idx == self.num_batches_fetched + self._data_fetcher._store_dataloader_iter_state(self, state) + return batch + + +class _SingleProcessDataLoaderIterStateful(_StatefulMixin, _SingleProcessDataLoaderIter): + def __init__(self, loader: DataLoader): + self._prepare_loader(loader) + super().__init__(loader) + self._data_fetcher: "pl.utilities.fetching.AbstractDataFetcher" = loader._lightning_fetcher + self.num_batches_fetched = 0 + + +class _MultiProcessingDataLoaderIterStateful(_StatefulMixin, _MultiProcessingDataLoaderIter): + def __init__(self, loader: DataLoader): + self._prepare_loader(loader) + super().__init__(loader) + self._data_fetcher: "pl.utilities.fetching.AbstractDataFetcher" = loader._lightning_fetcher + self.num_batches_fetched = 0 + + +def _get_iterator(self) -> "_BaseDataLoaderIter": + if self.num_workers == 0: + return _SingleProcessDataLoaderIterStateful(self) + else: + if hasattr(self, "check_worker_number_rationality"): + self.check_worker_number_rationality() + return _MultiProcessingDataLoaderIterStateful(self) + + +def _patch_dataloader_get_iterators() -> None: + """This function is used to replace the DataLoader iterator by their stateful version.""" + if _fault_tolerant_training_mode().is_manual: + if not hasattr(DataLoader, "_ori_get_iterator"): + DataLoader._ori_get_iterator = DataLoader._get_iterator + DataLoader._get_iterator = _get_iterator + + +def _teardown_dataloader_get_iterators() -> None: + """This function is used to restore the DataLoader `get_iterator` with its original one.""" + # cleanup the get_iterator replacement in case of Fault Tolerant Training. + get_iterator = getattr(DataLoader, "_ori_get_iterator", None) + if get_iterator: + DataLoader._get_iterator = get_iterator + del DataLoader._ori_get_iterator \ No newline at end of file diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index b36a9d1d76941..36658469065ba 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -43,6 +43,8 @@ CaptureMapDataset, FastForwardSampler, MergedIteratorState, + _patch_dataloader_get_iterators, + _teardown_dataloader_get_iterators, ) from pytorch_lightning.utilities.enums import AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException @@ -1192,3 +1194,8 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on assert "dataloader_state_dict" not in state_dict else: assert "dataloader_state_dict" in state_dict + + +def test_stateful_workers(): + + _patch_dataloader_get_iterators() \ No newline at end of file From 96f0517d31c1ae7414dfde9400b629056c272ac8 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 22 Nov 2021 13:59:49 +0000 Subject: [PATCH 13/42] update --- pytorch_lightning/trainer/states.py | 7 ++++++- pytorch_lightning/trainer/trainer.py | 5 ----- tests/utilities/test_auto_restart.py | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/trainer/states.py b/pytorch_lightning/trainer/states.py index 7f83dd76156ab..93ede001fe64d 100644 --- a/pytorch_lightning/trainer/states.py +++ b/pytorch_lightning/trainer/states.py @@ -11,10 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Optional from pytorch_lightning.utilities import LightningEnum +from pytorch_lightning.utilities.auto_restart import _detect_fault_tolerant_training_mode +from pytorch_lightning.utilities.enums import _FaultTolerantTrainingMode class TrainerStatus(LightningEnum): @@ -93,6 +95,9 @@ class TrainerState: fn: Optional[TrainerFn] = None stage: Optional[RunningStage] = None + # detect the fault tolerant flag + _fault_tolerant_mode: _FaultTolerantTrainingMode = field(default_factory=_detect_fault_tolerant_training_mode) + @property def finished(self) -> bool: return self.status == TrainerStatus.FINISHED diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index da5b952ce1172..2f6e987635d47 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -80,10 +80,8 @@ parse_argparser, parse_env_variables, ) -from pytorch_lightning.utilities.auto_restart import _detect_fault_tolerant_training_mode from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.distributed import distributed_available -from pytorch_lightning.utilities.enums import _FaultTolerantTrainingMode from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training from pytorch_lightning.utilities.meta import is_on_meta_device, materialize_module @@ -472,9 +470,6 @@ def __init__( # Needed because of LightningOptimizer self._lightning_optimizers = None - # detect the fault tolerant flag - self._fault_tolerant_mode: _FaultTolerantTrainingMode = _detect_fault_tolerant_training_mode() - # .validate() and .test() set this when they load a checkpoint self.validated_ckpt_path: Optional[str] = None self.tested_ckpt_path: Optional[str] = None diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index c1dc0afb76799..eae9eee9ba285 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -1200,17 +1200,17 @@ def test_fault_tolerant_manual_mode_enum(): with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "0"}): assert _FaultTolerantTrainingMode.DISABLED == _detect_fault_tolerant_training_mode() trainer = Trainer() - assert not trainer._fault_tolerant_mode.is_enabled + assert not trainer.state._fault_tolerant_mode.is_enabled with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}): assert _FaultTolerantTrainingMode.AUTOMATIC == _detect_fault_tolerant_training_mode() trainer = Trainer() - assert trainer._fault_tolerant_mode.is_automatic + assert trainer.state._fault_tolerant_mode.is_automatic with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "2"}): assert _FaultTolerantTrainingMode.MANUAL == _detect_fault_tolerant_training_mode() trainer = Trainer() - assert trainer._fault_tolerant_mode.is_manual + assert trainer.state._fault_tolerant_mode.is_manual with pytest.raises( MisconfigurationException, match="The environnement flag `PL_FAULT_TOLERANT_TRAINING` should be either" From 9800cba04d316c10bf14a93b0b75a8aba150fea9 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 22 Nov 2021 14:07:56 +0000 Subject: [PATCH 14/42] update --- pytorch_lightning/trainer/states.py | 5 +++-- pytorch_lightning/utilities/auto_restart.py | 18 +----------------- pytorch_lightning/utilities/enums.py | 18 ++++++++++++++++++ tests/utilities/test_auto_restart.py | 11 +++++------ 4 files changed, 27 insertions(+), 25 deletions(-) diff --git a/pytorch_lightning/trainer/states.py b/pytorch_lightning/trainer/states.py index 93ede001fe64d..1c9050bbd86c2 100644 --- a/pytorch_lightning/trainer/states.py +++ b/pytorch_lightning/trainer/states.py @@ -15,7 +15,6 @@ from typing import Optional from pytorch_lightning.utilities import LightningEnum -from pytorch_lightning.utilities.auto_restart import _detect_fault_tolerant_training_mode from pytorch_lightning.utilities.enums import _FaultTolerantTrainingMode @@ -96,7 +95,9 @@ class TrainerState: stage: Optional[RunningStage] = None # detect the fault tolerant flag - _fault_tolerant_mode: _FaultTolerantTrainingMode = field(default_factory=_detect_fault_tolerant_training_mode) + _fault_tolerant_mode: _FaultTolerantTrainingMode = field( + default_factory=_FaultTolerantTrainingMode._detect_fault_tolerant_training_mode + ) @property def finished(self) -> bool: diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 48ea16557ec40..228e16e4e9c8c 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from copy import deepcopy from dataclasses import dataclass, field from functools import partial, wraps @@ -25,7 +24,7 @@ from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader, IterableDataset import pytorch_lightning as pl -from pytorch_lightning.utilities.enums import _FaultTolerantTrainingMode, AutoRestartBatchKeys +from pytorch_lightning.utilities.enums import AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _fault_tolerant_training @@ -571,18 +570,3 @@ def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, A else: raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.") - - -def _detect_fault_tolerant_training_mode() -> _FaultTolerantTrainingMode: - """This utility detects if Fault Tolerant is activated and maps its value to `_FaultTolerantTrainingMode`.""" - env_value = os.getenv("PL_FAULT_TOLERANT_TRAINING", "0") - if env_value == "0": - return _FaultTolerantTrainingMode.DISABLED - elif env_value == "1": - return _FaultTolerantTrainingMode.AUTOMATIC - elif env_value == "2": - return _FaultTolerantTrainingMode.MANUAL - raise MisconfigurationException( - "The environnement flag `PL_FAULT_TOLERANT_TRAINING` should be either " - "'0' (disabled), '1' (automatic) or '2' (manual)." - ) diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index f7491241dfc3f..5070fbcef0172 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. """Enumerated utilities.""" +import os from enum import Enum, EnumMeta from typing import Any, List, Optional, Union +from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.warnings import rank_zero_deprecation @@ -275,3 +277,19 @@ def is_automatic(self) -> bool: @property def is_manual(self) -> bool: return self is _FaultTolerantTrainingMode.MANUAL + + @classmethod + def _detect_fault_tolerant_training_mode(cls) -> "_FaultTolerantTrainingMode": + """This utility detects if Fault Tolerant is activated and maps its value to + `_FaultTolerantTrainingMode`.""" + env_value = os.getenv("PL_FAULT_TOLERANT_TRAINING", "0").lower() + if env_value in ("0", "disabled"): + return _FaultTolerantTrainingMode.DISABLED + elif env_value in ("1", "automatic"): + return _FaultTolerantTrainingMode.AUTOMATIC + elif env_value in ("2", "manual"): + return _FaultTolerantTrainingMode.MANUAL + raise MisconfigurationException( + "The environnement flag `PL_FAULT_TOLERANT_TRAINING` should be either " + "'0' (disabled), '1' (automatic) or '2' (manual)." + ) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index eae9eee9ba285..295efa023ea41 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -39,7 +39,6 @@ _add_capture_metadata_collate, _dataloader_load_state_dict, _dataloader_to_state_dict, - _detect_fault_tolerant_training_mode, CaptureIterableDataset, CaptureMapDataset, FastForwardSampler, @@ -1198,17 +1197,17 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on def test_fault_tolerant_manual_mode_enum(): with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "0"}): - assert _FaultTolerantTrainingMode.DISABLED == _detect_fault_tolerant_training_mode() + assert _FaultTolerantTrainingMode.DISABLED == _FaultTolerantTrainingMode._detect_fault_tolerant_training_mode() trainer = Trainer() assert not trainer.state._fault_tolerant_mode.is_enabled with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}): - assert _FaultTolerantTrainingMode.AUTOMATIC == _detect_fault_tolerant_training_mode() + assert _FaultTolerantTrainingMode.AUTOMATIC == _FaultTolerantTrainingMode._detect_fault_tolerant_training_mode() trainer = Trainer() assert trainer.state._fault_tolerant_mode.is_automatic - with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "2"}): - assert _FaultTolerantTrainingMode.MANUAL == _detect_fault_tolerant_training_mode() + with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "MANUAL"}): + assert _FaultTolerantTrainingMode.MANUAL == _FaultTolerantTrainingMode._detect_fault_tolerant_training_mode() trainer = Trainer() assert trainer.state._fault_tolerant_mode.is_manual @@ -1216,4 +1215,4 @@ def test_fault_tolerant_manual_mode_enum(): MisconfigurationException, match="The environnement flag `PL_FAULT_TOLERANT_TRAINING` should be either" ): with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "3"}): - _detect_fault_tolerant_training_mode() + _FaultTolerantTrainingMode._detect_fault_tolerant_training_mode() From 427ed036a20f403df2955d84a9105cededfc25a5 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 22 Nov 2021 14:08:52 +0000 Subject: [PATCH 15/42] docstring improvement --- pytorch_lightning/utilities/enums.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index 5070fbcef0172..ef30f21b70b5c 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -280,7 +280,7 @@ def is_manual(self) -> bool: @classmethod def _detect_fault_tolerant_training_mode(cls) -> "_FaultTolerantTrainingMode": - """This utility detects if Fault Tolerant is activated and maps its value to + """This classmethod detects if `Fault Tolerant` is activated and maps its value to `_FaultTolerantTrainingMode`.""" env_value = os.getenv("PL_FAULT_TOLERANT_TRAINING", "0").lower() if env_value in ("0", "disabled"): From ae712b0a89a9448da38dd9aef9dbb3116ded77c8 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 22 Nov 2021 14:10:25 +0000 Subject: [PATCH 16/42] update --- pytorch_lightning/trainer/states.py | 2 +- pytorch_lightning/utilities/enums.py | 2 +- tests/utilities/test_auto_restart.py | 8 ++++---- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pytorch_lightning/trainer/states.py b/pytorch_lightning/trainer/states.py index 1c9050bbd86c2..5916610d4e43e 100644 --- a/pytorch_lightning/trainer/states.py +++ b/pytorch_lightning/trainer/states.py @@ -96,7 +96,7 @@ class TrainerState: # detect the fault tolerant flag _fault_tolerant_mode: _FaultTolerantTrainingMode = field( - default_factory=_FaultTolerantTrainingMode._detect_fault_tolerant_training_mode + default_factory=_FaultTolerantTrainingMode.detect_fault_tolerant_training_mode ) @property diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index ef30f21b70b5c..91cb63fcf7202 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -279,7 +279,7 @@ def is_manual(self) -> bool: return self is _FaultTolerantTrainingMode.MANUAL @classmethod - def _detect_fault_tolerant_training_mode(cls) -> "_FaultTolerantTrainingMode": + def detect_fault_tolerant_training_mode(cls) -> "_FaultTolerantTrainingMode": """This classmethod detects if `Fault Tolerant` is activated and maps its value to `_FaultTolerantTrainingMode`.""" env_value = os.getenv("PL_FAULT_TOLERANT_TRAINING", "0").lower() diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 295efa023ea41..64a51600c5cc0 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -1197,17 +1197,17 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on def test_fault_tolerant_manual_mode_enum(): with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "0"}): - assert _FaultTolerantTrainingMode.DISABLED == _FaultTolerantTrainingMode._detect_fault_tolerant_training_mode() + assert _FaultTolerantTrainingMode.DISABLED == _FaultTolerantTrainingMode.detect_fault_tolerant_training_mode() trainer = Trainer() assert not trainer.state._fault_tolerant_mode.is_enabled with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}): - assert _FaultTolerantTrainingMode.AUTOMATIC == _FaultTolerantTrainingMode._detect_fault_tolerant_training_mode() + assert _FaultTolerantTrainingMode.AUTOMATIC == _FaultTolerantTrainingMode.detect_fault_tolerant_training_mode() trainer = Trainer() assert trainer.state._fault_tolerant_mode.is_automatic with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "MANUAL"}): - assert _FaultTolerantTrainingMode.MANUAL == _FaultTolerantTrainingMode._detect_fault_tolerant_training_mode() + assert _FaultTolerantTrainingMode.MANUAL == _FaultTolerantTrainingMode.detect_fault_tolerant_training_mode() trainer = Trainer() assert trainer.state._fault_tolerant_mode.is_manual @@ -1215,4 +1215,4 @@ def test_fault_tolerant_manual_mode_enum(): MisconfigurationException, match="The environnement flag `PL_FAULT_TOLERANT_TRAINING` should be either" ): with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "3"}): - _FaultTolerantTrainingMode._detect_fault_tolerant_training_mode() + _FaultTolerantTrainingMode.detect_fault_tolerant_training_mode() From 9a5166dcc2dd05810467f942d18a74778ad1f38d Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 22 Nov 2021 15:17:52 +0100 Subject: [PATCH 17/42] Rename and simplify --- CHANGELOG.md | 2 +- pytorch_lightning/trainer/states.py | 6 ++---- pytorch_lightning/utilities/enums.py | 22 ++++++++++------------ tests/utilities/test_auto_restart.py | 25 +++++++++++-------------- 4 files changed, 24 insertions(+), 31 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5ff6d021c8c4c..b8279f5d2d852 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,7 +13,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fault Tolerant Manual - * Add `_FaultTolerantTrainingMode` enum used to track different supported fault tolerant modes ([#10645](https://github.com/PyTorchLightning/pytorch-lightning/issues/10645)) + * Add `_FaultTolerantMode` enum used to track different supported fault tolerant modes ([#10645](https://github.com/PyTorchLightning/pytorch-lightning/issues/10645)) - diff --git a/pytorch_lightning/trainer/states.py b/pytorch_lightning/trainer/states.py index 5916610d4e43e..a81073cccc1c0 100644 --- a/pytorch_lightning/trainer/states.py +++ b/pytorch_lightning/trainer/states.py @@ -15,7 +15,7 @@ from typing import Optional from pytorch_lightning.utilities import LightningEnum -from pytorch_lightning.utilities.enums import _FaultTolerantTrainingMode +from pytorch_lightning.utilities.enums import _FaultTolerantMode class TrainerStatus(LightningEnum): @@ -95,9 +95,7 @@ class TrainerState: stage: Optional[RunningStage] = None # detect the fault tolerant flag - _fault_tolerant_mode: _FaultTolerantTrainingMode = field( - default_factory=_FaultTolerantTrainingMode.detect_fault_tolerant_training_mode - ) + _fault_tolerant_mode: _FaultTolerantMode = field(default_factory=_FaultTolerantMode.detect_current_mode) @property def finished(self) -> bool: diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index 91cb63fcf7202..bf7ae2621c8ba 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -260,7 +260,7 @@ def is_interactive_compatible(self) -> bool: return self in _StrategyType.interactive_compatible_types() -class _FaultTolerantTrainingMode(LightningEnum): +class _FaultTolerantMode(LightningEnum): DISABLED = "disabled" AUTOMATIC = "automatic" @@ -268,28 +268,26 @@ class _FaultTolerantTrainingMode(LightningEnum): @property def is_enabled(self) -> bool: - return self is not _FaultTolerantTrainingMode.DISABLED + return self is not _FaultTolerantMode.DISABLED @property def is_automatic(self) -> bool: - return self is _FaultTolerantTrainingMode.AUTOMATIC + return self is _FaultTolerantMode.AUTOMATIC @property def is_manual(self) -> bool: - return self is _FaultTolerantTrainingMode.MANUAL + return self is _FaultTolerantMode.MANUAL @classmethod - def detect_fault_tolerant_training_mode(cls) -> "_FaultTolerantTrainingMode": - """This classmethod detects if `Fault Tolerant` is activated and maps its value to - `_FaultTolerantTrainingMode`.""" + def detect_current_mode(cls) -> "_FaultTolerantMode": + """This classmethod detects if `Fault Tolerant` is activated and maps its value to `_FaultTolerantMode`.""" env_value = os.getenv("PL_FAULT_TOLERANT_TRAINING", "0").lower() if env_value in ("0", "disabled"): - return _FaultTolerantTrainingMode.DISABLED + return _FaultTolerantMode.DISABLED elif env_value in ("1", "automatic"): - return _FaultTolerantTrainingMode.AUTOMATIC + return _FaultTolerantMode.AUTOMATIC elif env_value in ("2", "manual"): - return _FaultTolerantTrainingMode.MANUAL + return _FaultTolerantMode.MANUAL raise MisconfigurationException( - "The environnement flag `PL_FAULT_TOLERANT_TRAINING` should be either " - "'0' (disabled), '1' (automatic) or '2' (manual)." + "The environment flag `PL_FAULT_TOLERANT_TRAINING` should be either 'disabled', 'automatic', or 'manual'." ) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 64a51600c5cc0..b9eb97cb42ae8 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -35,6 +35,7 @@ import tests.helpers.utils as tutils from pytorch_lightning import Callback, LightningModule, seed_everything, Trainer +from pytorch_lightning.trainer.states import TrainerState from pytorch_lightning.utilities.auto_restart import ( _add_capture_metadata_collate, _dataloader_load_state_dict, @@ -44,7 +45,7 @@ FastForwardSampler, MergedIteratorState, ) -from pytorch_lightning.utilities.enums import _FaultTolerantTrainingMode, AutoRestartBatchKeys +from pytorch_lightning.utilities.enums import _FaultTolerantMode, AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException from pytorch_lightning.utilities.fetching import DataFetcher from pytorch_lightning.utilities.imports import _fault_tolerant_training @@ -1194,25 +1195,21 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on assert "dataloader_state_dict" in state_dict -def test_fault_tolerant_manual_mode_enum(): - +def test_fault_tolerant_mode_enum(): with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "0"}): - assert _FaultTolerantTrainingMode.DISABLED == _FaultTolerantTrainingMode.detect_fault_tolerant_training_mode() - trainer = Trainer() - assert not trainer.state._fault_tolerant_mode.is_enabled + assert _FaultTolerantMode.DISABLED == _FaultTolerantMode.detect_current_mode() + assert not TrainerState()._fault_tolerant_mode.is_enabled with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}): - assert _FaultTolerantTrainingMode.AUTOMATIC == _FaultTolerantTrainingMode.detect_fault_tolerant_training_mode() - trainer = Trainer() - assert trainer.state._fault_tolerant_mode.is_automatic + assert _FaultTolerantMode.AUTOMATIC == _FaultTolerantMode.detect_current_mode() + assert TrainerState()._fault_tolerant_mode.is_automatic with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "MANUAL"}): - assert _FaultTolerantTrainingMode.MANUAL == _FaultTolerantTrainingMode.detect_fault_tolerant_training_mode() - trainer = Trainer() - assert trainer.state._fault_tolerant_mode.is_manual + assert _FaultTolerantMode.MANUAL == _FaultTolerantMode.detect_current_mode() + assert TrainerState()._fault_tolerant_mode.is_manual with pytest.raises( - MisconfigurationException, match="The environnement flag `PL_FAULT_TOLERANT_TRAINING` should be either" + MisconfigurationException, match="The environment flag `PL_FAULT_TOLERANT_TRAINING` should be either" ): with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "3"}): - _FaultTolerantTrainingMode.detect_fault_tolerant_training_mode() + _FaultTolerantMode.detect_current_mode() From b5fa8192b9eb2f5e982631697a79ba66fd690df8 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 22 Nov 2021 15:23:47 +0100 Subject: [PATCH 18/42] Add comment --- pytorch_lightning/utilities/enums.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_lightning/utilities/enums.py b/pytorch_lightning/utilities/enums.py index bf7ae2621c8ba..1d7a6e3fa5452 100644 --- a/pytorch_lightning/utilities/enums.py +++ b/pytorch_lightning/utilities/enums.py @@ -282,6 +282,7 @@ def is_manual(self) -> bool: def detect_current_mode(cls) -> "_FaultTolerantMode": """This classmethod detects if `Fault Tolerant` is activated and maps its value to `_FaultTolerantMode`.""" env_value = os.getenv("PL_FAULT_TOLERANT_TRAINING", "0").lower() + # the int values are kept for backwards compatibility, but long-term we want to keep only the strings if env_value in ("0", "disabled"): return _FaultTolerantMode.DISABLED elif env_value in ("1", "automatic"): From c82b2f2d4e387afdf2535974ae647710a5d95354 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 22 Nov 2021 14:27:42 +0000 Subject: [PATCH 19/42] update --- pytorch_lightning/utilities/auto_restart.py | 27 ++++++++++++-------- pytorch_lightning/utilities/fetching.py | 12 +++++++++ tests/utilities/test_auto_restart.py | 28 ++++++++++++++++++--- 3 files changed, 54 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 08b480b094b77..ad915f1a2b63e 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -22,7 +22,13 @@ import numpy as np import torch from torch.utils.data import Dataset, get_worker_info, Sampler -from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader, IterableDataset +from torch.utils.data.dataloader import ( + _BaseDataLoaderIter, + _MultiProcessingDataLoaderIter, + _SingleProcessDataLoaderIter, + DataLoader, + IterableDataset, +) import pytorch_lightning as pl from pytorch_lightning.utilities.enums import AutoRestartBatchKeys @@ -610,6 +616,8 @@ def _prepare_loader(self, loader): _capture_metadata_collate, dataset=loader.dataset, default_collate=loader.collate_fn ) self._loader = loader + self._data_fetcher: "pl.utilities.fetching.AbstractDataFetcher" = loader._lightning_fetcher + self.num_batches_fetched = 0 def __del__(self) -> None: if isinstance(self._loader.collate_fn, partial): @@ -645,19 +653,19 @@ class _SingleProcessDataLoaderIterStateful(_StatefulMixin, _SingleProcessDataLoa def __init__(self, loader: DataLoader): self._prepare_loader(loader) super().__init__(loader) - self._data_fetcher: "pl.utilities.fetching.AbstractDataFetcher" = loader._lightning_fetcher - self.num_batches_fetched = 0 class _MultiProcessingDataLoaderIterStateful(_StatefulMixin, _MultiProcessingDataLoaderIter): def __init__(self, loader: DataLoader): self._prepare_loader(loader) super().__init__(loader) - self._data_fetcher: "pl.utilities.fetching.AbstractDataFetcher" = loader._lightning_fetcher - self.num_batches_fetched = 0 def _get_iterator(self) -> "_BaseDataLoaderIter": + if not hasattr(self, "_lightning_fetcher"): + raise MisconfigurationException( + "A stateful iterator should be used only when a DataFetcher has been attached to the DataLoader." + ) if self.num_workers == 0: return _SingleProcessDataLoaderIterStateful(self) else: @@ -668,10 +676,9 @@ def _get_iterator(self) -> "_BaseDataLoaderIter": def _patch_dataloader_get_iterators() -> None: """This function is used to replace the DataLoader iterator by their stateful version.""" - if _fault_tolerant_training_mode().is_manual: - if not hasattr(DataLoader, "_ori_get_iterator"): - DataLoader._ori_get_iterator = DataLoader._get_iterator - DataLoader._get_iterator = _get_iterator + if not hasattr(DataLoader, "_ori_get_iterator"): + DataLoader._ori_get_iterator = DataLoader._get_iterator + DataLoader._get_iterator = _get_iterator def _teardown_dataloader_get_iterators() -> None: @@ -680,4 +687,4 @@ def _teardown_dataloader_get_iterators() -> None: get_iterator = getattr(DataLoader, "_ori_get_iterator", None) if get_iterator: DataLoader._get_iterator = get_iterator - del DataLoader._ori_get_iterator \ No newline at end of file + del DataLoader._ori_get_iterator diff --git a/pytorch_lightning/utilities/fetching.py b/pytorch_lightning/utilities/fetching.py index 9b80d2f9874c7..f5bb4be032d10 100644 --- a/pytorch_lightning/utilities/fetching.py +++ b/pytorch_lightning/utilities/fetching.py @@ -99,6 +99,8 @@ def setup( if self.profiler is not None and stage is None: raise MisconfigurationException("When providing a profiler, the stage should be provided too.") + self._attach_data_fetcher() + @staticmethod def _add_capture_metadata_collate(dataloader: Iterable) -> None: if not isinstance(dataloader, (DataLoader, CombinedLoader)): @@ -190,6 +192,16 @@ def collect_state(iterator: Iterator): return apply_to_collection(self.loader_iters, Iterator, collect_state) + def _attach_data_fetcher(self): + def _attach_data_fetcher_fn(loader: DataLoader): + if isinstance(loader, CycleIterator): + loader = loader.loader + + if isinstance(loader, DataLoader) and _fault_tolerant_training(): + loader._lightning_fetcher = self + + apply_to_collection(self.loaders, (DataLoader, CycleIterator), _attach_data_fetcher_fn) + def __iter__(self) -> Generator[Tuple[Any, bool], None, None]: if self.dataloader is None: raise MisconfigurationException("The iterate hasn't been provided. HINT: Did you call setup function ?.") diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 36658469065ba..a2534b1bed2c6 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -39,12 +39,13 @@ _add_capture_metadata_collate, _dataloader_load_state_dict, _dataloader_to_state_dict, + _patch_dataloader_get_iterators, + _SingleProcessDataLoaderIterStateful, + _teardown_dataloader_get_iterators, CaptureIterableDataset, CaptureMapDataset, FastForwardSampler, MergedIteratorState, - _patch_dataloader_get_iterators, - _teardown_dataloader_get_iterators, ) from pytorch_lightning.utilities.enums import AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException @@ -1196,6 +1197,27 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on assert "dataloader_state_dict" in state_dict +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) def test_stateful_workers(): - _patch_dataloader_get_iterators() \ No newline at end of file + seed_everything(42) + + _patch_dataloader_get_iterators() + assert DataLoader._ori_get_iterator is not None + + data_fetcher = DataFetcher() + dataloader = DataLoader(range(10), shuffle=True) + + with pytest.raises(MisconfigurationException, match="A stateful iterator should be used"): + iter(dataloader) + + # This would attach the `data_fetcher` to the DataLoader. + data_fetcher.setup(dataloader) + + dataloader_iter = iter(dataloader) + assert isinstance(dataloader_iter, _SingleProcessDataLoaderIterStateful) + + batch = next(dataloader_iter) + print(batch) + + _teardown_dataloader_get_iterators() \ No newline at end of file From 2baddb928c4ca091c94c4da38325391a8b5a70af Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 22 Nov 2021 15:38:22 +0000 Subject: [PATCH 20/42] update --- tests/utilities/test_auto_restart.py | 96 ++++++++++++++++++---------- 1 file changed, 62 insertions(+), 34 deletions(-) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 7071990e803b4..e71e2de455dee 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -41,6 +41,7 @@ _dataloader_load_state_dict, _dataloader_to_state_dict, _is_obj_stateful, + _MultiProcessingDataLoaderIterStateful, _patch_dataloader_get_iterators, _SingleProcessDataLoaderIterStateful, _teardown_dataloader_get_iterators, @@ -1228,40 +1229,48 @@ def load_state_dict(self): assert not _is_obj_stateful(obj) -@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "2"}) -def test_stateful_workers(): - class StatefulRandomSampler(RandomSampler): +class StatefulRandomSampler(RandomSampler): - counter = 0 + counter = 0 - def state_dict(self): - self.counter += 1 - return {"counter": self.counter} + def state_dict(self): + self.counter += 1 + return {"counter": self.counter} - def load_state_dict(self, state_dict): - self.counter = state_dict["counter"] + def load_state_dict(self, state_dict): + self.counter = state_dict["counter"] - class FailingStatefulRandomDataset(RandomDataset): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.counter = 0 - def __getitem__(self, index): - self.counter += 1 - return super().__getitem__(index) +class FailingStatefulRandomDataset(RandomDataset): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.counter = 0 - def state_dict(self): - return {"counter": self.counter} + def __getitem__(self, index): + self.counter += 1 + return super().__getitem__(index) - def load_state_dict(self, state_dict): - self.counter = state_dict["counter"] + def state_dict(self): + return {"counter": self.counter} + + def load_state_dict(self, state_dict): + self.counter = state_dict["counter"] - class StatefulRandomDataset(FailingStatefulRandomDataset): - def state_dict(self): - return {0: {"counter": self.counter}} + +class StatefulRandomDataset(FailingStatefulRandomDataset): + def state_dict(self): + info = get_worker_info() + worker_id = info.id if info else 0 + return {worker_id: {"counter": self.counter}} + + +@pytest.mark.parametrize("num_workers", [0, 2]) +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "2"}) +def test_stateful_workers(num_workers): seed_everything(42) + _get_iterator_fn = DataLoader._get_iterator _patch_dataloader_get_iterators() assert DataLoader._ori_get_iterator is not None @@ -1282,27 +1291,46 @@ def state_dict(self): data_fetcher = DataFetcher() dataset = StatefulRandomDataset(1, 64) - dataloader = DataLoader(dataset, sampler=StatefulRandomSampler(dataset)) + dataloader = DataLoader(dataset, sampler=StatefulRandomSampler(dataset), num_workers=num_workers) # This would attach the `data_fetcher` to the DataLoader. data_fetcher.setup(dataloader) data_fetcher_iter = iter(data_fetcher) + worker_type = _SingleProcessDataLoaderIterStateful if num_workers == 0 else _MultiProcessingDataLoaderIterStateful + assert isinstance(data_fetcher.dataloader_iter, worker_type) + next(data_fetcher_iter) - assert data_fetcher.dataloader_iter.state.state[0].dataset_state == {0: {"counter": 1}} - assert data_fetcher.dataloader_iter.state.state[0].sampler_state["sampler"] == {"counter": 1} + state = data_fetcher.dataloader_iter.state.state + assert state[0].dataset_state == {0: {"counter": 1}} + assert state[0].sampler_state["sampler"] == {"counter": 1} + next(data_fetcher_iter) - assert data_fetcher.dataloader_iter.previous_state.state[0].dataset_state == {0: {"counter": 1}} - assert data_fetcher.dataloader_iter.previous_state.state[0].sampler_state["sampler"] == {"counter": 1} - assert data_fetcher.dataloader_iter.state.state[0].dataset_state == {0: {"counter": 2}} - assert data_fetcher.dataloader_iter.state.state[0].sampler_state["sampler"] == {"counter": 2} + previous_state = data_fetcher.dataloader_iter.previous_state.state + state = data_fetcher.dataloader_iter.state.state + assert previous_state[0].dataset_state == {0: {"counter": 1}} + assert previous_state[0].sampler_state["sampler"] == {"counter": 1} + # TODO: Resolve the previous `sampler_state` associated to `worker_id: 0`. + worker_id = 1 if num_workers else 0 + assert state[worker_id].sampler_state["sampler"] == {"counter": 2} + + # each worker has its own copy of the dataset + assert state[0].dataset_state == ({0: {"counter": 2}} if num_workers == 0 else {0: {"counter": 1}}) + target_previous_state = deepcopy(state) + next(data_fetcher_iter) - assert data_fetcher.dataloader_iter.previous_state.state[0].dataset_state == {0: {"counter": 2}} - assert data_fetcher.dataloader_iter.previous_state.state[0].sampler_state["sampler"] == {"counter": 2} - assert data_fetcher.dataloader_iter.state.state[0].dataset_state == {0: {"counter": 3}} - assert data_fetcher.dataloader_iter.state.state[0].sampler_state["sampler"] == {"counter": 3} + latest_worker_id = data_fetcher.dataloader_iter.state.latest_worker_id + assert latest_worker_id == 0 + previous_state = data_fetcher.dataloader_iter.previous_state.state + state = data_fetcher.dataloader_iter.state.state + + assert target_previous_state == previous_state + assert state[0].sampler_state["sampler"] == {"counter": 3} + assert state[0].dataset_state == ({0: {"counter": 3}} if num_workers == 0 else {0: {"counter": 2}}) _teardown_dataloader_get_iterators() + assert not hasattr(DataLoader, "_ori_get_iterator") + assert DataLoader._get_iterator == _get_iterator_fn def test_fault_tolerant_mode_enum(): From 97548bb18ed941c5134344d88ca5181ef5f8a708 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 22 Nov 2021 15:52:11 +0000 Subject: [PATCH 21/42] update --- tests/utilities/test_auto_restart.py | 40 ++++++++++++++-------------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index e71e2de455dee..8e1c0c613e775 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -1264,6 +1264,26 @@ def state_dict(self): return {worker_id: {"counter": self.counter}} +def test_fault_tolerant_mode_enum(): + with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "0"}): + assert _FaultTolerantMode.DISABLED == _FaultTolerantMode.detect_current_mode() + assert not TrainerState()._fault_tolerant_mode.is_enabled + + with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}): + assert _FaultTolerantMode.AUTOMATIC == _FaultTolerantMode.detect_current_mode() + assert TrainerState()._fault_tolerant_mode.is_automatic + + with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "MANUAL"}): + assert _FaultTolerantMode.MANUAL == _FaultTolerantMode.detect_current_mode() + assert TrainerState()._fault_tolerant_mode.is_manual + + with pytest.raises( + MisconfigurationException, match="The environment flag `PL_FAULT_TOLERANT_TRAINING` should be either" + ): + with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "3"}): + _FaultTolerantMode.detect_current_mode() + + @pytest.mark.parametrize("num_workers", [0, 2]) @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "2"}) def test_stateful_workers(num_workers): @@ -1331,23 +1351,3 @@ def test_stateful_workers(num_workers): _teardown_dataloader_get_iterators() assert not hasattr(DataLoader, "_ori_get_iterator") assert DataLoader._get_iterator == _get_iterator_fn - - -def test_fault_tolerant_mode_enum(): - with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "0"}): - assert _FaultTolerantMode.DISABLED == _FaultTolerantMode.detect_current_mode() - assert not TrainerState()._fault_tolerant_mode.is_enabled - - with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}): - assert _FaultTolerantMode.AUTOMATIC == _FaultTolerantMode.detect_current_mode() - assert TrainerState()._fault_tolerant_mode.is_automatic - - with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "MANUAL"}): - assert _FaultTolerantMode.MANUAL == _FaultTolerantMode.detect_current_mode() - assert TrainerState()._fault_tolerant_mode.is_manual - - with pytest.raises( - MisconfigurationException, match="The environment flag `PL_FAULT_TOLERANT_TRAINING` should be either" - ): - with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "3"}): - _FaultTolerantMode.detect_current_mode() From 41ffbaba56a9a6b8e8075f7a4d8363ddb91f4819 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 22 Nov 2021 15:58:50 +0000 Subject: [PATCH 22/42] use_teardown --- tests/utilities/test_auto_restart.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index b9c6df4cbd961..1de265592a433 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -1351,3 +1351,5 @@ def test_stateful_workers(num_workers): _teardown_dataloader_get_iterators() assert not hasattr(DataLoader, "_ori_get_iterator") assert DataLoader._get_iterator == _get_iterator_fn + + data_fetcher.teardown() From d04596df734a9c10d68f6ce8a301f9ab783574aa Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 22 Nov 2021 17:13:20 +0100 Subject: [PATCH 23/42] Use `Protocol` --- pytorch_lightning/utilities/auto_restart.py | 18 +++++++----------- tests/utilities/test_auto_restart.py | 13 +++++-------- 2 files changed, 12 insertions(+), 19 deletions(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 0310aceb4bc49..4b0ccddace55c 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import inspect from copy import deepcopy from dataclasses import dataclass, field from functools import partial, wraps @@ -24,6 +22,7 @@ import torch from torch.utils.data import Dataset, get_worker_info, Sampler from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader, IterableDataset +from typing_extensions import Protocol, runtime_checkable import pytorch_lightning as pl from pytorch_lightning.utilities.enums import AutoRestartBatchKeys @@ -574,13 +573,10 @@ def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, A raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.") -def _is_obj_stateful(obj: Any) -> bool: - """In order to be stateful, an object should implement a ``state_dict`` and ``load_state_dict`` method.""" - load_state_dict_fn = getattr(obj, "load_state_dict", None) - if not isinstance(load_state_dict_fn, Callable): - return False - params = inspect.signature(load_state_dict_fn).parameters - if len(params) == 0: - return False +@runtime_checkable +class _SupportsStateDict(Protocol): + def state_dict(self) -> Dict[str, Any]: + ... - return isinstance(getattr(obj, "state_dict", None), Callable) + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + ... diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 0c35ba651cdb6..6392249300961 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -40,7 +40,7 @@ _add_capture_metadata_collate, _dataloader_load_state_dict, _dataloader_to_state_dict, - _is_obj_stateful, + _SupportsStateDict, CaptureIterableDataset, CaptureMapDataset, FastForwardSampler, @@ -1196,7 +1196,7 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on assert "dataloader_state_dict" in state_dict -def test_is_obj_stateful(): +def test_supports_state_dict_protocol(): class StatefulClass: def state_dict(self): pass @@ -1205,24 +1205,21 @@ def load_state_dict(self, state_dict): pass obj = StatefulClass() - assert _is_obj_stateful(obj) + assert isinstance(obj, _SupportsStateDict) class NotStatefulClass: def state_dict(self): pass - def load_state_dict(self): - pass - obj = NotStatefulClass() - assert not _is_obj_stateful(obj) + assert not isinstance(obj, _SupportsStateDict) class NotStateful2Class: def load_state_dict(self, state_dict): pass obj = NotStateful2Class() - assert not _is_obj_stateful(obj) + assert not isinstance(obj, _SupportsStateDict) def test_fault_tolerant_mode_enum(): From ff7b8367941e4b493c19c6f5d9fcc845241cb955 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 22 Nov 2021 17:20:45 +0100 Subject: [PATCH 24/42] Simplify test --- tests/utilities/test_auto_restart.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 6392249300961..5152874b39469 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -1204,22 +1204,19 @@ def state_dict(self): def load_state_dict(self, state_dict): pass - obj = StatefulClass() - assert isinstance(obj, _SupportsStateDict) + assert isinstance(StatefulClass(), _SupportsStateDict) class NotStatefulClass: def state_dict(self): pass - obj = NotStatefulClass() - assert not isinstance(obj, _SupportsStateDict) + assert not isinstance(NotStatefulClass(), _SupportsStateDict) class NotStateful2Class: def load_state_dict(self, state_dict): pass - obj = NotStateful2Class() - assert not isinstance(obj, _SupportsStateDict) + assert not isinstance(NotStateful2Class(), _SupportsStateDict) def test_fault_tolerant_mode_enum(): From a5698e655020b8ed0698fb765cb1c7490da3a00a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 22 Nov 2021 17:21:14 +0100 Subject: [PATCH 25/42] Update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c19f254007628..169308718e569 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,7 +13,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fault Tolerant Manual - * Add `_is_obj_stateful` utility to detect if user data loading components are stateful ([#10646](https://github.com/PyTorchLightning/pytorch-lightning/issues/10646)) + * Add `_SupportsStateDict` protocol to detect if classes are stateful ([#10646](https://github.com/PyTorchLightning/pytorch-lightning/issues/10646)) * Add `_FaultTolerantMode` enum used to track different supported fault tolerant modes ([#10645](https://github.com/PyTorchLightning/pytorch-lightning/issues/10645)) From 916b520cb47f365cce4b20b113f1af2987966384 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 22 Nov 2021 17:56:34 +0000 Subject: [PATCH 26/42] update --- pytorch_lightning/utilities/auto_restart.py | 20 +++++++++++------- tests/utilities/test_auto_restart.py | 23 ++++++++++++++++----- 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index d71718155f9ba..26527047cddb6 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -448,9 +448,9 @@ def _find_current_worker(iterator: Iterator) -> Dict[str, Optional[int]]: def _capture_metadata_collate( - samples: List, dataset: Dataset, collate: Callable, fault_tolerant_mode: _FaultTolerantMode + samples: List, dataset: Dataset, collate_fn: Callable, fault_tolerant_mode: _FaultTolerantMode ) -> Any: - """A collate function that adds the state dict of a :class:`CaptureIterableDataset` or + """A collate_fn function that adds the state dict of a :class:`CaptureIterableDataset` or :class:`CaptureMapDataset` used in the worker processes. This function gets executed within the worker processes. The structure will be: @@ -461,7 +461,7 @@ def _capture_metadata_collate( "__pl_restart_meta": {"sampler_name0": state_dict0, "sampler_name1": state_dict1}, } """ - data = collate(samples) + data = collate_fn(samples) fault_tolerant_mode if not fault_tolerant_mode.is_enabled: return data @@ -475,9 +475,11 @@ def _capture_metadata_collate( if state_dict_fn: metadata = state_dict_fn() if worker_id not in metadata: - raise MisconfigurationException( - f"The state_dict returned by {dataset} needs to be indexed by `worker_id` integer keys." - ) + if info and info.num_workers > 1: + raise MisconfigurationException( + f"The state_dict returned by {dataset} needs to be indexed by `worker_id` integer keys." + ) + metadata = {0: metadata} if metadata is None: metadata = {worker_id: {}} @@ -563,7 +565,7 @@ def _add_capture_metadata_collate(dataloader: DataLoader) -> None: dataloader.collate_fn = partial( _capture_metadata_collate, dataset=dataloader.dataset, - default_collate=dataloader.collate_fn, + collate_fn=dataloader.collate_fn, fault_tolerant_mode=_FaultTolerantMode.detect_current_mode(), ) @@ -652,10 +654,12 @@ def _prepare_loader(self, loader): self._loader = loader self._data_fetcher: "pl.utilities.fetching.AbstractDataFetcher" = loader._lightning_fetcher self.num_batches_fetched = 0 + self._sampler_state = [] + self._sampler_state_idx = 0 def __del__(self) -> None: if isinstance(self._loader.collate_fn, partial): - self._loader.collate_fn = self._loader.collate_fn.keywords["default_collate"] + self._loader.collate_fn = self._loader.collate_fn.keywords["collate_fn"] def _next_data(self) -> Any: combined_batch = super()._next_data() diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index f02c68946db91..8b574b1b118d5 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -1272,13 +1272,17 @@ def load_state_dict(self, state_dict): class StatefulRandomDataset(FailingStatefulRandomDataset): + + provide_workers_id = False + def state_dict(self): info = get_worker_info() - worker_id = info.id if info else 0 - return {worker_id: {"counter": self.counter}} + if info and self.provide_workers_id: + return {info.id: {"counter": self.counter}} + return {"counter": self.counter} -@pytest.mark.parametrize("num_workers", [0, 2]) +@pytest.mark.parametrize("num_workers", [2]) @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "2"}) def test_stateful_workers(num_workers): @@ -1300,8 +1304,7 @@ def test_stateful_workers(num_workers): dataloader_iter = iter(dataloader) assert isinstance(dataloader_iter, _SingleProcessDataLoaderIterStateful) - with pytest.raises(MisconfigurationException, match="he state_dict returned by"): - next(dataloader_iter) + next(dataloader_iter) data_fetcher = DataFetcher() dataset = StatefulRandomDataset(1, 64) @@ -1309,6 +1312,16 @@ def test_stateful_workers(num_workers): # This would attach the `data_fetcher` to the DataLoader. data_fetcher.setup(dataloader) + + if num_workers == 2: + with pytest.raises(MisconfigurationException, match="The state_dict returned by"): + data_fetcher_iter = iter(data_fetcher) + + data_fetcher = DataFetcher() + dataset.provide_workers_id = True + dataloader = DataLoader(dataset, sampler=StatefulRandomSampler(dataset), num_workers=num_workers) + data_fetcher.setup(dataloader) + data_fetcher_iter = iter(data_fetcher) worker_type = _SingleProcessDataLoaderIterStateful if num_workers == 0 else _MultiProcessingDataLoaderIterStateful From 4b67fbf8289dce70022f989839bb785161d5fc42 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 22 Nov 2021 18:11:56 +0000 Subject: [PATCH 27/42] update --- pytorch_lightning/utilities/auto_restart.py | 7 ++-- tests/utilities/test_auto_restart.py | 38 +++++++-------------- 2 files changed, 15 insertions(+), 30 deletions(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 26527047cddb6..829b101ba653f 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -462,7 +462,6 @@ def _capture_metadata_collate( } """ data = collate_fn(samples) - fault_tolerant_mode if not fault_tolerant_mode.is_enabled: return data metadata = None @@ -613,7 +612,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: ... -class _StatefulMixin: +class _StatefulDataLoaderIter: """This mixin is used to make PyTorch DataLoaderIter stateful.""" def _reset(self, loader: DataLoader, first_iter: bool = False): @@ -687,13 +686,13 @@ def _next_data(self) -> Any: return batch -class _SingleProcessDataLoaderIterStateful(_StatefulMixin, _SingleProcessDataLoaderIter): +class _SingleProcessDataLoaderIterStateful(_StatefulDataLoaderIter, _SingleProcessDataLoaderIter): def __init__(self, loader: DataLoader): self._prepare_loader(loader) super().__init__(loader) -class _MultiProcessingDataLoaderIterStateful(_StatefulMixin, _MultiProcessingDataLoaderIter): +class _MultiProcessingDataLoaderIterStateful(_StatefulDataLoaderIter, _MultiProcessingDataLoaderIter): def __init__(self, loader: DataLoader): self._prepare_loader(loader) super().__init__(loader) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 8b574b1b118d5..3adf159901355 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -1255,32 +1255,28 @@ def load_state_dict(self, state_dict): self.counter = state_dict["counter"] -class FailingStatefulRandomDataset(RandomDataset): +class StatefulRandomDataset(RandomDataset): + + provide_workers_id = False + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.counter = 0 + self.provide_workers_id = False def __getitem__(self, index): self.counter += 1 return super().__getitem__(index) - def state_dict(self): - return {"counter": self.counter} - - def load_state_dict(self, state_dict): - self.counter = state_dict["counter"] - - -class StatefulRandomDataset(FailingStatefulRandomDataset): - - provide_workers_id = False - def state_dict(self): info = get_worker_info() if info and self.provide_workers_id: return {info.id: {"counter": self.counter}} return {"counter": self.counter} + def load_state_dict(self, state_dict): + self.counter = state_dict["counter"] + @pytest.mark.parametrize("num_workers", [2]) @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "2"}) @@ -1293,7 +1289,8 @@ def test_stateful_workers(num_workers): assert DataLoader._ori_get_iterator is not None data_fetcher = DataFetcher() - dataloader = DataLoader(FailingStatefulRandomDataset(1, 64), shuffle=True) + dataset = StatefulRandomDataset(1, 64) + dataloader = DataLoader(dataset, sampler=StatefulRandomSampler(dataset), num_workers=num_workers) with pytest.raises(MisconfigurationException, match="A stateful iterator should be used"): iter(dataloader) @@ -1302,16 +1299,8 @@ def test_stateful_workers(num_workers): data_fetcher.setup(dataloader) dataloader_iter = iter(dataloader) - assert isinstance(dataloader_iter, _SingleProcessDataLoaderIterStateful) - - next(dataloader_iter) - - data_fetcher = DataFetcher() - dataset = StatefulRandomDataset(1, 64) - dataloader = DataLoader(dataset, sampler=StatefulRandomSampler(dataset), num_workers=num_workers) - - # This would attach the `data_fetcher` to the DataLoader. - data_fetcher.setup(dataloader) + worker_type = _SingleProcessDataLoaderIterStateful if num_workers == 0 else _MultiProcessingDataLoaderIterStateful + assert isinstance(dataloader_iter, worker_type) if num_workers == 2: with pytest.raises(MisconfigurationException, match="The state_dict returned by"): @@ -1324,9 +1313,6 @@ def test_stateful_workers(num_workers): data_fetcher_iter = iter(data_fetcher) - worker_type = _SingleProcessDataLoaderIterStateful if num_workers == 0 else _MultiProcessingDataLoaderIterStateful - assert isinstance(data_fetcher.dataloader_iter, worker_type) - next(data_fetcher_iter) state = data_fetcher.dataloader_iter.state.state assert state[0].dataset_state == {0: {"counter": 1}} From c9481e2fd817008b72de1ecf1982df7cd24bc23b Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 22 Nov 2021 18:17:01 +0000 Subject: [PATCH 28/42] update --- tests/utilities/test_auto_restart.py | 23 +++++------------------ 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 3adf159901355..6032d2fb11709 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -1256,13 +1256,9 @@ def load_state_dict(self, state_dict): class StatefulRandomDataset(RandomDataset): - - provide_workers_id = False - def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.counter = 0 - self.provide_workers_id = False def __getitem__(self, index): self.counter += 1 @@ -1270,7 +1266,7 @@ def __getitem__(self, index): def state_dict(self): info = get_worker_info() - if info and self.provide_workers_id: + if info: return {info.id: {"counter": self.counter}} return {"counter": self.counter} @@ -1278,7 +1274,7 @@ def load_state_dict(self, state_dict): self.counter = state_dict["counter"] -@pytest.mark.parametrize("num_workers", [2]) +@pytest.mark.parametrize("num_workers", [0, 2]) @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "2"}) def test_stateful_workers(num_workers): @@ -1298,21 +1294,12 @@ def test_stateful_workers(num_workers): # This would attach the `data_fetcher` to the DataLoader. data_fetcher.setup(dataloader) - dataloader_iter = iter(dataloader) + data_fetcher_iter = iter(data_fetcher) + + dataloader_iter = data_fetcher.dataloader_iter worker_type = _SingleProcessDataLoaderIterStateful if num_workers == 0 else _MultiProcessingDataLoaderIterStateful assert isinstance(dataloader_iter, worker_type) - if num_workers == 2: - with pytest.raises(MisconfigurationException, match="The state_dict returned by"): - data_fetcher_iter = iter(data_fetcher) - - data_fetcher = DataFetcher() - dataset.provide_workers_id = True - dataloader = DataLoader(dataset, sampler=StatefulRandomSampler(dataset), num_workers=num_workers) - data_fetcher.setup(dataloader) - - data_fetcher_iter = iter(data_fetcher) - next(data_fetcher_iter) state = data_fetcher.dataloader_iter.state.state assert state[0].dataset_state == {0: {"counter": 1}} From 4a1fff73df469d720a5507fdec9a297b7b61f9a9 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 22 Nov 2021 19:02:56 +0000 Subject: [PATCH 29/42] update --- pytorch_lightning/utilities/auto_restart.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index a425eda222bae..1eaffd83e6023 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -16,7 +16,7 @@ from functools import partial, wraps from random import getstate as python_get_rng_state from random import setstate as python_set_rng_state -from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Protocol, runtime_checkable, Tuple, Union +from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Tuple, Union import numpy as np import torch @@ -28,6 +28,7 @@ DataLoader, IterableDataset, ) +from typing_extensions import Protocol, runtime_checkable import pytorch_lightning as pl from pytorch_lightning.utilities.enums import _FaultTolerantMode, AutoRestartBatchKeys From cb27e305a5b1b546b271a9e034bbc7e767c88d83 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 22 Nov 2021 19:06:18 +0000 Subject: [PATCH 30/42] update --- tests/utilities/test_auto_restart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 6032d2fb11709..7223482ca2731 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -1274,7 +1274,7 @@ def load_state_dict(self, state_dict): self.counter = state_dict["counter"] -@pytest.mark.parametrize("num_workers", [0, 2]) +@pytest.mark.parametrize("num_workers", [0]) @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "2"}) def test_stateful_workers(num_workers): From 7903d24a28b8ab6db5c6cbb444ca2bdf75b6e136 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 22 Nov 2021 19:32:49 +0000 Subject: [PATCH 31/42] resolve tests --- pytorch_lightning/utilities/auto_restart.py | 7 ++++--- tests/utilities/test_auto_restart.py | 1 + 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 1eaffd83e6023..99f36a2ff70fb 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -463,8 +463,6 @@ def _capture_metadata_collate( } """ data = collate_fn(samples) - if not fault_tolerant_mode.is_enabled: - return data metadata = None if fault_tolerant_mode.is_automatic: metadata = dataset.state_dict() @@ -562,11 +560,14 @@ def wrapper(): def _add_capture_metadata_collate(dataloader: DataLoader) -> None: """Wrap default collate function to retrive captured dataset state dict when fault tolerant is enabled.""" + faut_tolerant_mode = _FaultTolerantMode.detect_current_mode() + if not faut_tolerant_mode.is_enabled: + return dataloader.collate_fn = partial( _capture_metadata_collate, dataset=dataloader.dataset, collate_fn=dataloader.collate_fn, - fault_tolerant_mode=_FaultTolerantMode.detect_current_mode(), + fault_tolerant_mode=faut_tolerant_mode, ) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 7223482ca2731..b1e1b53cb61b0 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -250,6 +250,7 @@ def __next__(self): @pytest.mark.skipif(torch.cuda.is_available(), reason="This test takes around 30 sec and should be skipped in Azure CI") @pytest.mark.parametrize("num_workers", [0, 1, 2]) +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) def test_fast_forward_sampler_over_iterable_dataset(num_workers): """This test ensures ``FastForwardSampler`` and ``CaptureIterableDataset`` are properly being used to capture workers states.""" From 1104cbcf3f89d2f4e148d93045abcf6b7bb3724e Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 23 Nov 2021 09:30:56 +0000 Subject: [PATCH 32/42] update --- pytorch_lightning/trainer/data_loading.py | 8 +------- pytorch_lightning/utilities/auto_restart.py | 12 +----------- tests/utilities/test_auto_restart.py | 2 +- 3 files changed, 3 insertions(+), 19 deletions(-) diff --git a/pytorch_lightning/trainer/data_loading.py b/pytorch_lightning/trainer/data_loading.py index 226c2fd3fba45..6044f1320286c 100644 --- a/pytorch_lightning/trainer/data_loading.py +++ b/pytorch_lightning/trainer/data_loading.py @@ -214,7 +214,7 @@ def reset_train_dataloader(self, model: Optional["pl.LightningModule"] = None) - # add collate_fn to collect metadata for fault tolerant training if _fault_tolerant_training(): - apply_to_collection(self.train_dataloader, DataLoader, self._add_sampler_metadata_collate) + apply_to_collection(self.train_dataloader, DataLoader, _add_capture_metadata_collate) # wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches self.train_dataloader = CombinedLoader(self.train_dataloader, self._data_connector.multiple_trainloader_mode) @@ -436,12 +436,6 @@ def request_dataloader( self.training_type_plugin.barrier("get_dataloaders") return dataloader - @staticmethod - def _add_sampler_metadata_collate(dataloader: DataLoader) -> None: - """Wrap default collate function to retrive ``FastForwardSampler`` state dict when fault tolerant is - enabled.""" - _add_capture_metadata_collate(dataloader) - @staticmethod def _resolve_overfit_batches(dataloader: Collection[DataLoader]) -> Collection[DataLoader]: all_have_sequential_sampler = True diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index c1aa2471c1e9b..4e984f7ecb2aa 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -463,7 +463,7 @@ def _capture_metadata_collate( state_dict_fn = getattr(dataset, "state_dict", None) info = get_worker_info() worker_id = info.id if info else 0 - if state_dict_fn: + if state_dict_fn is not None: metadata = state_dict_fn() if worker_id not in metadata: if info and info.num_workers > 1: @@ -626,17 +626,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: class _StatefulDataLoaderIter: """This mixin is used to make PyTorch DataLoaderIter stateful.""" - def _reset(self, loader: DataLoader, first_iter: bool = False): - super()._reset(loader, first_iter=first_iter) - self._loader = loader - self.num_batches_fetched = 0 - def __accumulate_state(self, sampler_state: Dict[str, Any]) -> None: - # initialize the queue if it doesn't exist. - if not hasattr(self, "_sampler_state"): - self._sampler_state = [] - self._sampler_state_idx = 0 - # store sampler state within a queue alongside its idx. self._sampler_state_idx = getattr(self, "_sampler_state_idx", 0) + 1 self._sampler_state.append((sampler_state, self._sampler_state_idx)) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 25f80ec6817a5..b27ed93ce422b 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -1292,7 +1292,7 @@ def load_state_dict(self, state_dict): self.counter = state_dict["counter"] -@pytest.mark.parametrize("num_workers", [0]) +@pytest.mark.parametrize("num_workers", [0, 2]) @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "2"}) def test_stateful_workers(num_workers): From f071f9abe188f7192ab5fc30de5e85c53b386f93 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 23 Nov 2021 09:32:57 +0000 Subject: [PATCH 33/42] change to 0 --- tests/utilities/test_auto_restart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index b27ed93ce422b..25f80ec6817a5 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -1292,7 +1292,7 @@ def load_state_dict(self, state_dict): self.counter = state_dict["counter"] -@pytest.mark.parametrize("num_workers", [0, 2]) +@pytest.mark.parametrize("num_workers", [0]) @mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "2"}) def test_stateful_workers(num_workers): From b777dc3912491014744525ae57c1e8091e4a893d Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 23 Nov 2021 11:26:20 +0000 Subject: [PATCH 34/42] update --- .../loops/epoch/evaluation_epoch_loop.py | 4 +- pytorch_lightning/trainer/supporters.py | 4 +- pytorch_lightning/utilities/auto_restart.py | 70 +++++++++++++------ tests/utilities/test_auto_restart.py | 3 +- 4 files changed, 56 insertions(+), 25 deletions(-) diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 102603f20302b..6347882acb6cf 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -22,7 +22,7 @@ from pytorch_lightning.loops.base import Loop from pytorch_lightning.loops.utilities import _update_dataloader_iter from pytorch_lightning.trainer.progress import BatchProgress -from pytorch_lightning.utilities.auto_restart import MergedIteratorState, reload_dataloader_state_dict +from pytorch_lightning.utilities.auto_restart import MergedIteratorState, _reload_dataloader_state_dict from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataFetcher from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT @@ -182,7 +182,7 @@ def on_load_checkpoint(self, state_dict: Dict) -> None: def _reload_dataloader_state_dict(self, data_fetcher: AbstractDataFetcher): if not self.trainer.sanity_checking and self._dataloader_state_dict: - reload_dataloader_state_dict(data_fetcher.dataloader, self._dataloader_state_dict) + _reload_dataloader_state_dict(data_fetcher.dataloader, self._dataloader_state_dict) self._dataloader_state_dict = None def _num_completed_batches_reached(self) -> bool: diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 6e2e51e82bbf1..48ebf82fcbec0 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -26,7 +26,7 @@ from pytorch_lightning.utilities.auto_restart import ( MergedIteratorState, patch_dataloader_iterator, - reload_dataloader_state_dict, + _reload_dataloader_state_dict, ) from pytorch_lightning.utilities.data import get_len from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -403,7 +403,7 @@ def create_loader_iters(dataloader: DataLoader, state_dict: Dict) -> Iterator: if isinstance(dataloader, CycleIterator): dataloader = dataloader_to_iter_on.loader - reload_dataloader_state_dict(dataloader, state_dict) + _reload_dataloader_state_dict(dataloader, state_dict) # We finally spawned the workers if any. it = iter(dataloader_to_iter_on) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 4e984f7ecb2aa..e3996449041e2 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -564,40 +564,70 @@ def _add_capture_metadata_collate(dataloader: DataLoader) -> None: ) -def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None: +def _reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None: """Utility to reload state_dict within dataloader for fault tolerance.""" - if not _fault_tolerant_training(): + fault_tolerant_mode = _FaultTolerantMode.detect_current_mode() + + if not fault_tolerant_mode.is_enabled: return dataset = dataloader.dataset - if isinstance(dataset, CaptureMapDataset): - iterator_state = state_dict["state"][0] + if fault_tolerant_mode.is_automatic: + if isinstance(dataset, CaptureMapDataset): + iterator_state = state_dict["state"][0] - if not isinstance(iterator_state, IteratorState): - iterator_state = IteratorState.from_state_dict(iterator_state) + if not isinstance(iterator_state, IteratorState): + iterator_state = IteratorState.from_state_dict(iterator_state) - # reload sampler state - ff_sampler = _find_fast_forward_samplers(dataloader) - ff_sampler.load_state_dict(iterator_state.sampler_state) + # reload sampler state + ff_sampler = _find_fast_forward_samplers(dataloader) + ff_sampler.load_state_dict(iterator_state.sampler_state) - # reload dataset state - dataset.load_state_dict( - iterator_state.dataset_state, - latest_worker_id=state_dict["latest_worker_id"], - num_workers=iterator_state.num_workers, - ) + # reload dataset state + dataset.load_state_dict( + iterator_state.dataset_state, + latest_worker_id=state_dict["latest_worker_id"], + num_workers=iterator_state.num_workers, + ) - elif isinstance(dataset, CaptureIterableDataset): - dataset.load_state_dict( - {sampler_name: state[0]["sampler_state"] for sampler_name, state in state_dict["state"].items()} - ) + elif isinstance(dataset, CaptureIterableDataset): + dataset.load_state_dict( + {sampler_name: state[0]["sampler_state"] for sampler_name, state in state_dict["state"].items()} + ) + + else: + raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.") + + elif fault_tolerant_mode.is_manual: + + latest_worker_id = state_dict["latest_worker_id"] + num_workers = state_dict["state"][latest_worker_id]["num_workers"] + sampler_state = state_dict["state"][latest_worker_id]["sampler_state"] + if sampler_state: + for k in sampler_state: + obj = getattr(dataloader, k) + if not isinstance(obj, _SupportsStateDict): + raise MisconfigurationException( + f"The DataLoader attribute should have a `load_state_dict` method. Found {obj}" + ) + + obj.load_state_dict(sampler_state[k]) + + if not hasattr(dataset, "load_state_dict"): + return + + dataset_state = { + worker_id: state_dict["state"][worker_id]["dataset_state"][worker_id] + for worker_id in state_dict["state"].keys() + } + + dataset.load_state_dict(_rotate_worker_indices(dataset_state, latest_worker_id, num_workers)) else: raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.") - def _rotate_worker_indices(state: Dict[int, Any], latest_worker_id: int, num_workers: int) -> Dict[int, Any]: """This function is used to rotate the worker indices based on the `latest_worker_id` the training failed on.""" diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 25f80ec6817a5..b7b70e289b6b7 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -50,6 +50,7 @@ CaptureMapDataset, FastForwardSampler, MergedIteratorState, + _reload_dataloader_state_dict, ) from pytorch_lightning.utilities.enums import _FaultTolerantMode, AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException @@ -1350,4 +1351,4 @@ def test_stateful_workers(num_workers): assert not hasattr(DataLoader, "_ori_get_iterator") assert DataLoader._get_iterator == _get_iterator_fn - data_fetcher.teardown() + data_fetcher.teardown() \ No newline at end of file From 2da16746b63f4a7e1d9cc87c1f8a400d03e57e83 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 23 Nov 2021 12:32:10 +0000 Subject: [PATCH 35/42] update --- tests/utilities/test_auto_restart.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index b7b70e289b6b7..979bfcc4ee883 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -19,6 +19,7 @@ from collections.abc import Iterable from contextlib import suppress from copy import deepcopy +from dataclasses import asdict from typing import List, Optional from unittest import mock from unittest.mock import ANY @@ -42,6 +43,7 @@ _dataloader_to_state_dict, _MultiProcessingDataLoaderIterStateful, _patch_dataloader_get_iterators, + _reload_dataloader_state_dict, _rotate_worker_indices, _SingleProcessDataLoaderIterStateful, _SupportsStateDict, @@ -50,7 +52,6 @@ CaptureMapDataset, FastForwardSampler, MergedIteratorState, - _reload_dataloader_state_dict, ) from pytorch_lightning.utilities.enums import _FaultTolerantMode, AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException @@ -1290,7 +1291,7 @@ def state_dict(self): return {"counter": self.counter} def load_state_dict(self, state_dict): - self.counter = state_dict["counter"] + self.counter = state_dict[0]["counter"] @pytest.mark.parametrize("num_workers", [0]) @@ -1320,7 +1321,10 @@ def test_stateful_workers(num_workers): assert isinstance(dataloader_iter, worker_type) next(data_fetcher_iter) - state = data_fetcher.dataloader_iter.state.state + + reloaded_state = deepcopy(data_fetcher.dataloader_iter.state) + + state = reloaded_state.state assert state[0].dataset_state == {0: {"counter": 1}} assert state[0].sampler_state["sampler"] == {"counter": 1} @@ -1351,4 +1355,6 @@ def test_stateful_workers(num_workers): assert not hasattr(DataLoader, "_ori_get_iterator") assert DataLoader._get_iterator == _get_iterator_fn - data_fetcher.teardown() \ No newline at end of file + _reload_dataloader_state_dict(dataloader, asdict(reloaded_state)) + assert dataloader.sampler.counter == 1 + data_fetcher.teardown() From dbcfa65ca12c8f55f6eeab798c09a881d5fcfce3 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 23 Nov 2021 12:35:03 +0000 Subject: [PATCH 36/42] update changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index fd9cf54a9730e..79c7ce6f6eb76 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). * Add `_FaultTolerantMode` enum used to track different supported fault tolerant modes ([#10645](https://github.com/PyTorchLightning/pytorch-lightning/issues/10645)) * Add a `_rotate_worker_indices` utility to reload the state according the latest worker ([#10647](https://github.com/PyTorchLightning/pytorch-lightning/issues/10647)) * Add stateful workers ([#10674](https://github.com/PyTorchLightning/pytorch-lightning/issues/10674)) + * Add logic to reload the states across data loading components ([#10699](https://github.com/PyTorchLightning/pytorch-lightning/issues/10699)) - From ae18166638bb016e581861e8374be11a3b67e3f0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 23 Nov 2021 12:36:21 +0000 Subject: [PATCH 37/42] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- pytorch_lightning/loops/epoch/evaluation_epoch_loop.py | 2 +- pytorch_lightning/trainer/supporters.py | 2 +- pytorch_lightning/utilities/auto_restart.py | 1 + 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py index 6347882acb6cf..2fc572ea252e6 100644 --- a/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py +++ b/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py @@ -22,7 +22,7 @@ from pytorch_lightning.loops.base import Loop from pytorch_lightning.loops.utilities import _update_dataloader_iter from pytorch_lightning.trainer.progress import BatchProgress -from pytorch_lightning.utilities.auto_restart import MergedIteratorState, _reload_dataloader_state_dict +from pytorch_lightning.utilities.auto_restart import _reload_dataloader_state_dict, MergedIteratorState from pytorch_lightning.utilities.fetching import AbstractDataFetcher, DataFetcher from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT diff --git a/pytorch_lightning/trainer/supporters.py b/pytorch_lightning/trainer/supporters.py index 48ebf82fcbec0..d65bc08e6689e 100644 --- a/pytorch_lightning/trainer/supporters.py +++ b/pytorch_lightning/trainer/supporters.py @@ -24,9 +24,9 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections from pytorch_lightning.utilities.auto_restart import ( + _reload_dataloader_state_dict, MergedIteratorState, patch_dataloader_iterator, - _reload_dataloader_state_dict, ) from pytorch_lightning.utilities.data import get_len from pytorch_lightning.utilities.exceptions import MisconfigurationException diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index e3996449041e2..67df8e716815f 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -628,6 +628,7 @@ def _reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, else: raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.") + def _rotate_worker_indices(state: Dict[int, Any], latest_worker_id: int, num_workers: int) -> Dict[int, Any]: """This function is used to rotate the worker indices based on the `latest_worker_id` the training failed on.""" From a5279295ee78c1deea093a766321e6e97a34d42d Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 23 Nov 2021 12:37:00 +0000 Subject: [PATCH 38/42] update --- tests/utilities/test_auto_restart.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 44bec66827fc2..1c27d582cc6a5 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -1355,5 +1355,5 @@ def test_stateful_workers(num_workers): assert DataLoader._get_iterator == _get_iterator_fn _reload_dataloader_state_dict(dataloader, asdict(reloaded_state)) - assert dataloader.sampler.counter == 1 + assert dataloader.sampler.counter == dataloader.dataset.counter == 1 data_fetcher.teardown() From 51cf75b38a8cd8852915a449aa0682126e3cfd6c Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 23 Nov 2021 13:00:44 +0000 Subject: [PATCH 39/42] update --- pytorch_lightning/utilities/auto_restart.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 67df8e716815f..87eb1b00740ca 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -33,7 +33,6 @@ import pytorch_lightning as pl from pytorch_lightning.utilities.enums import _FaultTolerantMode, AutoRestartBatchKeys from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _fault_tolerant_training class FastForwardSampler(Sampler): From 35644b8587f2fecbc1cfd5403a93d4b116083d26 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 23 Nov 2021 13:06:00 +0000 Subject: [PATCH 40/42] update on comments --- pytorch_lightning/utilities/auto_restart.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 87eb1b00740ca..35060f79ed92a 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -601,6 +601,9 @@ def _reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, elif fault_tolerant_mode.is_manual: + # In manual mode, we don't wrap the user objects with `CaptureMapDataset` or `CaptureIterableDataset` + # therefore, we need to reload the states manually. + latest_worker_id = state_dict["latest_worker_id"] num_workers = state_dict["state"][latest_worker_id]["num_workers"] sampler_state = state_dict["state"][latest_worker_id]["sampler_state"] @@ -609,12 +612,12 @@ def _reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, obj = getattr(dataloader, k) if not isinstance(obj, _SupportsStateDict): raise MisconfigurationException( - f"The DataLoader attribute should have a `load_state_dict` method. Found {obj}" + f"The DataLoader attribute {k}:{obj} should have a `load_state_dict` method." ) obj.load_state_dict(sampler_state[k]) - if not hasattr(dataset, "load_state_dict"): + if not isinstance(dataset, _SupportsStateDict): return dataset_state = { @@ -668,7 +671,6 @@ def _store_sampler_state(self) -> None: for k, v in self._loader.__dict__.items() if isinstance(v, _SupportsStateDict) and k != "dataset" } - self.__accumulate_state(sampler_state) def _next_index(self) -> Any: From fe3afd8c427343a64869b7b9f65f19d82128f6a4 Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 23 Nov 2021 15:25:54 +0000 Subject: [PATCH 41/42] update --- pytorch_lightning/utilities/auto_restart.py | 113 +++++++++++--------- 1 file changed, 63 insertions(+), 50 deletions(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 583a4c202f8e0..283f8125ba1f9 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -563,74 +563,87 @@ def _add_capture_metadata_collate(dataloader: DataLoader) -> None: ) -def _reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None: - """Utility to reload state_dict within dataloader for fault tolerance.""" +def _reload_dataloader_state_dict_automatic_map_dataset(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None: + iterator_state = state_dict["state"][0] - fault_tolerant_mode = _FaultTolerantMode.detect_current_mode() + if not isinstance(iterator_state, IteratorState): + iterator_state = IteratorState.from_state_dict(iterator_state) + + # reload sampler state + ff_sampler = _find_fast_forward_samplers(dataloader) + ff_sampler.load_state_dict(iterator_state.sampler_state) + + # reload dataset state + dataloader.dataset.load_state_dict( + iterator_state.dataset_state, + latest_worker_id=state_dict["latest_worker_id"], + num_workers=iterator_state.num_workers, + ) - if not fault_tolerant_mode.is_enabled: - return +def _reload_dataloader_state_dict_automatic_iterable_dataset( + dataset: CaptureIterableDataset, state_dict: Dict[str, Any] +) -> None: + dataset.load_state_dict( + {sampler_name: state[0]["sampler_state"] for sampler_name, state in state_dict["state"].items()} + ) + + +def _reload_dataloader_state_dict_automatic(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None: dataset = dataloader.dataset + if isinstance(dataset, CaptureMapDataset): + _reload_dataloader_state_dict_automatic_map_dataset(dataloader, state_dict) - if fault_tolerant_mode.is_automatic: - if isinstance(dataset, CaptureMapDataset): - iterator_state = state_dict["state"][0] + elif isinstance(dataset, CaptureIterableDataset): + _reload_dataloader_state_dict_automatic_iterable_dataset(dataset, state_dict) - if not isinstance(iterator_state, IteratorState): - iterator_state = IteratorState.from_state_dict(iterator_state) - # reload sampler state - ff_sampler = _find_fast_forward_samplers(dataloader) - ff_sampler.load_state_dict(iterator_state.sampler_state) +def _reload_dataloader_state_dict_manual(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None: + # In manual mode, we don't wrap the user objects with `CaptureMapDataset` or `CaptureIterableDataset` + # therefore, we need to reload the states manually. - # reload dataset state - dataset.load_state_dict( - iterator_state.dataset_state, - latest_worker_id=state_dict["latest_worker_id"], - num_workers=iterator_state.num_workers, - ) + latest_worker_id = state_dict["latest_worker_id"] + num_workers = state_dict["state"][latest_worker_id]["num_workers"] + sampler_state = state_dict["state"][latest_worker_id].get("sampler_state", None) + if sampler_state: + # `sampler_state` keys contain all the DataLoader attribute names + # which matched `_SupportsStateDict` API interface while collecting the `state_dict`. + for dataloader_attr_name in sampler_state: + obj = getattr(dataloader, dataloader_attr_name) + if not isinstance(obj, _SupportsStateDict): + raise MisconfigurationException( + f"The DataLoader attribute {dataloader_attr_name}:{obj} should have a `load_state_dict` method." + ) - elif isinstance(dataset, CaptureIterableDataset): - dataset.load_state_dict( - {sampler_name: state[0]["sampler_state"] for sampler_name, state in state_dict["state"].items()} - ) + obj.load_state_dict(sampler_state[dataloader_attr_name]) - else: - raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.") + if not isinstance(dataloader.dataset, _SupportsStateDict): + return - elif fault_tolerant_mode.is_manual: + dataset_state = { + worker_id: state_dict["state"][worker_id]["dataset_state"][worker_id] + for worker_id in state_dict["state"].keys() + } - # In manual mode, we don't wrap the user objects with `CaptureMapDataset` or `CaptureIterableDataset` - # therefore, we need to reload the states manually. - - latest_worker_id = state_dict["latest_worker_id"] - num_workers = state_dict["state"][latest_worker_id]["num_workers"] - sampler_state = state_dict["state"][latest_worker_id].get("sampler_state", None) - if sampler_state: - # `sampler_state` keys contain all the DataLoader attribute names - # which matched `_SupportsStateDict` API interface while collecting the `state_dict`. - for dataloader_attr_name in sampler_state: - obj = getattr(dataloader, dataloader_attr_name) - if not isinstance(obj, _SupportsStateDict): - raise MisconfigurationException( - f"The DataLoader attribute {dataloader_attr_name}:{obj} should have a `load_state_dict` method." - ) + dataloader.dataset.load_state_dict(_rotate_worker_indices(dataset_state, latest_worker_id, num_workers)) - obj.load_state_dict(sampler_state[dataloader_attr_name]) - if not isinstance(dataset, _SupportsStateDict): - return +def _reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None: + """Utility to reload state_dict within dataloader for fault tolerance.""" - dataset_state = { - worker_id: state_dict["state"][worker_id]["dataset_state"][worker_id] - for worker_id in state_dict["state"].keys() - } + fault_tolerant_mode = _FaultTolerantMode.detect_current_mode() + + if not fault_tolerant_mode.is_enabled: + return - dataset.load_state_dict(_rotate_worker_indices(dataset_state, latest_worker_id, num_workers)) + if fault_tolerant_mode.is_automatic: + _reload_dataloader_state_dict_automatic(dataloader, state_dict) + + elif fault_tolerant_mode.is_manual: + _reload_dataloader_state_dict_manual(dataloader, state_dict) else: - raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.") + raise MisconfigurationException("This shouldn't be happening.") def _rotate_worker_indices(state: Dict[int, Any], latest_worker_id: int, num_workers: int) -> Dict[int, Any]: From 421d869f81677af9c7b6966151f315006e8eec7e Mon Sep 17 00:00:00 2001 From: tchaton Date: Tue, 23 Nov 2021 15:26:39 +0000 Subject: [PATCH 42/42] update --- pytorch_lightning/utilities/auto_restart.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 283f8125ba1f9..3fa32bc72da5e 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -597,6 +597,9 @@ def _reload_dataloader_state_dict_automatic(dataloader: DataLoader, state_dict: elif isinstance(dataset, CaptureIterableDataset): _reload_dataloader_state_dict_automatic_iterable_dataset(dataset, state_dict) + else: + raise MisconfigurationException("This shouldn't be happening. Please, open an issue.") + def _reload_dataloader_state_dict_manual(dataloader: DataLoader, state_dict: Dict[str, Any]) -> None: # In manual mode, we don't wrap the user objects with `CaptureMapDataset` or `CaptureIterableDataset` @@ -643,7 +646,7 @@ def _reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, _reload_dataloader_state_dict_manual(dataloader, state_dict) else: - raise MisconfigurationException("This shouldn't be happening.") + raise MisconfigurationException("This shouldn't be happening. Please, open an issue.") def _rotate_worker_indices(state: Dict[int, Any], latest_worker_id: int, num_workers: int) -> Dict[int, Any]: