Skip to content

Commit

Permalink
Support re-instantiation for custom DataLoader in Lightning (#10680)
Browse files Browse the repository at this point in the history
Co-authored-by: Justus Schock <12886177+justusschock@users.noreply.github.com>
  • Loading branch information
awaelchli and justusschock committed Nov 24, 2021
1 parent e51a8ee commit 30ec481
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 103 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Expand Up @@ -21,7 +21,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
* Add logic to reload the states across data loading components ([#10699](https://github.com/PyTorchLightning/pytorch-lightning/issues/10699))
* Cleanup some fault tolerant utilities ([#10703](https://github.com/PyTorchLightning/pytorch-lightning/issues/10703))

-
- Added support for re-instantiation of custom (subclasses of) `DataLoaders` returned in the `*_dataloader()` methods, i.e., automatic replacement of samplers now works with custom types of `DataLoader` ([#10680](https://github.com/PyTorchLightning/pytorch-lightning/issues/10639))


-
Expand Down
14 changes: 7 additions & 7 deletions pytorch_lightning/lite/lite.py
Expand Up @@ -25,17 +25,17 @@
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, SequentialSampler

from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.lite.wrappers import (
_LiteDataLoader,
_LiteModule,
_LiteOptimizer,
_replace_dataloader_init_method,
)
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
from pytorch_lightning.plugins import DDPSpawnPlugin, DeepSpeedPlugin, PLUGIN_INPUT, TPUSpawnPlugin, TrainingTypePlugin
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
from pytorch_lightning.utilities import _StrategyType, DeviceType, move_data_to_device
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
from pytorch_lightning.utilities.data import _auto_add_worker_init_fn, _update_dataloader, has_iterable_dataset
from pytorch_lightning.utilities.data import (
_auto_add_worker_init_fn,
_replace_dataloader_init_method,
_update_dataloader,
has_iterable_dataset,
)
from pytorch_lightning.utilities.device_parser import _parse_devices
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.seed import seed_everything
Expand Down
49 changes: 1 addition & 48 deletions pytorch_lightning/lite/wrappers.py
Expand Up @@ -11,11 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import inspect
from contextlib import contextmanager
from itertools import chain
from typing import Any, Callable, Generator, Iterator, Optional, Set, Type, Union
from typing import Any, Callable, Generator, Iterator, Optional, Union

import torch
from torch import nn as nn
Expand Down Expand Up @@ -110,49 +106,6 @@ def _convert_float_tensor(t: Tensor) -> Tensor:
return output


def _wrap_init(init: Callable) -> Callable:
"""Wraps the ``__init__`` method of the dataloader in order to enable re-instantiation of custom subclasses of
:class:`~torch.utils.data.DataLoader`."""

@functools.wraps(init)
def wrapper(obj: DataLoader, *args: Any, **kwargs: Any) -> None:
params = dict(inspect.signature(obj.__init__).parameters)
params.pop("args", None)
params.pop("kwargs", None)
for arg_name, arg_value in chain(zip(params, args), kwargs.items()):
setattr(obj, arg_name, arg_value)
init(obj, *args, **kwargs)

return wrapper


# https://stackoverflow.com/a/63851681/9201239
def _get_all_subclasses(cls: Type[Any]) -> Set[Type[Any]]:
"""Returns a list of all classes that inherit directly or indirectly from the given class."""
subclasses = set()

def recurse(cl: Type[Any]) -> None:
for subclass in cl.__subclasses__():
subclasses.add(subclass)
recurse(subclass)

recurse(cls)
return subclasses


@contextmanager
def _replace_dataloader_init_method() -> Generator[None, None, None]:
"""This context manager is used to add support for re-instantiation of custom (subclasses) of
:class:`~torch.utils.data.DataLoader`. It patches the ``__init__`` method."""
for subclass in _get_all_subclasses(DataLoader):
subclass._old_init = subclass.__init__
subclass.__init__ = _wrap_init(subclass.__init__)
yield
for subclass in _get_all_subclasses(DataLoader):
subclass.__init__ = subclass._old_init
del subclass._old_init


class _LiteDataLoader:
def __init__(self, dataloader: DataLoader, device: Optional[torch.device] = None) -> None:
"""The LiteDataLoader is a wrapper for the :class:`~torch.utils.data.DataLoader`. It moves the data to the
Expand Down
6 changes: 5 additions & 1 deletion pytorch_lightning/trainer/data_loading.py
Expand Up @@ -31,6 +31,7 @@
from pytorch_lightning.utilities.auto_restart import _add_capture_metadata_collate
from pytorch_lightning.utilities.data import (
_auto_add_worker_init_fn,
_replace_dataloader_init_method,
_update_dataloader,
has_iterable_dataset,
has_len_all_ranks,
Expand Down Expand Up @@ -430,7 +431,10 @@ def request_dataloader(

hook = f"{stage.dataloader_prefix}_dataloader"
self.call_hook("on_" + hook, pl_module=model)
dataloader = source.dataloader()
with _replace_dataloader_init_method():
# under this context manager, the arguments passed to `DataLoader.__init__` will be captured and saved as
# attributes on the instance in case the dataloader needs to be re-instantiated later by Ligtning
dataloader = source.dataloader()
if isinstance(dataloader, tuple):
dataloader = list(dataloader)
self.training_type_plugin.barrier("get_dataloaders")
Expand Down
49 changes: 48 additions & 1 deletion pytorch_lightning/utilities/data.py
Expand Up @@ -11,10 +11,13 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import inspect
import os
from contextlib import contextmanager
from functools import partial
from typing import Any, Dict, Generator, Iterable, Mapping, Optional, Union
from itertools import chain
from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Set, Type, Union

import torch
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, Sampler
Expand Down Expand Up @@ -299,6 +302,50 @@ def _auto_add_worker_init_fn(dataloader: DataLoader, rank: int) -> None:
dataloader.worker_init_fn = partial(pl_worker_init_function, rank=rank)


def _wrap_init(init: Callable) -> Callable:
"""Wraps the ``__init__`` method of the dataloader in order to enable re-instantiation of custom subclasses of
:class:`~torch.utils.data.DataLoader`."""

@functools.wraps(init)
def wrapper(obj: DataLoader, *args: Any, **kwargs: Any) -> None:
params = dict(inspect.signature(obj.__init__).parameters)
params.pop("args", None)
params.pop("kwargs", None)
for arg_name, arg_value in chain(zip(params, args), kwargs.items()):
setattr(obj, arg_name, arg_value)
init(obj, *args, **kwargs)

return wrapper


# https://stackoverflow.com/a/63851681/9201239
def _get_all_subclasses(cls: Type[Any]) -> Set[Type[Any]]:
"""Returns a list of all classes that inherit directly or indirectly from the given class."""
subclasses = set()

def recurse(cl: Type[Any]) -> None:
for subclass in cl.__subclasses__():
subclasses.add(subclass)
recurse(subclass)

recurse(cls)
return subclasses


@contextmanager
def _replace_dataloader_init_method() -> Generator[None, None, None]:
"""This context manager is used to add support for re-instantiation of custom (subclasses) of
:class:`~torch.utils.data.DataLoader`. It patches the ``__init__`` method."""
subclasses = _get_all_subclasses(DataLoader)
for subclass in subclasses:
subclass._old_init = subclass.__init__
subclass.__init__ = _wrap_init(subclass.__init__)
yield
for subclass in subclasses:
subclass.__init__ = subclass._old_init
del subclass._old_init


def _apply_fault_tolerant_automatic_capture_dataset_wrapper(dl_kwargs: Dict) -> Dict:
dataset = dl_kwargs["dataset"]
if isinstance(dataset, IterableDataset):
Expand Down
34 changes: 9 additions & 25 deletions tests/lite/test_lite.py
Expand Up @@ -164,32 +164,16 @@ def test_setup_dataloaders_return_type():
assert lite_dataloader1.dataset is dataset1


def test_setup_dataloaders_with_custom_type():
"""Test that Lite intercepts arguments passed to custom subclasses of torch.utils.DataLoader and sets them as
attributes."""

class DataLoaderSubclass1(DataLoader):
def __init__(self, attribute1, *args, **kwargs):
# intentionally not setting this attribute, calling super with different args
# self.attribute1 = attribute1
super().__init__(*args, **kwargs)

class DataLoaderSubclass2(DataLoaderSubclass1):
def __init__(self, attribute1, attribute2, *args, **kwargs):
# intentionally not setting this attribute, calling super with different args
# self.attribute2 = attribute2
super().__init__(attribute1, *args, **kwargs)

class LiteWithCustomDataLoader(LightningLite):
@mock.patch("pytorch_lightning.lite.lite._replace_dataloader_init_method")
def test_setup_dataloaders_captures_dataloader_arguments(ctx_manager):
"""Test that Lite intercepts the DataLoader constructor arguments with a context manager in its run method."""

class Lite(LightningLite):
def run(self):
dataloader = DataLoaderSubclass2("attribute1", "attribute2", dataset=range(4), batch_size=2)
assert dataloader.attribute1 == "attribute1"
assert dataloader.attribute2 == "attribute2"
lite_dataloader = self.setup_dataloaders(dataloader)
assert lite_dataloader.attribute1 == "attribute1"
assert lite_dataloader.attribute2 == "attribute2"

LiteWithCustomDataLoader().run()
ctx_manager().__enter__.assert_called_once()

Lite().run()
ctx_manager().__exit__.assert_called_once()


def test_setup_dataloaders_raises_for_unknown_custom_args():
Expand Down
29 changes: 9 additions & 20 deletions tests/trainer/test_data_loading.py
Expand Up @@ -28,19 +28,16 @@


@RunIf(skip_windows=True)
@pytest.mark.parametrize("mode", (1, 2, 3))
@pytest.mark.parametrize("mode", (1, 2))
def test_replace_distributed_sampler(tmpdir, mode):
class IndexedRandomDataset(RandomDataset):
def __getitem__(self, index):
return self.data[index]

class CustomDataLoader(DataLoader):
def __init__(self, num_features, dataset, *args, **kwargs):
self.num_features = num_features
super().__init__(dataset, *args, **kwargs)

class FailureCustomDataLoader(DataLoader):
def __init__(self, num_features, dataset, *args, **kwargs):
# argument `num_features` unused on purpose
# it gets automatically captured by _replace_dataloader_init_method()
super().__init__(dataset, *args, **kwargs)

class CustomBatchSampler(BatchSampler):
Expand All @@ -59,11 +56,11 @@ def on_test_start(self) -> None:
dataloader = self.trainer.test_dataloaders[0]
assert isinstance(dataloader, CustomDataLoader)
batch_sampler = dataloader.batch_sampler
if self._mode == 2:
if self._mode == 1:
assert isinstance(batch_sampler, CustomBatchSampler)
# the batch_size is set on the batch sampler
assert dataloader.batch_size is None
elif self._mode == 3:
elif self._mode == 2:
assert type(batch_sampler) is BatchSampler
assert dataloader.batch_size == self._mode
assert batch_sampler.batch_size == self._mode
Expand All @@ -74,15 +71,12 @@ def on_test_start(self) -> None:
def create_dataset(self):
dataset = IndexedRandomDataset(32, 64)
if self._mode == 1:
# this case will raise an error
return FailureCustomDataLoader(32, dataset)
if self._mode == 2:
# with a custom batch sampler
batch_sampler = CustomBatchSampler(SequentialSampler(dataset), batch_size=2, drop_last=True)
batch_sampler = CustomBatchSampler(SequentialSampler(dataset), batch_size=1, drop_last=True)
return CustomDataLoader(32, dataset, batch_sampler=batch_sampler)
elif self._mode == 3:
elif self._mode == 2:
# with no batch sampler provided
return CustomDataLoader(32, dataset, batch_size=3, drop_last=True)
return CustomDataLoader(32, dataset, batch_size=2, drop_last=True)

def test_dataloader(self):
return [self.create_dataset()] * self._numbers_test_dataloaders
Expand All @@ -93,12 +87,7 @@ def test_dataloader(self):
trainer = Trainer(
default_root_dir=tmpdir, limit_test_batches=2, strategy="ddp_find_unused_parameters_false", num_processes=1
)
if mode == 1:
match = escape("missing attributes are ['num_features']")
with pytest.raises(MisconfigurationException, match=match):
trainer.test(model)
else:
trainer.test(model)
trainer.test(model)


class TestSpawnBoringModel(BoringModel):
Expand Down
27 changes: 27 additions & 0 deletions tests/utilities/test_data.py
Expand Up @@ -4,6 +4,7 @@

from pytorch_lightning import Trainer
from pytorch_lightning.utilities.data import (
_replace_dataloader_init_method,
extract_batch_size,
get_len,
has_iterable_dataset,
Expand Down Expand Up @@ -112,3 +113,29 @@ def test_has_len_all_rank():
assert not has_len_all_ranks(DataLoader(RandomDataset(0, 0)), trainer.training_type_plugin, model)

assert has_len_all_ranks(DataLoader(RandomDataset(1, 1)), trainer.training_type_plugin, model)


def test_replace_dataloader_init_method():
"""Test that context manager intercepts arguments passed to custom subclasses of torch.utils.DataLoader and
sets them as attributes."""

class DataLoaderSubclass1(DataLoader):
def __init__(self, attribute1, *args, **kwargs):
# intentionally not setting this attribute, calling super with different args
# self.attribute1 = attribute1
super().__init__(*args, **kwargs)

class DataLoaderSubclass2(DataLoaderSubclass1):
def __init__(self, attribute1, attribute2, *args, **kwargs):
# intentionally not setting this attribute, calling super with different args
# self.attribute2 = attribute2
super().__init__(attribute1, *args, **kwargs)

with _replace_dataloader_init_method():
dataloader = DataLoaderSubclass1("attribute1", dataset=range(4), batch_size=2)
assert dataloader.attribute1 == "attribute1"

with _replace_dataloader_init_method():
dataloader = DataLoaderSubclass2("attribute1", "attribute2", dataset=range(4), batch_size=2)
assert dataloader.attribute1 == "attribute1"
assert dataloader.attribute2 == "attribute2"

0 comments on commit 30ec481

Please sign in to comment.