Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve error message on TypeError during DataLoader reconstruction #10719

Merged
merged 7 commits into from Nov 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Expand Up @@ -24,7 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added support for re-instantiation of custom (subclasses of) `DataLoaders` returned in the `*_dataloader()` methods, i.e., automatic replacement of samplers now works with custom types of `DataLoader` ([#10680](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639))


-
- Show a better error message when a custom `DataLoader` implementation is not well implemented and we need to reconstruct it ([#10719](https://github.com/PyTorchLightning/pytorch-lightning/issues/10719))


-
Expand Down
20 changes: 19 additions & 1 deletion pytorch_lightning/utilities/data.py
Expand Up @@ -180,7 +180,25 @@ def get_len(dataloader: DataLoader) -> Union[int, float]:
def _update_dataloader(dataloader: DataLoader, sampler: Sampler, mode: Optional[RunningStage] = None) -> DataLoader:
dl_kwargs = _get_dataloader_init_kwargs(dataloader, sampler, mode=mode)
dl_cls = type(dataloader)
dataloader = dl_cls(**dl_kwargs)
try:
dataloader = dl_cls(**dl_kwargs)
except TypeError as e:
# improve exception message due to an incorrect implementation of the `DataLoader` where multiple subclass
# `__init__` arguments map to one `DataLoader.__init__` argument
import re
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

match = re.match(r".*__init__\(\) got multiple values .* '(\w+)'", str(e))
if not match:
# an unexpected `TypeError`, continue failure
raise
argument = match.groups()[0]
message = (
f"The {dl_cls.__name__} `DataLoader` implementation has an error where more than one `__init__` argument"
f" can be passed to its parent's `{argument}=...` `__init__` argument. This is likely caused by allowing"
f" passing both a custom argument that will map to the `{argument}` argument as well as `**kwargs`."
f" `kwargs` should be filtered to make sure they don't contain the `{argument}` key."
)
raise MisconfigurationException(message) from e
return dataloader


Expand Down
33 changes: 33 additions & 0 deletions tests/utilities/test_data.py
Expand Up @@ -5,6 +5,7 @@
from pytorch_lightning import Trainer
from pytorch_lightning.utilities.data import (
_replace_dataloader_init_method,
_update_dataloader,
extract_batch_size,
get_len,
has_iterable_dataset,
Expand Down Expand Up @@ -115,6 +116,38 @@ def test_has_len_all_rank():
assert has_len_all_ranks(DataLoader(RandomDataset(1, 1)), trainer.training_type_plugin, model)


def test_update_dataloader_typerror_custom_exception():
class BadImpl(DataLoader):
def __init__(self, foo, *args, **kwargs):
self.foo = foo
# positional conflict with `dataset`
super().__init__(foo, *args, **kwargs)

dataloader = BadImpl([1, 2, 3])
with pytest.raises(MisconfigurationException, match="`DataLoader` implementation has an error.*`dataset`"):
_update_dataloader(dataloader, dataloader.sampler)

class BadImpl2(DataLoader):
def __init__(self, randomize, *args, **kwargs):
self.randomize = randomize
# keyword conflict with `shuffle`
super().__init__(*args, shuffle=randomize, **kwargs)

dataloader = BadImpl2(False, [])
with pytest.raises(MisconfigurationException, match="`DataLoader` implementation has an error.*`shuffle`"):
_update_dataloader(dataloader, dataloader.sampler)

class GoodImpl(DataLoader):
def __init__(self, randomize, *args, **kwargs):
# fixed implementation, kwargs are filtered
self.randomize = randomize or kwargs.pop("shuffle", False)
super().__init__(*args, shuffle=randomize, **kwargs)

dataloader = GoodImpl(False, [])
new_dataloader = _update_dataloader(dataloader, dataloader.sampler)
assert isinstance(new_dataloader, GoodImpl)


def test_replace_dataloader_init_method():
"""Test that context manager intercepts arguments passed to custom subclasses of torch.utils.DataLoader and
sets them as attributes."""
Expand Down