Skip to content

Commit

Permalink
save initial arguments (#4163)
Browse files Browse the repository at this point in the history
* save initial arguments

* typing

* chlog

* .
  • Loading branch information
Borda committed Oct 15, 2020
1 parent 4290c9e commit f064682
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 3 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed `hparams` saving - save the state when `save_hyperparameters()` is called [in `__init__`] ([#4163](https://github.com/PyTorchLightning/pytorch-lightning/pull/4163))



## [1.0.1] - 2020-10-14
Expand Down
14 changes: 13 additions & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import collections
import copy
import inspect
import os
import re
Expand Down Expand Up @@ -1448,9 +1449,11 @@ def save_hyperparameters(self, *args, frame=None) -> None:
init_args = get_init_args(frame)
assert init_args, "failed to inspect the self init"
if not args:
# take all arguments
hp = init_args
self._hparams_name = "kwargs" if hp else None
else:
# take only listed arguments in `save_hparams`
isx_non_str = [i for i, arg in enumerate(args) if not isinstance(arg, str)]
if len(isx_non_str) == 1:
hp = args[isx_non_str[0]]
Expand All @@ -1463,6 +1466,8 @@ def save_hyperparameters(self, *args, frame=None) -> None:
# `hparams` are expected here
if hp:
self._set_hparams(hp)
# make deep copy so there is not other runtime changes reflected
self._hparams_initial = copy.deepcopy(self._hparams)

def _set_hparams(self, hp: Union[dict, Namespace, str]) -> None:
if isinstance(hp, Namespace):
Expand Down Expand Up @@ -1594,11 +1599,18 @@ def to_torchscript(
return torchscript_module

@property
def hparams(self) -> Union[AttributeDict, str]:
def hparams(self) -> Union[AttributeDict, dict, Namespace]:
if not hasattr(self, "_hparams"):
self._hparams = AttributeDict()
return self._hparams

@property
def hparams_initial(self) -> AttributeDict:
if not hasattr(self, "_hparams_initial"):
self._hparams_initial = AttributeDict()
# prevent any change
return copy.deepcopy(self._hparams_initial)

@hparams.setter
def hparams(self, hp: Union[dict, Namespace, Any]):
hparams_assignment_name = self.__get_hparams_assignment_variable()
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def setup_training(self, model: LightningModule):
# log hyper-parameters
if self.trainer.logger is not None:
# save exp to get started (this is where the first experiment logs are written)
self.trainer.logger.log_hyperparams(ref_model.hparams)
self.trainer.logger.log_hyperparams(ref_model.hparams_initial)
self.trainer.logger.log_graph(ref_model)
self.trainer.logger.save()

Expand Down
27 changes: 26 additions & 1 deletion tests/models/test_hparams.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from pytorch_lightning import Trainer, LightningModule
from pytorch_lightning.core.saving import save_hparams_to_yaml, load_hparams_from_yaml
from pytorch_lightning.utilities import AttributeDict, is_picklable
from tests.base import EvalModelTemplate, TrialMNIST
from tests.base import EvalModelTemplate, TrialMNIST, BoringModel


class SaveHparamsModel(EvalModelTemplate):
Expand Down Expand Up @@ -554,3 +554,28 @@ def test_args(tmpdir):
with pytest.raises(TypeError, match="__init__\(\) got an unexpected keyword argument 'test'"):
SubClassVarArgs.load_from_checkpoint(raw_checkpoint_path)


class RuntimeParamChangeModel(BoringModel):
def __init__(self, running_arg):
super().__init__()
self.save_hyperparameters()


def test_init_arg_with_runtime_change(tmpdir):
model = RuntimeParamChangeModel(123)
assert model.hparams.running_arg == 123
model.hparams.running_arg = -1
assert model.hparams.running_arg == -1

trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=2,
limit_val_batches=2,
limit_test_batches=2,
max_epochs=1,
)
trainer.fit(model)

path_yaml = os.path.join(trainer.logger.log_dir, trainer.logger.NAME_HPARAMS_FILE)
hparams = load_hparams_from_yaml(path_yaml)
assert hparams.get('running_arg') == 123

0 comments on commit f064682

Please sign in to comment.