Skip to content

Commit

Permalink
fix hparams assign in init (#4189)
Browse files Browse the repository at this point in the history
  • Loading branch information
Borda committed Oct 16, 2020
1 parent 130de22 commit 3fe479f
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
5 changes: 4 additions & 1 deletion pytorch_lightning/core/lightning.py
Expand Up @@ -1608,7 +1608,7 @@ def hparams(self) -> Union[AttributeDict, dict, Namespace]:
@property
def hparams_initial(self) -> AttributeDict:
if not hasattr(self, "_hparams_initial"):
self._hparams_initial = AttributeDict()
return AttributeDict()
# prevent any change
return copy.deepcopy(self._hparams_initial)

Expand All @@ -1617,6 +1617,9 @@ def hparams(self, hp: Union[dict, Namespace, Any]):
hparams_assignment_name = self.__get_hparams_assignment_variable()
self._hparams_name = hparams_assignment_name
self._set_hparams(hp)
# this resolves case when user does not uses `save_hyperparameters` and do hard assignement in init
if not hasattr(self, "_hparams_initial"):
self._hparams_initial = copy.deepcopy(self._hparams)

def __get_hparams_assignment_variable(self):
""""""
Expand Down
18 changes: 14 additions & 4 deletions tests/models/test_hparams.py
Expand Up @@ -556,17 +556,27 @@ def test_args(tmpdir):
SubClassVarArgs.load_from_checkpoint(raw_checkpoint_path)


class RuntimeParamChangeModel(BoringModel):
def __init__(self, running_arg):
class RuntimeParamChangeModelSaving(BoringModel):
def __init__(self, **kwargs):
super().__init__()
self.save_hyperparameters()


def test_init_arg_with_runtime_change(tmpdir):
model = RuntimeParamChangeModel(123)
class RuntimeParamChangeModelAssign(BoringModel):
def __init__(self, **kwargs):
super().__init__()
self.hparams = kwargs


@pytest.mark.parametrize("cls", [RuntimeParamChangeModelSaving, RuntimeParamChangeModelAssign])
def test_init_arg_with_runtime_change(tmpdir, cls):
"""Test that we save/export only the initial hparams, no other runtime change allowed"""
model = cls(running_arg=123)
assert model.hparams.running_arg == 123
model.hparams.running_arg = -1
assert model.hparams.running_arg == -1
model.hparams = Namespace(abc=42)
assert model.hparams.abc == 42

trainer = Trainer(
default_root_dir=tmpdir,
Expand Down

0 comments on commit 3fe479f

Please sign in to comment.