From 2bed140489493deee71f9cf27b7375992e535ebc Mon Sep 17 00:00:00 2001 From: carolynwang Date: Thu, 21 Jul 2022 16:30:49 -0700 Subject: [PATCH 1/2] add import --- src/transformers/training_args.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index d6f14fcdb7e53..9c2c72db42aa8 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1333,6 +1333,8 @@ def _setup_devices(self) -> "torch.device": device = torch.device("cuda", local_rank) self._n_gpu = 1 elif is_sagemaker_dp_enabled(): + import smdistributed.dataparallel.torch.torch_smddp # noqa: F401 + dist.init_process_group(backend="smddp") self.local_rank = int(os.getenv("SMDATAPARALLEL_LOCAL_RANK")) device = torch.device("cuda", self.local_rank) From 0e86151d377b1c8e3a623eb8b8062cd1b64016ab Mon Sep 17 00:00:00 2001 From: carolynwang Date: Mon, 25 Jul 2022 17:23:38 -0700 Subject: [PATCH 2/2] format --- src/transformers/training_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 9c2c72db42aa8..641f2dedf30cd 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1334,7 +1334,7 @@ def _setup_devices(self) -> "torch.device": self._n_gpu = 1 elif is_sagemaker_dp_enabled(): import smdistributed.dataparallel.torch.torch_smddp # noqa: F401 - + dist.init_process_group(backend="smddp") self.local_rank = int(os.getenv("SMDATAPARALLEL_LOCAL_RANK")) device = torch.device("cuda", self.local_rank)