Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

save initial arguments #4163

Merged
merged 4 commits into from
Oct 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Expand Up @@ -21,6 +21,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed `hparams` saving - save the state when `save_hyperparameters()` is called [in `__init__`] ([#4163](https://github.com/PyTorchLightning/pytorch-lightning/pull/4163))



## [1.0.1] - 2020-10-14
Expand Down
14 changes: 13 additions & 1 deletion pytorch_lightning/core/lightning.py
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import collections
import copy
import inspect
import os
import re
Expand Down Expand Up @@ -1448,9 +1449,11 @@ def save_hyperparameters(self, *args, frame=None) -> None:
init_args = get_init_args(frame)
assert init_args, "failed to inspect the self init"
if not args:
# take all arguments
hp = init_args
self._hparams_name = "kwargs" if hp else None
else:
# take only listed arguments in `save_hparams`
isx_non_str = [i for i, arg in enumerate(args) if not isinstance(arg, str)]
if len(isx_non_str) == 1:
hp = args[isx_non_str[0]]
Expand All @@ -1463,6 +1466,8 @@ def save_hyperparameters(self, *args, frame=None) -> None:
# `hparams` are expected here
if hp:
self._set_hparams(hp)
# make deep copy so there is not other runtime changes reflected
self._hparams_initial = copy.deepcopy(self._hparams)

def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None:
if isinstance(hp, Namespace):
Expand Down Expand Up @@ -1594,11 +1599,18 @@ def to_torchscript(
return torchscript_module

@property
def hparams(self) -> Union[AttributeDict, str]:
def hparams(self) -> Union[AttributeDict, dict, Namespace]:
if not hasattr(self, "_hparams"):
self._hparams = AttributeDict()
return self._hparams

@property
def hparams_initial(self) -> AttributeDict:
if not hasattr(self, "_hparams_initial"):
self._hparams_initial = AttributeDict()
# prevent any change
return copy.deepcopy(self._hparams_initial)

@hparams.setter
def hparams(self, hp: Union[dict, Namespace, Any]):
hparams_assignment_name = self.__get_hparams_assignment_variable()
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/training_loop.py
Expand Up @@ -129,7 +129,7 @@ def setup_training(self, model: LightningModule):
# log hyper-parameters
if self.trainer.logger is not None:
# save exp to get started (this is where the first experiment logs are written)
self.trainer.logger.log_hyperparams(ref_model.hparams)
self.trainer.logger.log_hyperparams(ref_model.hparams_initial)
self.trainer.logger.log_graph(ref_model)
self.trainer.logger.save()

Expand Down
27 changes: 26 additions & 1 deletion tests/models/test_hparams.py
Expand Up @@ -25,7 +25,7 @@
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.core.saving import save_hparams_to_yaml, load_hparams_from_yaml
from pytorch_lightning.utilities import AttributeDict, is_picklable
from tests.base import EvalModelTemplate, TrialMNIST
from tests.base import EvalModelTemplate, TrialMNIST, BoringModel


class SaveHparamsModel(EvalModelTemplate):
Expand Down Expand Up @@ -554,3 +554,28 @@ def test_args(tmpdir):
with pytest.raises(TypeError, match="__init__\(\) got an unexpected keyword argument 'test'"):
SubClassVarArgs.load_from_checkpoint(raw_checkpoint_path)


class RuntimeParamChangeModel(BoringModel):
def __init__(self, running_arg):
super().__init__()
self.save_hyperparameters()


def test_init_arg_with_runtime_change(tmpdir):
model = RuntimeParamChangeModel(123)
assert model.hparams.running_arg == 123
model.hparams.running_arg = -1
assert model.hparams.running_arg == -1

trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
max_epochs=1,
)
trainer.fit(model)

path_yaml = os.path.join(trainer.logger.log_dir, trainer.logger.NAME_HPARAMS_FILE)
hparams = load_hparams_from_yaml(path_yaml)
assert hparams.get('running_arg') == 123