From 04f87f2ab9328675d0c9b5b790df700371776a94 Mon Sep 17 00:00:00 2001 From: erjia Date: Tue, 14 Jun 2022 19:03:55 +0000 Subject: [PATCH] [DataLoader] Fix the world_size when distributed sharding MapDataPipe (#79524) Fixes #79449 Pull Request resolved: https://github.com/pytorch/pytorch/pull/79524 Approved by: https://github.com/NivekT, https://github.com/VitalyFedyunin --- torch/utils/data/dataloader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index 692ebadf66d8..766adff324d4 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -256,6 +256,7 @@ def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1, torch.utils.data.graph_settings.apply_sharding(self.dataset, ws, rank) elif isinstance(self.dataset, MapDataPipe): self.dataset = _MapDataPipeSerializationWrapper(self.dataset) + ws, rank = _get_distributed_settings() if num_workers > 0: self.worker_init_fn = functools.partial( _sharding_worker_init_fn, self.worker_init_fn, ws, rank)