Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Isolate distrib_run #828

Merged
merged 1 commit into from
Nov 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/accelerate/commands/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,6 @@
logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()])


if is_torch_version(">=", "1.9.0"):
import torch.distributed.run as distrib_run

logger = logging.getLogger(__name__)

options_to_group = {
Expand Down Expand Up @@ -555,6 +552,8 @@ def simple_launcher(args):


def multi_gpu_launcher(args):
if is_torch_version(">=", "1.9.0"):
import torch.distributed.run as distrib_run
num_processes = getattr(args, "num_processes")
num_machines = getattr(args, "num_machines")
main_process_ip = getattr(args, "main_process_ip")
Expand Down Expand Up @@ -644,6 +643,8 @@ def multi_gpu_launcher(args):


def deepspeed_launcher(args):
if is_torch_version(">=", "1.9.0"):
import torch.distributed.run as distrib_run
if not is_deepspeed_available():
raise ImportError("DeepSpeed is not installed => run `pip3 install deepspeed` or build it from source.")
num_processes = getattr(args, "num_processes")
Expand Down
6 changes: 2 additions & 4 deletions src/accelerate/utils/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,6 @@
from .dataclasses import DistributedType


if is_torch_version(">=", "1.9.0"):
import torch.distributed.run as distrib_run


def get_launch_prefix():
"""
Grabs the correct launcher for starting a distributed command, such as either `torchrun`, `python -m
Expand All @@ -43,6 +39,8 @@ def _filter_args(args):
"""
Filters out all `accelerate` specific args
"""
if is_torch_version(">=", "1.9.0"):
import torch.distributed.run as distrib_run
distrib_args = distrib_run.get_args_parser()
new_args, _ = distrib_args.parse_known_args()

Expand Down