Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added getstate/setstate method for torch.save serialization #4127

Merged
merged 6 commits into from
Oct 13, 2020
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 14 additions & 0 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,20 @@ def init_ddp_connection(
torch_backend, rank=global_rank, world_size=world_size
)

def __getstate__(self):
return {
'trainer': self.trainer,
'nickname': self.nickname,
'cluster_environment': self.cluster_environment,
'dist': self.dist
}

def __setstate__(self, d):
self.trainer = d['trainer']
self.nickname = d['nickname']
self.cluster_environment = d['cluster_environment']
self.dist = d['dist']


# TODO: allow user to compare with string even internaly we shall use these Enum to prevent typos...
class BackendType(Enum):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def __init__(self, *args, **kwargs):
# optionally can be set by user
self._example_input_array = None
self._datamodule = None
self._results: Result = None
self._results: Optional[Result] = None
self._current_fx_name = ''

def optimizers(self):
Expand Down
7 changes: 7 additions & 0 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
from pytorch_lightning.metrics import Metric


class Result(Dict):
def __init__(
self,
Expand Down Expand Up @@ -89,6 +90,12 @@ def __setattr__(self, key: str, val: Union[Tensor, Any]):

self[key] = val

def __getstate__(self):
return self

def __setstate__(self, d):
self.update(d)

def _assert_tensor_metric(self, name: str, potential_metric: Union[bool, Tensor, None, Any]):
if potential_metric is not None and not isinstance(potential_metric, bool):
assert isinstance(potential_metric, Tensor), f'{name} must be a torch.Tensor'
Expand Down
59 changes: 59 additions & 0 deletions tests/checkpointing/test_model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from distutils.version import LooseVersion
from unittest.mock import MagicMock, Mock

import yaml
Expand Down Expand Up @@ -512,3 +513,61 @@ def validation_epoch_end(self, outputs):

# check that last one is also the best one
assert trainer.dev_debugger.checkpoint_callback_history[-1]['epoch'] == len(monitor) - 1


def test_model_torch_save(tmpdir):
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
"""Test to ensure torch save does not fail for model and trainer."""
model = EvalModelTemplate()
num_epochs = 4
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=num_epochs,
)
temp_path = os.path.join(tmpdir, 'temp.pt')
trainer.fit(model)

# Ensure these do not fail
torch.save(trainer.model, temp_path)
torch.save(trainer, temp_path)


@pytest.mark.skipif(platform.system() == "Windows",
reason="Distributed training is not supported on Windows")
@pytest.mark.skipif((platform.system() == "Darwin" and
LooseVersion(torch.__version__) < LooseVersion("1.3.0")),
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
reason="Distributed training is not supported on MacOS before Torch 1.3.0")
SeanNaren marked this conversation as resolved.
Show resolved Hide resolved
def test_model_torch_save_ddp_cpu(tmpdir):
"""Test to ensure torch save does not fail for model and trainer using cpu ddp."""
model = EvalModelTemplate()
num_epochs = 4
rohitgr7 marked this conversation as resolved.
Show resolved Hide resolved
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=num_epochs,
distributed_backend="ddp_cpu",
num_processes=2,
)
temp_path = os.path.join(tmpdir, 'temp.pt')
trainer.fit(model)

# Ensure these do not fail
torch.save(trainer.model, temp_path)
torch.save(trainer, temp_path)


@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
def test_model_torch_save_ddp_cuda(tmpdir):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like this guy is hanging...

"""Test to ensure torch save does not fail for model and trainer using gpu ddp."""
model = EvalModelTemplate()
num_epochs = 4
trainer = Trainer(
default_root_dir=tmpdir,
max_epochs=num_epochs,
distributed_backend="ddp",
gpus=2
)
temp_path = os.path.join(tmpdir, 'temp.pt')
trainer.fit(model)

# Ensure these do not fail
torch.save(trainer.model, temp_path)
torch.save(trainer, temp_path)