Skip to content

Commit

Permalink
Hydra + DDP improvements
Browse files Browse the repository at this point in the history
* Create different hydra output subdirectories for processes started by DDP
* Support experimental-rerun
* If rerun is not enabled but multi-run used, raise explicit error
Reverts parts of Lightning-AI#15737
  • Loading branch information
nisheethlahoti committed Jul 27, 2023
1 parent 324d90a commit 72b1c24
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 7 deletions.
32 changes: 25 additions & 7 deletions src/lightning/fabric/strategies/launchers/subprocess_script.py
Expand Up @@ -141,8 +141,10 @@ def _basic_subprocess_cmd() -> Sequence[str]:
return [sys.executable, "-m", __main__.__spec__.name] + sys.argv[1:]


def _hydra_subprocess_cmd(local_rank: int) -> Tuple[Sequence[str], str]:
def _hydra_subprocess_cmd(local_rank: int) -> Tuple[Sequence[str], Optional[str]]:
import __main__ # local import to avoid https://github.com/Lightning-AI/lightning/issues/15218
from hydra.core.hydra_config import HydraConfig
from hydra.types import RunMode
from hydra.utils import get_original_cwd, to_absolute_path

# when user is using hydra find the absolute path
Expand All @@ -151,9 +153,25 @@ def _hydra_subprocess_cmd(local_rank: int) -> Tuple[Sequence[str], str]:
else:
command = [sys.executable, "-m", __main__.__spec__.name]

command += sys.argv[1:]

cwd = get_original_cwd()
os_cwd = f'"{os.getcwd()}"'
command += [f"hydra.run.dir={os_cwd}", f"hydra.job.name=train_ddp_process_{local_rank}"]
return command, cwd
# extract the hydra configuration
hydra_cfg = HydraConfig.get()

# the location of the hydra configuration files saved for the current job
hydra_output = hydra_cfg.runtime.output_dir
if hydra_cfg.output_subdir is not None:
hydra_output = os.path.join(hydra_output, hydra_cfg.output_subdir)

# check if experimental re-run capability exists
# otherwise use existing config.yaml which may have issues
pickled_config = os.path.join(hydra_output, "config.pickle")
if os.path.exists(pickled_config):
command += ["--experimental-rerun", pickled_config]
return command, None
elif hydra_cfg.mode == RunMode.RUN:
command += [
f"hydra.output_subdir=.pl_ddp_hydra_{local_rank}",
f"hydra.run.dir={hydra_cfg.runtime.output_dir}",
]
return command, get_original_cwd()
else:
raise RuntimeError(f"DDP with multirun requires re-run callback, but no file found at {pickled_config}")
59 changes: 59 additions & 0 deletions tests/tests_pytorch/strategies/launchers/test_subprocess_script.py
@@ -1,5 +1,6 @@
import subprocess
import sys
from pathlib import Path
from unittest.mock import Mock

import pytest
Expand All @@ -13,6 +14,7 @@

if _HYDRA_WITH_RUN_PROCESS:
from hydra.test_utils.test_utils import run_process
from omegaconf import OmegaConf


# Script to run from command line
Expand Down Expand Up @@ -63,6 +65,63 @@ def test_ddp_with_hydra_runjob(subdir, tmpdir, monkeypatch):
cmd += [f"hydra.output_subdir={subdir}"]
run_process(cmd)

# Make sure config.yaml was created for additional
# processes.
logs = list(Path.cwd().glob("**/config.yaml"))
assert len(logs) == devices

# Make sure the parameter was set and used
cfg = OmegaConf.load(logs[0])
assert cfg.devices == devices

# Make sure PL spawned a job that is logged by Hydra
logs = list(Path.cwd().glob("**/*.log"))
assert len(logs) == 1


yaml_file = """
hydra:
callbacks:
save_job_info:
_target_: hydra.experimental.callbacks.PickleJobInfoCallback
"""


@RunIf(min_cuda_gpus=2, skip_windows=True, standalone=True)
@pytest.mark.skipif(not _HYDRA_WITH_RERUN, reason=str(_HYDRA_WITH_RERUN))
@pytest.mark.parametrize("num_jobs", [1, 2])
def test_ddp_with_hydra_multirunjob_rerun(tmpdir, num_jobs, monkeypatch):
monkeypatch.chdir(tmpdir)

# Save script locally
with open("temp.py", "w") as fn:
fn.write(script)

with open("config.yaml", "w") as fn:
fn.write(yaml_file)

# create fake multirun params based on `num_jobs`
fake_param = "+foo=" + ",".join(str(i) for i in range(num_jobs))

# Run CLI
run_process(
[
sys.executable,
"temp.py",
"-cp",
".",
"-cn",
"config.yaml",
"+devices=2",
'+strategy="ddp"',
fake_param,
"--multirun",
]
)

pickles = sorted(Path.cwd().glob("**/.hydra/config.pickle"))
assert len(pickles) == num_jobs


def test_kill():
launcher = _SubprocessScriptLauncher(Mock(), 1, 1)
Expand Down

0 comments on commit 72b1c24

Please sign in to comment.