Skip to content

Commit

Permalink
Support multi-run with hydra + DDP
Browse files Browse the repository at this point in the history
  • Loading branch information
nisheethlahoti committed Jul 28, 2023
1 parent 664aa5b commit e65d344
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 6 deletions.
20 changes: 16 additions & 4 deletions src/lightning/fabric/strategies/launchers/subprocess_script.py
Expand Up @@ -14,6 +14,7 @@
import os
import subprocess
import sys
from pathlib import Path
from typing import Any, Callable, Optional, Sequence, Tuple

from lightning_utilities.core.imports import RequirementCache
Expand Down Expand Up @@ -143,6 +144,8 @@ def _basic_subprocess_cmd() -> Sequence[str]:

def _hydra_subprocess_cmd(local_rank: int) -> Tuple[Sequence[str], str]:
import __main__ # local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
from hydra.core.hydra_config import HydraConfig
from hydra.types import RunMode
from hydra.utils import get_original_cwd, to_absolute_path

# when user is using hydra find the absolute path
Expand All @@ -151,9 +154,18 @@ def _hydra_subprocess_cmd(local_rank: int) -> Tuple[Sequence[str], str]:
else:
command = [sys.executable, "-m", __main__.__spec__.name]

command += sys.argv[1:]

cwd = get_original_cwd()
os_cwd = f'"{os.getcwd()}"'
command += [f"hydra.run.dir={os_cwd}", f"hydra.job.name=train_ddp_process_{local_rank}"]
hydra_cfg = HydraConfig.get()
run_dir = Path(hydra_cfg.runtime.output_dir)

if hydra_cfg.output_subdir is None: # config isn't saved, so re-run original command
if hydra_cfg.mode == RunMode.MULTIRUN:
raise RuntimeError(f"DDP with multirun requires saved config file")
command += sys.argv[1:]
else:
hydra_subdir = run_dir / hydra_cfg.output_subdir
command += ["-cp", str(hydra_subdir), "-cn", "config.yaml"] # Used saved config for new run
command += [f"hydra.output_subdir=.pl_ddp_hydra_{local_rank}"] # Log to different subdir

command += [f"hydra.run.dir={run_dir}", f"hydra.job.name=train_ddp_process_{local_rank}"]
return command, cwd
51 changes: 49 additions & 2 deletions tests/tests_pytorch/strategies/launchers/test_subprocess_script.py
@@ -1,5 +1,6 @@
import subprocess
import sys
from pathlib import Path
from unittest.mock import Mock

import pytest
Expand All @@ -13,6 +14,7 @@

if _HYDRA_WITH_RUN_PROCESS:
from hydra.test_utils.test_utils import run_process
from omegaconf import OmegaConf


# Script to run from command line
Expand Down Expand Up @@ -48,7 +50,7 @@ def task_fn(cfg):

@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
@pytest.mark.skipif(not _HYDRA_WITH_RUN_PROCESS, reason=str(_HYDRA_WITH_RUN_PROCESS))
@pytest.mark.parametrize("subdir", [None, "dksa", ".hello"])
@pytest.mark.parametrize("subdir", [None, "null", "dksa", ".hello"])
def test_ddp_with_hydra_runjob(subdir, tmpdir, monkeypatch):
monkeypatch.chdir(tmpdir)

Expand All @@ -58,11 +60,56 @@ def test_ddp_with_hydra_runjob(subdir, tmpdir, monkeypatch):

# Run CLI
devices = 2
cmd = [sys.executable, "temp.py", f"+devices={devices}", '+strategy="ddp"']
run_dir = Path(tmpdir) / "hydra_output"
cmd = [sys.executable, "temp.py", f"+devices={devices}", '+strategy="ddp"', f"hydra.run.dir={run_dir}"]
if subdir is not None:
cmd += [f"hydra.output_subdir={subdir}"]
run_process(cmd)

# Make sure config.yaml was created for additional processes iff subdir is present.
saved_confs = list(run_dir.glob("**/config.yaml"))
assert len(saved_confs) == (0 if subdir == "null" else devices)

if saved_confs: # Make sure the parameter was set and used
cfg = OmegaConf.load(saved_confs[0])
assert cfg.devices == devices

# Make sure PL spawned jobs that are logged by Hydra
logs = list(run_dir.glob("**/*.log"))
assert len(logs) == devices


@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
@pytest.mark.skipif(not _HYDRA_WITH_RUN_PROCESS, reason=str(_HYDRA_WITH_RUN_PROCESS))
@pytest.mark.parametrize("num_jobs", [1, 2])
def test_ddp_with_hydra_multirunjob(tmpdir, num_jobs, monkeypatch):
monkeypatch.chdir(tmpdir)

# Save script locally
with open("temp.py", "w") as fn:
fn.write(script)

# Run CLI
devices = 2
sweep_dir = Path(tmpdir) / "hydra_output"
command = [sys.executable, "temp.py", f"+devices={devices}", '+strategy="ddp"', f"hydra.sweep.dir={sweep_dir}"]
command += ["--multirun", "+foo=" + ",".join(str(i) for i in range(num_jobs))] # fake multirun params
run_process(command)

# Make sure config.yaml was created for each job
saved_confs = list(sweep_dir.glob("**/config.yaml"))
assert len(saved_confs) == devices * num_jobs

# Make sure the parameter was set and used for each job
for config in saved_confs:
cfg = OmegaConf.load(config)
local_rank = int(config.parent.parent.parts[-1])
assert cfg.devices == devices
assert cfg.foo == local_rank

logs = list(sweep_dir.glob("**/*.log"))
assert len(logs) == devices * num_jobs


def test_kill():
launcher = _SubprocessScriptLauncher(Mock(), 1, 1)
Expand Down

0 comments on commit e65d344

Please sign in to comment.