diff --git a/torch/utils/data/dataloader.py b/torch/utils/data/dataloader.py index 692ebadf66d8..766adff324d4 100644 --- a/torch/utils/data/dataloader.py +++ b/torch/utils/data/dataloader.py @@ -256,6 +256,7 @@ def __init__(self, dataset: Dataset[T_co], batch_size: Optional[int] = 1, torch.utils.data.graph_settings.apply_sharding(self.dataset, ws, rank) elif isinstance(self.dataset, MapDataPipe): self.dataset = _MapDataPipeSerializationWrapper(self.dataset) + ws, rank = _get_distributed_settings() if num_workers > 0: self.worker_init_fn = functools.partial( _sharding_worker_init_fn, self.worker_init_fn, ws, rank)