From 858683a7d13eba1ff773f20fd780b78a4a3bc06b Mon Sep 17 00:00:00 2001 From: Carolyn Wang <32006339+carolynwang@users.noreply.github.com> Date: Tue, 26 Jul 2022 13:00:24 -0700 Subject: [PATCH] patch for smddp import (#18244) * add import * format --- src/transformers/training_args.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index d6f14fcdb7e53..641f2dedf30cd 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1333,6 +1333,8 @@ def _setup_devices(self) -> "torch.device": device = torch.device("cuda", local_rank) self._n_gpu = 1 elif is_sagemaker_dp_enabled(): + import smdistributed.dataparallel.torch.torch_smddp # noqa: F401 + dist.init_process_group(backend="smddp") self.local_rank = int(os.getenv("SMDATAPARALLEL_LOCAL_RANK")) device = torch.device("cuda", self.local_rank)