diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index d6f14fcdb7e538..641f2dedf30cd5 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1333,6 +1333,8 @@ def _setup_devices(self) -> "torch.device": device = torch.device("cuda", local_rank) self._n_gpu = 1 elif is_sagemaker_dp_enabled(): + import smdistributed.dataparallel.torch.torch_smddp # noqa: F401 + dist.init_process_group(backend="smddp") self.local_rank = int(os.getenv("SMDATAPARALLEL_LOCAL_RANK")) device = torch.device("cuda", self.local_rank)