Skip to content

Commit

Permalink
Bugfix/4156 filter hparams for yaml - fsspec (#4158)
Browse files Browse the repository at this point in the history
* add test

* fix

* sleepy boy

* chlog

* Apply suggestions from code review

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>

Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
  • Loading branch information
Borda and justusschock committed Oct 15, 2020
1 parent 72f1976 commit 4204ef7
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed `hparams` saving - save the state when `save_hyperparameters()` is called [in `__init__`] ([#4163](https://github.com/PyTorchLightning/pytorch-lightning/pull/4163))

- Fixed runtime failure while exporting `hparams` to yaml ([#4158](https://github.com/PyTorchLightning/pytorch-lightning/pull/4158))


## [1.0.1] - 2020-10-14
Expand Down
16 changes: 14 additions & 2 deletions pytorch_lightning/core/saving.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import os
from argparse import Namespace
from typing import Union, Dict, Any, Optional, Callable, MutableMapping
from warnings import warn

import fsspec
import torch
Expand Down Expand Up @@ -372,10 +373,21 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
OmegaConf.save(OmegaConf.create(hparams), fp, resolve=True)
return

# saving the standard way
assert isinstance(hparams, dict)
hparams_allowed = {}
# drop paramaters which contain some strange datatypes as fsspec
for k, v in hparams.items():
try:
yaml.dump(v)
except TypeError as err:
warn(f"Skipping '{k}' parameter because it is not possible to safely dump to YAML.")
hparams[k] = type(v).__name__
else:
hparams_allowed[k] = v

# saving the standard way
with fs.open(config_yaml, "w", newline="") as fp:
yaml.dump(hparams, fp)
yaml.dump(hparams_allowed, fp)


def convert(val: str) -> Union[int, float, bool, str]:
Expand Down
20 changes: 20 additions & 0 deletions tests/models/test_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import cloudpickle
import pytest
import torch
from fsspec.implementations.local import LocalFileSystem
from omegaconf import OmegaConf, Container
from torch.nn import functional as F
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -579,3 +580,22 @@ def test_init_arg_with_runtime_change(tmpdir):
path_yaml = os.path.join(trainer.logger.log_dir, trainer.logger.NAME_HPARAMS_FILE)
hparams = load_hparams_from_yaml(path_yaml)
assert hparams.get('running_arg') == 123


class UnsafeParamModel(BoringModel):
def __init__(self, my_path, any_param=123):
super().__init__()
self.save_hyperparameters()


def test_model_with_fsspec_as_parameter(tmpdir):
model = UnsafeParamModel(LocalFileSystem(tmpdir))
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
max_epochs=1,
)
trainer.fit(model)
trainer.test()

0 comments on commit 4204ef7

Please sign in to comment.