You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Using a StreamingDataloader with num_workers=0 works, but resuming the state does not. There is an explicit length check for the state that fails.
Using num_workers=0 is maybe not very meaningful for real applications, but it might be good for debugging and testing purposes. Alternatively, if that's difficult to support, then StreamingDataloader could just force having num_workers>=1. I think we should do something about it, since 0 is the default for the dataloader and users might forget to set it and then run into this error which could be confusing them.
What version are you seeing the problem on?
master
How to reproduce the bug
importtorchdefrun():
checkpoint_path="checkpoint.pt"# Save a checkpointtrain_dataloader=create_dataloader()
train_iterator=iter(train_dataloader)
next(train_iterator)
next(train_iterator)
torch.save(train_dataloader.state_dict(), checkpoint_path)
# Reset and attempt resumetrain_dataloader=create_dataloader()
state= {"train_dataloader": train_dataloader}
train_dataloader.load_state_dict(torch.load(checkpoint_path))
train_iterator=iter(train_dataloader)
next(train_iterator)
next(train_iterator)
defcreate_dataloader():
fromlightning.dataimportStreamingDataset, CombinedStreamingDataset, StreamingDataLoaderfromlightning.data.streaming.item_loaderimportTokensLoadertrain_datasets= [
StreamingDataset(
input_dir="/teamspace/s3_connections/tinyllama-template/slimpajama/train",
item_loader=TokensLoader(block_size=4),
),
StreamingDataset(
input_dir="/teamspace/s3_connections/tinyllama-template/starcoder",
item_loader=TokensLoader(block_size=4),
),
]
combined_dataset=CombinedStreamingDataset(datasets=train_datasets)
train_dataloader=StreamingDataLoader(combined_dataset, batch_size=4, num_workers=0) # <--- BUG WHEN NUM WORKERS=0returntrain_dataloaderif__name__=="__main__":
run()
Error messages and logs
Traceback (most recent call last):
File "/teamspace/studios/this_studio/repro_worker.py", line 50, in <module>
run()
File "/teamspace/studios/this_studio/repro_worker.py", line 25, in run
next(train_iterator)
File "/teamspace/studios/this_studio/lightning/src/lightning/data/streaming/dataloader.py", line 432, in __iter__
for batch in super().__iter__():
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 438, in __iter__
return self._get_iterator()
File "/teamspace/studios/this_studio/lightning/src/lightning/data/streaming/dataloader.py", line 504, in _get_iterator
return _SingleProcessDataLoaderIter(self)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 669, in __init__
self._dataset_fetcher = _DatasetKind.create_fetcher(
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 79, in create_fetcher
return _utils.fetch._IterableDatasetFetcher(dataset, auto_collation, collate_fn, drop_last)
File "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 21, in __init__
self.dataset_iter = iter(dataset)
File "/teamspace/studios/this_studio/lightning/src/lightning/data/streaming/combined.py", line 83, in __iter__
self._iterator = _CombinedDatasetIterator(
File "/teamspace/studios/this_studio/lightning/src/lightning/data/streaming/combined.py", line 126, in __init__
self._dataset_iters = [iter(dataset) for dataset in datasets]
File "/teamspace/studios/this_studio/lightning/src/lightning/data/streaming/combined.py", line 126, in <listcomp>
self._dataset_iters = [iter(dataset) for dataset in datasets]
File "/teamspace/studios/this_studio/lightning/src/lightning/data/streaming/dataset.py", line 146, in __iter__
self._validate_state_dict()
File "/teamspace/studios/this_studio/lightning/src/lightning/data/streaming/dataset.py", line 328, in _validate_state_dict
raise ValueError(
ValueError: The provided `num_workers` state doesn't match the current one. Found `1` instead of `0`.
Environment
Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0): master (2.2dev)
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):
Following up on this, what is the recommended practice/solution? I was able to load my checkpoint and manually adjust it from num_workers=0 to 1 and save it again, to get it to pass the check when it loads the state_dict but wanted to know if there's a better work around
Bug description
Using a StreamingDataloader with
num_workers=0
works, but resuming the state does not. There is an explicit length check for the state that fails.Using
num_workers=0
is maybe not very meaningful for real applications, but it might be good for debugging and testing purposes. Alternatively, if that's difficult to support, then StreamingDataloader could just force havingnum_workers>=1
. I think we should do something about it, since 0 is the default for the dataloader and users might forget to set it and then run into this error which could be confusing them.What version are you seeing the problem on?
master
How to reproduce the bug
Error messages and logs
Environment
Current environment
More info
No response
Moved from Lightning-AI/pytorch-lightning#19335, submitted by @awaelchli
The text was updated successfully, but these errors were encountered: