Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix frozen dataclass instance error in apply_to_collection #10927

Merged
merged 13 commits into from Jan 5, 2022
3 changes: 3 additions & 0 deletions CHANGELOG.md
Expand Up @@ -34,6 +34,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Show a better error message when a custom `DataLoader` implementation is not well implemented and we need to reconstruct it ([#10719](https://github.com/PyTorchLightning/pytorch-lightning/issues/10719))


- Show a better error message when frozen dataclass is used as a batch ([#10927](https://github.com/PyTorchLightning/pytorch-lightning/issues/10927))


- Save the `Loop`'s state by default in the checkpoint ([#10784](https://github.com/PyTorchLightning/pytorch-lightning/issues/10784))


Expand Down
9 changes: 8 additions & 1 deletion pytorch_lightning/utilities/apply_func.py
Expand Up @@ -23,6 +23,7 @@
import numpy as np
import torch

from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _compare_version, _TORCHTEXT_LEGACY
from pytorch_lightning.utilities.warnings import rank_zero_deprecation

Expand Down Expand Up @@ -147,7 +148,13 @@ def apply_to_collection(
)
if not field_init or (not include_none and v is None): # retain old value
v = getattr(data, field_name)
setattr(result, field_name, v)
try:
setattr(result, field_name, v)
except dataclasses.FrozenInstanceError as e:
raise MisconfigurationException(
"A frozen dataclass was passed to `apply_to_collection` but this is not allowed."
" HINT: is your batch a frozen dataclass?"
) from e
return result

# data is neither of dtype, nor a collection
Expand Down
12 changes: 12 additions & 0 deletions tests/utilities/test_apply_func.py
Expand Up @@ -22,6 +22,7 @@
import torch

from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections, move_data_to_device
from pytorch_lightning.utilities.exceptions import MisconfigurationException


def test_recursive_application_to_collection():
Expand Down Expand Up @@ -302,6 +303,17 @@ def fn(a, b):
assert reduced is None


def test_apply_to_collection_frozen_dataclass():
@dataclasses.dataclass(frozen=True)
class Foo:
input: torch.Tensor

foo = Foo(torch.tensor(0))

with pytest.raises(MisconfigurationException, match="frozen dataclass was passed"):
apply_to_collection(foo, torch.Tensor, lambda t: t.to(torch.int))


@pytest.mark.parametrize("should_return", [False, True])
def test_wrongly_implemented_transferable_data_type(should_return):
class TensorObject:
Expand Down