Skip to content

Commit

Permalink
Adds timeout argument to training_args to avoid socket timeouts in DDP (
Browse files Browse the repository at this point in the history
huggingface#18562)

* chore(training_args): Adds support for timeout argument.

* fix(training_args): Passes make style through changes.

* fix(training_args): Removes wrong docstring sentence.

* fix(training_args): Fixes timeout not being JSON serializable.

* fix(training_args_sm): Also updates timeout to timeout_delta.

* fix(training_args): Fixes PR according to suggestions.
  • Loading branch information
gugarosa authored and oneraghavan committed Sep 26, 2022
1 parent f435cb3 commit 36fead3
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 5 deletions.
4 changes: 2 additions & 2 deletions src/transformers/sagemaker/training_args_sm.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _setup_devices(self) -> "torch.device":
elif is_sagemaker_dp_enabled():
import smdistributed.dataparallel.torch.torch_smddp # noqa: F401

torch.distributed.init_process_group(backend="smddp")
torch.distributed.init_process_group(backend="smddp", timeout=self.ddp_timeout_delta)
self.local_rank = int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))
device = torch.device("cuda", self.local_rank)
self._n_gpu = 1
Expand All @@ -111,7 +111,7 @@ def _setup_devices(self) -> "torch.device":
# Here, we'll use torch.distributed.
# Initializes the distributed backend which will take care of synchronizing nodes/GPUs
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl")
torch.distributed.init_process_group(backend="nccl", timeout=self.ddp_timeout_delta)
device = torch.device("cuda", self.local_rank)
self._n_gpu = 1

Expand Down
27 changes: 24 additions & 3 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import os
import warnings
from dataclasses import asdict, dataclass, field
from datetime import timedelta
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
Expand Down Expand Up @@ -481,6 +482,11 @@ class TrainingArguments:
are also available. See the [Ray documentation](
https://docs.ray.io/en/latest/tune/api_docs/analysis.html#ray.tune.ExperimentAnalysis.get_best_trial) for
more options.
ddp_timeout (`int`, *optional*, defaults to 1800):
The timeout for `torch.distributed.init_process_group` calls, used to avoid GPU socket timeouts when
performing slow operations in distributed runnings. Please refer the [PyTorch documentation]
(https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) for more
information.
use_mps_device (`bool`, *optional*, defaults to `False`):
Whether to use Apple Silicon chip based `mps` device.
"""
Expand Down Expand Up @@ -971,6 +977,12 @@ class TrainingArguments:
)
},
)
ddp_timeout: Optional[int] = field(
default=1800,
metadata={
"help": "Overrides the default timeout for distributed training (value should be given in seconds)."
},
)

def __post_init__(self):
# Handle --use_env option in torch.distributed.launch (local_rank not passed as an arg then).
Expand Down Expand Up @@ -1291,6 +1303,13 @@ def eval_batch_size(self) -> int:
eval_batch_size = per_device_batch_size * max(1, self.n_gpu)
return eval_batch_size

@property
def ddp_timeout_delta(self) -> timedelta:
"""
The actual timeout for torch.distributed.init_process_group since it expects a timedelta variable.
"""
return timedelta(seconds=self.ddp_timeout)

@cached_property
@torch_required
def _setup_devices(self) -> "torch.device":
Expand Down Expand Up @@ -1358,7 +1377,9 @@ def _setup_devices(self) -> "torch.device":
f"num_cpu_threads_per_process unset, we set it at {num_cpu_threads_per_process} to improve oob"
" performance."
)
torch.distributed.init_process_group(backend=self.xpu_backend, rank=rank, world_size=size)
torch.distributed.init_process_group(
backend=self.xpu_backend, rank=rank, world_size=size, timeout=self.ddp_timeout_delta
)
elif is_torch_tpu_available():
device = xm.xla_device()
self._n_gpu = 0
Expand All @@ -1369,7 +1390,7 @@ def _setup_devices(self) -> "torch.device":
elif is_sagemaker_dp_enabled():
import smdistributed.dataparallel.torch.torch_smddp # noqa: F401

dist.init_process_group(backend="smddp")
dist.init_process_group(backend="smddp", timeout=self.ddp_timeout_delta)
self.local_rank = int(os.getenv("SMDATAPARALLEL_LOCAL_RANK"))
device = torch.device("cuda", self.local_rank)
self._n_gpu = 1
Expand Down Expand Up @@ -1431,7 +1452,7 @@ def _setup_devices(self) -> "torch.device":
# Here, we'll use torch.distributed.
# Initializes the distributed backend which will take care of synchronizing nodes/GPUs
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend="nccl")
torch.distributed.init_process_group(backend="nccl", timeout=self.ddp_timeout_delta)
device = torch.device("cuda", self.local_rank)
self._n_gpu = 1

Expand Down

0 comments on commit 36fead3

Please sign in to comment.