Skip to content

Commit

Permalink
Added getstate/setstate method for torch.save serialization (#4127)
Browse files Browse the repository at this point in the history
* Added getstate/setstate method for torch.save serialization, added additional Optional Typing to results object

* Added tests to ensure torch.save does not fail

* Added flags to ensure compatible ddp cpu environment

* Removed torch version check due to minimum already being 1.3, reduced epochs for speed

* Moved tests to separate file

* Update to accelerator, move to ddp_spawn to prevent hanging ddp
  • Loading branch information
SeanNaren committed Oct 13, 2020
1 parent 01402e3 commit 98eb736
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 1 deletion.
14 changes: 14 additions & 0 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,20 @@ def init_ddp_connection(
torch_backend, rank=global_rank, world_size=world_size
)

def __getstate__(self):
return {
'trainer': self.trainer,
'nickname': self.nickname,
'cluster_environment': self.cluster_environment,
'dist': self.dist
}

def __setstate__(self, d):
self.trainer = d['trainer']
self.nickname = d['nickname']
self.cluster_environment = d['cluster_environment']
self.dist = d['dist']


# TODO: allow user to compare with string even internaly we shall use these Enum to prevent typos...
class BackendType(Enum):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(self, *args, **kwargs):
# optionally can be set by user
self._example_input_array = None
self._datamodule = None
self._results: Result = None
self._results: Optional[Result] = None
self._current_fx_name = ''

def optimizers(self):
Expand Down
7 changes: 7 additions & 0 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
from pytorch_lightning.metrics import Metric


class Result(Dict):
def __init__(
self,
Expand Down Expand Up @@ -89,6 +90,12 @@ def __setattr__(self, key: str, val: Union[Tensor, Any]):

self[key] = val

def __getstate__(self):
return self

def __setstate__(self, d):
self.update(d)

def _assert_tensor_metric(self, name: str, potential_metric: Union[bool, Tensor, None, Any]):
if potential_metric is not None and not isinstance(potential_metric, bool):
assert isinstance(potential_metric, Tensor), f'{name} must be a torch.Tensor'
Expand Down
1 change: 1 addition & 0 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from distutils.version import LooseVersion
from unittest.mock import MagicMock, Mock

import yaml
Expand Down
76 changes: 76 additions & 0 deletions tests/checkpointing/test_torch_saving.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import platform

import pytest
import torch

from pytorch_lightning import Trainer
from tests.base import EvalModelTemplate


def test_model_torch_save(tmpdir):
"""Test to ensure torch save does not fail for model and trainer."""
model = EvalModelTemplate()
num_epochs = 1
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=num_epochs,
)
temp_path = os.path.join(tmpdir, 'temp.pt')
trainer.fit(model)

# Ensure these do not fail
torch.save(trainer.model, temp_path)
torch.save(trainer, temp_path)


@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
def test_model_torch_save_ddp_cpu(tmpdir):
"""Test to ensure torch save does not fail for model and trainer using cpu ddp."""
model = EvalModelTemplate()
num_epochs = 1
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=num_epochs,
accelerator="ddp_cpu",
num_processes=2,
)
temp_path = os.path.join(tmpdir, 'temp.pt')
trainer.fit(model)

# Ensure these do not fail
torch.save(trainer.model, temp_path)
torch.save(trainer, temp_path)


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_model_torch_save_ddp_cuda(tmpdir):
"""Test to ensure torch save does not fail for model and trainer using gpu ddp."""
model = EvalModelTemplate()
num_epochs = 1
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=num_epochs,
accelerator="ddp_spawn",
gpus=2
)
temp_path = os.path.join(tmpdir, 'temp.pt')
trainer.fit(model)

# Ensure these do not fail
torch.save(trainer.model, temp_path)
torch.save(trainer, temp_path)

0 comments on commit 98eb736

Please sign in to comment.