Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fault Tolerant Manual: Add _SupportsStateDict to validate a class is stateful #10646

Merged
merged 9 commits into from Nov 22, 2021
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Expand Up @@ -13,6 +13,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).


- Fault Tolerant Manual
* Add `_SupportsStateDict` protocol to detect if classes are stateful ([#10646](https://github.com/PyTorchLightning/pytorch-lightning/issues/10646))
* Add `_FaultTolerantMode` enum used to track different supported fault tolerant modes ([#10645](https://github.com/PyTorchLightning/pytorch-lightning/issues/10645))


Expand Down
13 changes: 13 additions & 0 deletions pytorch_lightning/utilities/auto_restart.py
Expand Up @@ -22,6 +22,7 @@
import torch
from torch.utils.data import Dataset, get_worker_info, Sampler
from torch.utils.data.dataloader import _MultiProcessingDataLoaderIter, DataLoader, IterableDataset
from typing_extensions import Protocol, runtime_checkable

import pytorch_lightning as pl
from pytorch_lightning.utilities.enums import AutoRestartBatchKeys
Expand Down Expand Up @@ -570,3 +571,15 @@ def reload_dataloader_state_dict(dataloader: DataLoader, state_dict: Dict[str, A

else:
raise MisconfigurationException("This shouldn't happen. Please, open an issue on PyTorch Lightning Github.")


@runtime_checkable
class _SupportsStateDict(Protocol):

tchaton marked this conversation as resolved.
Show resolved Hide resolved
"""This class is used to detect if an object is stateful using `isinstance(obj, _SupportsStateDict)`."""

def state_dict(self) -> Dict[str, Any]:
...
tchaton marked this conversation as resolved.
Show resolved Hide resolved

def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
...
24 changes: 24 additions & 0 deletions tests/utilities/test_auto_restart.py
Expand Up @@ -40,6 +40,7 @@
_add_capture_metadata_collate,
_dataloader_load_state_dict,
_dataloader_to_state_dict,
_SupportsStateDict,
CaptureIterableDataset,
CaptureMapDataset,
FastForwardSampler,
Expand Down Expand Up @@ -1195,6 +1196,29 @@ def test_auto_restart_under_signal(on_last_batch, val_check_interval, failure_on
assert "dataloader_state_dict" in state_dict


def test_supports_state_dict_protocol():
tchaton marked this conversation as resolved.
Show resolved Hide resolved
class StatefulClass:
def state_dict(self):
pass

def load_state_dict(self, state_dict):
pass

assert isinstance(StatefulClass(), _SupportsStateDict)

class NotStatefulClass:
def state_dict(self):
pass

assert not isinstance(NotStatefulClass(), _SupportsStateDict)

class NotStateful2Class:
def load_state_dict(self, state_dict):
pass

assert not isinstance(NotStateful2Class(), _SupportsStateDict)


def test_fault_tolerant_mode_enum():
with mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "0"}):
assert _FaultTolerantMode.DISABLED == _FaultTolerantMode.detect_current_mode()
Expand Down