Skip to content

Commit

Permalink
[DataLoader] Fix the world_size when distributed sharding MapDataPipe (
Browse files Browse the repository at this point in the history
  • Loading branch information
ejguan authored and pytorchmergebot committed Jun 14, 2022
1 parent e479dae commit 04f87f2
Showing 1 changed file with 1 addition and 0 deletions.
1 change: 1 addition & 0 deletions torch/utils/data/dataloader.py
Expand Up @@ -256,6 +256,7 @@ def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1,
torch.utils.data.graph_settings.apply_sharding(self.dataset, ws, rank)
elif isinstance(self.dataset, MapDataPipe):
self.dataset = _MapDataPipeSerializationWrapper(self.dataset)
ws, rank = _get_distributed_settings()
if num_workers > 0:
self.worker_init_fn = functools.partial(
_sharding_worker_init_fn, self.worker_init_fn, ws, rank)
Expand Down

0 comments on commit 04f87f2

Please sign in to comment.