From 24c82451a64850d746accadb88d789a06d4148c9 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 19 Nov 2021 18:40:30 +0000 Subject: [PATCH 1/8] update --- pytorch_lightning/utilities/auto_restart.py | 13 +++++++++ tests/utilities/test_auto_restart.py | 30 +++++++++++++++++++++ 2 files changed, 43 insertions(+) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index ef52717636d90..11593c779c260 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import inspect from copy import deepcopy from dataclasses import dataclass, field from functools import partial, wraps @@ -571,3 +572,15 @@ def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, A else: raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.") + + +def is_obj_stateful(obj: Any) -> bool: + """In order to be stateful, an object should implement a ``state_dict`` and ``load_state_dict`` method.""" + load_state_dict_fn = getattr(obj, "load_state_dict", None) + if not isinstance(load_state_dict_fn, Callable): + return False + params = inspect.signature(load_state_dict_fn).parameters + if len(params) == 0: + return False + + return isinstance(getattr(obj, "state_dict", None), Callable) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index b36a9d1d76941..6717d8f09d4b5 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -42,6 +42,7 @@ CaptureIterableDataset, CaptureMapDataset, FastForwardSampler, + is_obj_stateful, MergedIteratorState, ) from pytorch_lightning.utilities.enums import AutoRestartBatchKeys @@ -1192,3 +1193,32 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on assert "dataloader_state_dict" not in state_dict else: assert "dataloader_state_dict" in state_dict + + +def test_is_obj_stateful(): + class StatefulClass: + def state_dict(self): + pass + + def load_state_dict(self, state_dict): + pass + + obj = StatefulClass() + assert is_obj_stateful(obj) + + class NotStatefulClass: + def state_dict(self): + pass + + def load_state_dict(self): + pass + + obj = NotStatefulClass() + assert not is_obj_stateful(obj) + + class NotStatefulClass: + def load_state_dict(self): + pass + + obj = NotStatefulClass() + assert not is_obj_stateful(obj) From bcd5569ef5e404e69430f089f613ac16c50eb6a9 Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 19 Nov 2021 18:42:28 +0000 Subject: [PATCH 2/8] update --- CHANGELOG.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 96830878d84ae..30c16a4d67637 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,7 +12,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added a flag `SLURMEnvironment(auto_requeue=True|False)` to control whether Lightning handles the requeuing ([#10601](https://github.com/PyTorchLightning/pytorch-lightning/issues/10601)) -- +- Fault Tolerant Manual + * Add `is_obj_stateful` utility to detect if user data loading components are stateful ([#10646](https://github.com/PyTorchLightning/pytorch-lightning/issues/10646)) - From 8d3844dc782a62613956fb6fe177b68872fe9a7f Mon Sep 17 00:00:00 2001 From: tchaton Date: Fri, 19 Nov 2021 18:44:03 +0000 Subject: [PATCH 3/8] update --- CHANGELOG.md | 2 +- pytorch_lightning/utilities/auto_restart.py | 2 +- tests/utilities/test_auto_restart.py | 10 +++++----- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 30c16a4d67637..4f59e996365a5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,7 +13,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fault Tolerant Manual - * Add `is_obj_stateful` utility to detect if user data loading components are stateful ([#10646](https://github.com/PyTorchLightning/pytorch-lightning/issues/10646)) + * Add `_is_obj_stateful` utility to detect if user data loading components are stateful ([#10646](https://github.com/PyTorchLightning/pytorch-lightning/issues/10646)) - diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 11593c779c260..0310aceb4bc49 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -574,7 +574,7 @@ def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, A raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.") -def is_obj_stateful(obj: Any) -> bool: +def _is_obj_stateful(obj: Any) -> bool: """In order to be stateful, an object should implement a ``state_dict`` and ``load_state_dict`` method.""" load_state_dict_fn = getattr(obj, "load_state_dict", None) if not isinstance(load_state_dict_fn, Callable): diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 6717d8f09d4b5..7d75b7605ce0a 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -39,10 +39,10 @@ _add_capture_metadata_collate, _dataloader_load_state_dict, _dataloader_to_state_dict, + _is_obj_stateful, CaptureIterableDataset, CaptureMapDataset, FastForwardSampler, - is_obj_stateful, MergedIteratorState, ) from pytorch_lightning.utilities.enums import AutoRestartBatchKeys @@ -1195,7 +1195,7 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on assert "dataloader_state_dict" in state_dict -def test_is_obj_stateful(): +def test__is_obj_stateful(): class StatefulClass: def state_dict(self): pass @@ -1204,7 +1204,7 @@ def load_state_dict(self, state_dict): pass obj = StatefulClass() - assert is_obj_stateful(obj) + assert _is_obj_stateful(obj) class NotStatefulClass: def state_dict(self): @@ -1214,11 +1214,11 @@ def load_state_dict(self): pass obj = NotStatefulClass() - assert not is_obj_stateful(obj) + assert not _is_obj_stateful(obj) class NotStatefulClass: def load_state_dict(self): pass obj = NotStatefulClass() - assert not is_obj_stateful(obj) + assert not _is_obj_stateful(obj) From d04596df734a9c10d68f6ce8a301f9ab783574aa Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 22 Nov 2021 17:13:20 +0100 Subject: [PATCH 4/8] Use `Protocol` --- pytorch_lightning/utilities/auto_restart.py | 18 +++++++----------- tests/utilities/test_auto_restart.py | 13 +++++-------- 2 files changed, 12 insertions(+), 19 deletions(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 0310aceb4bc49..4b0ccddace55c 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -11,8 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import inspect from copy import deepcopy from dataclasses import dataclass, field from functools import partial, wraps @@ -24,6 +22,7 @@ import torch from torch.utils.data import Dataset, get_worker_info, Sampler from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader, IterableDataset +from typing_extensions import Protocol, runtime_checkable import pytorch_lightning as pl from pytorch_lightning.utilities.enums import AutoRestartBatchKeys @@ -574,13 +573,10 @@ def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, A raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.") -def _is_obj_stateful(obj: Any) -> bool: - """In order to be stateful, an object should implement a ``state_dict`` and ``load_state_dict`` method.""" - load_state_dict_fn = getattr(obj, "load_state_dict", None) - if not isinstance(load_state_dict_fn, Callable): - return False - params = inspect.signature(load_state_dict_fn).parameters - if len(params) == 0: - return False +@runtime_checkable +class _SupportsStateDict(Protocol): + def state_dict(self) -> Dict[str, Any]: + ... - return isinstance(getattr(obj, "state_dict", None), Callable) + def load_state_dict(self, state_dict: Dict[str, Any]) -> None: + ... diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 0c35ba651cdb6..6392249300961 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -40,7 +40,7 @@ _add_capture_metadata_collate, _dataloader_load_state_dict, _dataloader_to_state_dict, - _is_obj_stateful, + _SupportsStateDict, CaptureIterableDataset, CaptureMapDataset, FastForwardSampler, @@ -1196,7 +1196,7 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on assert "dataloader_state_dict" in state_dict -def test_is_obj_stateful(): +def test_supports_state_dict_protocol(): class StatefulClass: def state_dict(self): pass @@ -1205,24 +1205,21 @@ def load_state_dict(self, state_dict): pass obj = StatefulClass() - assert _is_obj_stateful(obj) + assert isinstance(obj, _SupportsStateDict) class NotStatefulClass: def state_dict(self): pass - def load_state_dict(self): - pass - obj = NotStatefulClass() - assert not _is_obj_stateful(obj) + assert not isinstance(obj, _SupportsStateDict) class NotStateful2Class: def load_state_dict(self, state_dict): pass obj = NotStateful2Class() - assert not _is_obj_stateful(obj) + assert not isinstance(obj, _SupportsStateDict) def test_fault_tolerant_mode_enum(): From ff7b8367941e4b493c19c6f5d9fcc845241cb955 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Mon, 22 Nov 2021 17:20:45 +0100 Subject: [PATCH 5/8] Simplify test --- tests/utilities/test_auto_restart.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/tests/utilities/test_auto_restart.py b/tests/utilities/test_auto_restart.py index 6392249300961..5152874b39469 100644 --- a/tests/utilities/test_auto_restart.py +++ b/tests/utilities/test_auto_restart.py @@ -1204,22 +1204,19 @@ def state_dict(self): def load_state_dict(self, state_dict): pass - obj = StatefulClass() - assert isinstance(obj, _SupportsStateDict) + assert isinstance(StatefulClass(), _SupportsStateDict) class NotStatefulClass: def state_dict(self): pass - obj = NotStatefulClass() - assert not isinstance(obj, _SupportsStateDict) + assert not isinstance(NotStatefulClass(), _SupportsStateDict) class NotStateful2Class: def load_state_dict(self, state_dict): pass - obj = NotStateful2Class() - assert not isinstance(obj, _SupportsStateDict) + assert not isinstance(NotStateful2Class(), _SupportsStateDict) def test_fault_tolerant_mode_enum(): From a5698e655020b8ed0698fb765cb1c7490da3a00a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 22 Nov 2021 17:21:14 +0100 Subject: [PATCH 6/8] Update CHANGELOG.md --- CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c19f254007628..169308718e569 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,7 +13,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fault Tolerant Manual - * Add `_is_obj_stateful` utility to detect if user data loading components are stateful ([#10646](https://github.com/PyTorchLightning/pytorch-lightning/issues/10646)) + * Add `_SupportsStateDict` protocol to detect if classes are stateful ([#10646](https://github.com/PyTorchLightning/pytorch-lightning/issues/10646)) * Add `_FaultTolerantMode` enum used to track different supported fault tolerant modes ([#10645](https://github.com/PyTorchLightning/pytorch-lightning/issues/10645)) From d0aae6e45a01d2292b23cb4fdd90dd0f99279ff7 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 22 Nov 2021 17:58:24 +0000 Subject: [PATCH 7/8] update --- pytorch_lightning/utilities/auto_restart.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index 4b0ccddace55c..d06d5f0a02c05 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -575,6 +575,9 @@ def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, A @runtime_checkable class _SupportsStateDict(Protocol): + + """This class is used to detect if an object is stateful using `isinstance(obj, _SupportsStateDict)`.""" + def state_dict(self) -> Dict[str, Any]: ... From 529c2719c8546ee8ac058782b871eb7c1d675e8c Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Mon, 22 Nov 2021 18:19:49 +0000 Subject: [PATCH 8/8] Update pytorch_lightning/utilities/auto_restart.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- pytorch_lightning/utilities/auto_restart.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_lightning/utilities/auto_restart.py b/pytorch_lightning/utilities/auto_restart.py index d06d5f0a02c05..23583852f4f39 100644 --- a/pytorch_lightning/utilities/auto_restart.py +++ b/pytorch_lightning/utilities/auto_restart.py @@ -575,7 +575,6 @@ def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, A @runtime_checkable class _SupportsStateDict(Protocol): - """This class is used to detect if an object is stateful using `isinstance(obj, _SupportsStateDict)`.""" def state_dict(self) -> Dict[str, Any]: