Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resuming StreamingDataloader with num_workers=0 fails #24

Open
tchaton opened this issue Feb 26, 2024 · 1 comment
Open

Resuming StreamingDataloader with num_workers=0 fails #24

tchaton opened this issue Feb 26, 2024 · 1 comment
Labels
bug Something isn't working

Comments

@tchaton
Copy link
Collaborator

tchaton commented Feb 26, 2024

Bug description

Using a StreamingDataloader with num_workers=0 works, but resuming the state does not. There is an explicit length check for the state that fails.

Using num_workers=0 is maybe not very meaningful for real applications, but it might be good for debugging and testing purposes. Alternatively, if that's difficult to support, then StreamingDataloader could just force having num_workers>=1. I think we should do something about it, since 0 is the default for the dataloader and users might forget to set it and then run into this error which could be confusing them.

What version are you seeing the problem on?

master

How to reproduce the bug

import torch


def run():
    checkpoint_path = "checkpoint.pt"

    # Save a checkpoint
    train_dataloader = create_dataloader()
    train_iterator = iter(train_dataloader)
    next(train_iterator)
    next(train_iterator)
    torch.save(train_dataloader.state_dict(), checkpoint_path)

    # Reset and attempt resume
    train_dataloader = create_dataloader()
    state = {"train_dataloader": train_dataloader}
    train_dataloader.load_state_dict(torch.load(checkpoint_path))
    train_iterator = iter(train_dataloader)
    next(train_iterator)
    next(train_iterator)


def create_dataloader():
    from lightning.data import StreamingDataset, CombinedStreamingDataset, StreamingDataLoader
    from lightning.data.streaming.item_loader import TokensLoader

    train_datasets = [
        StreamingDataset(
            input_dir="/teamspace/s3_connections/tinyllama-template/slimpajama/train",
            item_loader=TokensLoader(block_size=4),
        ),
        StreamingDataset(
            input_dir="/teamspace/s3_connections/tinyllama-template/starcoder",
            item_loader=TokensLoader(block_size=4),
        ),
    ]
    combined_dataset = CombinedStreamingDataset(datasets=train_datasets)
    train_dataloader = StreamingDataLoader(combined_dataset, batch_size=4, num_workers=0)  # <--- BUG WHEN NUM WORKERS=0
    return train_dataloader


if __name__ == "__main__":
    run()

Error messages and logs

Traceback (most recent call last):
  File "/teamspace/studios/this_studio/repro_worker.py", line 50, in <module>
    run()
  File "/teamspace/studios/this_studio/repro_worker.py", line 25, in run
    next(train_iterator)
  File "/teamspace/studios/this_studio/lightning/src/lightning/data/streaming/dataloader.py", line 432, in __iter__
    for batch in super().__iter__():
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 438, in __iter__
    return self._get_iterator()
  File "/teamspace/studios/this_studio/lightning/src/lightning/data/streaming/dataloader.py", line 504, in _get_iterator
    return _SingleProcessDataLoaderIter(self)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 669, in __init__
    self._dataset_fetcher = _DatasetKind.create_fetcher(
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 79, in create_fetcher
    return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
  File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 21, in __init__
    self.dataset_iter = iter(dataset)
  File "/teamspace/studios/this_studio/lightning/src/lightning/data/streaming/combined.py", line 83, in __iter__
    self._iterator = _CombinedDatasetIterator(
  File "/teamspace/studios/this_studio/lightning/src/lightning/data/streaming/combined.py", line 126, in __init__
    self._dataset_iters = [iter(dataset) for dataset in datasets]
  File "/teamspace/studios/this_studio/lightning/src/lightning/data/streaming/combined.py", line 126, in <listcomp>
    self._dataset_iters = [iter(dataset) for dataset in datasets]
  File "/teamspace/studios/this_studio/lightning/src/lightning/data/streaming/dataset.py", line 146, in __iter__
    self._validate_state_dict()
  File "/teamspace/studios/this_studio/lightning/src/lightning/data/streaming/dataset.py", line 328, in _validate_state_dict
    raise ValueError(
ValueError: The provided `num_workers` state doesn't match the current one. Found `1` instead of `0`.

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0): master (2.2dev)
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

No response

Moved from Lightning-AI/pytorch-lightning#19335, submitted by @awaelchli

@tchaton tchaton added enhancement New feature or request help wanted Extra attention is needed labels Feb 26, 2024
@awaelchli awaelchli added bug Something isn't working and removed enhancement New feature or request help wanted Extra attention is needed labels Feb 28, 2024
@shreyanssethi
Copy link

Following up on this, what is the recommended practice/solution? I was able to load my checkpoint and manually adjust it from num_workers=0 to 1 and save it again, to get it to pass the check when it loads the state_dict but wanted to know if there's a better work around

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants