diff --git a/CHANGELOG.md b/CHANGELOG.md index e678be0e965fe..169308718e569 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fault Tolerant Manual + * Add `_SupportsStateDict` protocol to detect if classes are stateful ([#10646](https://github.com/PyTorchLightning/pytorch-lightning/issues/10646)) * Add `_FaultTolerantMode` enum used to track different supported fault tolerant modes ([#10645](https://github.com/PyTorchLightning/pytorch-lightning/issues/10645)) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 228e16e4e9c8c..23583852f4f39 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -22,6 +22,7 @@ import torch from torch.utils.data import Dataset, get_worker_info, Sampler from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader, IterableDataset +from typing_extensions import Protocol, runtime_checkable import pytorch_lightning as pl from pytorch_lightning.utilities.enums import AutoRestartBatchKeys @@ -570,3 +571,14 @@ def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, A else: raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.") + + +@runtime_checkable +class _SupportsStateDict(Protocol): + """This class is used to detect if an object is stateful using `isinstance(obj, _SupportsStateDict)`.""" + + def state_dict(self) -> Dict[str, Any]: + ... + + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + ... diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index b9eb97cb42ae8..5152874b39469 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -40,6 +40,7 @@ _add_capture_metadata_collate, _dataloader_load_state_dict, _dataloader_to_state_dict, + _SupportsStateDict, CaptureIterableDataset, CaptureMapDataset, FastForwardSampler, @@ -1195,6 +1196,29 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on assert "dataloader_state_dict" in state_dict +def test_supports_state_dict_protocol(): + class StatefulClass: + def state_dict(self): + pass + + def load_state_dict(self, state_dict): + pass + + assert isinstance(StatefulClass(), _SupportsStateDict) + + class NotStatefulClass: + def state_dict(self): + pass + + assert not isinstance(NotStatefulClass(), _SupportsStateDict) + + class NotStateful2Class: + def load_state_dict(self, state_dict): + pass + + assert not isinstance(NotStateful2Class(), _SupportsStateDict) + + def test_fault_tolerant_mode_enum(): with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "0"}): assert _FaultTolerantMode.DISABLED == _FaultTolerantMode.detect_current_mode()