Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DDP with Hydra multirun doesn't work when dirpath in checkpoint callback is specified #11300

Open
ashleve opened this issue Jan 3, 2022 · 16 comments Β· Fixed by #11617
Open

DDP with Hydra multirun doesn't work when dirpath in checkpoint callback is specified #11300

ashleve opened this issue Jan 3, 2022 · 16 comments Β· Fixed by #11617
Labels
argparse (removed) Related to argument parsing (argparse, Hydra, ...) bug Something isn't working priority: 1 Medium priority task strategy: ddp DistributedDataParallel
Milestone

Comments

@ashleve
Copy link
Contributor

ashleve commented Jan 3, 2022

πŸ› Bug

Running DDP with Hydra multirun ends up with "Killed" error message when launching the second task:

Epoch 0    ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0/939 0:00:00 β€’ -:--:-- 0.00it/s [W reducer.cpp:1158] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration,  which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())
[W reducer.cpp:1158] Warning: find_unused_parameters=True was specified in DDP constructor, but did not find any unused parameters in the forward pass. This flag results in an extra traversal of the autograd graph every iteration,  which can adversely affect performance. If your model indeed never has any unused parameters in the forward pass, consider turning this flag off. Note that this warning may be a false positive if your model has flow control causing later iterations to have unused parameters. (function operator())
Epoch 0    ━━━━━━━━━━━━━━━━ 939/939 0:00:13 β€’        70.53it/s loss: 0.142      
                                    0:00:00                    v_num:           
[2022-01-03 15:21:38,513][src.train][INFO] - Starting testing!
[2022-01-03 15:21:38,514][pytorch_lightning.utilities.distributed][INFO] - Restoring states from the checkpoint path at /home/user/lightning-hydra-template/logs/multiruns/2022-01-03/15-21-17/0/checkpoints/epoch_000.ckpt
[2022-01-03 15:21:38,535][pytorch_lightning.accelerators.gpu][INFO] - LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]
[2022-01-03 15:21:41,523][HYDRA]        #1 : trainer.max_epochs=1 datamodule.batch_size=64 trainer.gpus=2 +trainer.strategy=ddp
Killed

I experience this ONLY when passing the dirpath parameter to checkpoint callback:

ModelCheckpoint(dirpath="checkpoints/")

Tested for lightning v1.5.7. I believe this issue wasn't around in one of the previous releases.

This probably has something to do with the way hydra changes working directory for each new run - the directory for storing checkpoints also gets changed. If I remember correctly, there was some workaround implemented in lightning which made DDP possible despite that.

cc @tchaton @rohitgr7 @justusschock @kaushikb11 @awaelchli @akihironitta

@ashleve ashleve added the bug Something isn't working label Jan 3, 2022
@tchaton
Copy link
Contributor

tchaton commented Jan 4, 2022

Hey @ashleve,

The code in Lightning to enable DDP with Hydra is there: https://github.com/PyTorchLightning/pytorch-lightning/blob/7fa1aebcc99297e4d7eb8dcf2deb22e6da814edf/pytorch_lightning/strategies/ddp.py#L231.

Would you be interested in investigating this behavior with sweep and creating a bug fix PR if you manage to solve it.?

Otherwise, would you mind providing a reproducible script with the BoringModel + Hydra ?

Alternatively, did you try using Trainer(strategy="ddp_spawn") ?

Best,
T.C

@tchaton tchaton added hydra strategy: ddp DistributedDataParallel priority: 0 High priority task labels Jan 4, 2022
@ashleve
Copy link
Contributor Author

ashleve commented Jan 4, 2022

@tchaton Hey, here's a minimal example:
https://github.com/ashleve/lit-hydra-ddp-multirun-bug
Run multirun with python main.py -m x=0,1

I was not able to find an easy fix, but here's what I found:

  1. The process is killed only when using trainer.test(), I suspect the cause might be incorrect ckpt path
  2. The hydra logging folder gets multiplied for each process in ddp:
    image
    Here you can see 2 folders with names generated based on time: 16-27-40, 16-27-42. Both of those were generated by single multirun, but there should be only one main folder with multiple subfolders named by job number: 0,1,2....
    Seems like each DDP process causes hydra to spawn extra multirun.
  3. Not using the dirpath in checkpoint callback makes the trainer.test() execute without issues, but multiple folders still remain.

@awaelchli awaelchli added this to the 1.5.x milestone Jan 5, 2022
@jgbos
Copy link
Contributor

jgbos commented Jan 16, 2022

I brought it up in #2727, what has mostly been working for me is to change the lines @tchaton mentioned above to:

if _HYDRA_AVAILABLE:
    if HydraConfig.initialized():
        cwd = get_original_cwd()
        os_cwd = f'"{os.getcwd()}"'  # this is needed to handle characters like `=` in the directory name
        command = command[:2]
        command += ["-cp", str(Path.cwd().relative_to(Path(cwd)) / ".hydra"), "-cn", "config.yaml"]
        command += [f"hydra.run.dir={os_cwd}", f"hydra.job.name=train_ddp_process_{local_rank}"]

@tchaton
Copy link
Contributor

tchaton commented Jan 16, 2022

Dear @jgbos,

Would you be willing to make a PR to resolve this bug ?

@jgbos
Copy link
Contributor

jgbos commented Jan 18, 2022

@tchaton I'll think about it. I'm running into some issue with multirun that I need to debug, which brings me to the actual challenge of this PR... is there a good way to test Hydra+PL? It's still a bit hacky to assume the config.yaml file is in the .hydra directory (which is the default behavior). Maybe do a search with Path.glob for the config file? Again, I'll think about it.

@jgbos
Copy link
Contributor

jgbos commented Jan 18, 2022

FYI, if someone wants to fix this before I get to it, I think this code should work (needs testing though):

            if _HYDRA_AVAILABLE:
                if HydraConfig.initialized():
                    orig_cwd = get_original_cwd()
                    cwd = os.getcwd()
                    os_cwd = f'"{cwd}"'  # this is needed to handle characters like `=` in the directory name

                    hydra_cfg = HydraConfig.get()
                    hydra_output = os.path.join(os.path.relpath(cwd, orig_cwd), hydra_cfg.output_subdir)

                    command = command_no_args  # command[:2] or command[:3] if using "python -m script.main"
                    command += ["-cp", hydra_output, "-cn", "config.yaml"]
                    command += [f"hydra.run.dir={os_cwd}", f"hydra.job.name=train_ddp_process_{local_rank}"]

@tchaton
Copy link
Contributor

tchaton commented Jan 20, 2022

Hey @jgbos,

Great question. I think the simplest is to create a simple test with the config file, not in the right place and see if you can recover from it.

Best,
T.C

@jgbos
Copy link
Contributor

jgbos commented Jan 20, 2022

@tchaton True, I'm also thinking about PL+Hydra properties to test. There are going to be a few edge cases my solution doesn't support. I would rather PL support something like python -m pytorch_ligntning.main pickled_data.pk for launching their subprocesses. That way the subprocess would work independent of how a user launches a job. Not sure how that would be implemented though. Tough problem to solve nicely.

@tchaton
Copy link
Contributor

tchaton commented Jan 21, 2022

Yes, this needs to be revisited :) Feel free to open a PR with your solution, so we can iterate on your findings.

@carmocca carmocca added 3rd party Related to a 3rd-party argparse (removed) Related to argument parsing (argparse, Hydra, ...) and removed hydra 3rd party Related to a 3rd-party labels Jan 29, 2022
@jieru-hu
Copy link

hi @ashleve - thanks for creating the minimal repro! that was really helpful.

Sounds like there are two issues here:

  1. Hydra changes working dir and as a result the checkpoint cannot be found.
  2. hydra.sweep.dir got created twice somehow in ddp mode.

As for 1, in Hydra 1.2 (the one we are currently working on), we added an option to not changing current working dir. If you run your application with hydra.job.chdir=False, it should work. We've recently put out a dev release of Hydra 1.2 . You can install with pip install hydra-core --pre --upgrade in case you want to give that a try.

@rohitgr7 rohitgr7 assigned rohitgr7 and unassigned rohitgr7 Feb 28, 2022
@carmocca
Copy link
Member

carmocca commented Mar 1, 2022

@jgbos Will #11617 fix this issue?

@jgbos
Copy link
Contributor

jgbos commented Mar 1, 2022

@carmocca Yes, it should fix this issue.

@Borda Borda modified the milestones: 1.5.x, 1.6 Mar 21, 2022
@carmocca carmocca modified the milestones: 1.6, future Mar 22, 2022
@OZOOOOOH
Copy link

Hope this issue will be fixed

@jgbos
Copy link
Contributor

jgbos commented May 11, 2022

@OZOOOOOH sorry, I've been swamped on other projects to finish the PR. You can checkout a current solution we have here: https://mit-ll-responsible-ai.github.io/responsible-ai-toolbox/how_to/hydraddp.html.

@Borda Borda added priority: 1 Medium priority task and removed priority: 0 High priority task labels Aug 8, 2022
@carmocca carmocca modified the milestones: pl:future, pl:1.8 Sep 27, 2022
@awaelchli
Copy link
Member

awaelchli commented Nov 21, 2022

We had to revert part of this change in #15737. I'm reopening the issue so that we can continue to look for a solution that does not involve changing the current working directory when launching the processes.

@awaelchli awaelchli reopened this Nov 21, 2022
@jgbos
Copy link
Contributor

jgbos commented Nov 22, 2022

I've been working to come up with a solution for this (and more generally). A robust solution would remove Lightnings need to depend on sys.argv in launching subprocesses. I think the following is possibly a robust solution using the submitit package, could someone else give it a try and let me know how it goes?

from typing import Optional

import submitit
from submitit.core.core import Executor


class SubmititTrainer:
    def __init__(self, executor: Optional[Executor] = None, devices: int = 1, **kwargs):
        """PyTorch Lightning Trainer Wrapped by Submitit

        This class does not inherit `Trainer` because I want to support DDP in the notebook.

        Parameters
        ----------
        executor: Executor | None (default: None)
            The submitit executor, if `None` it uses a default local executor.
        devices: int = 1
            The devices for `pl.Trainer`
        **kwargs: Any
            Arguments for `pl.Trainer`
        """
        self.kwargs = kwargs
        self.devices = devices

        if executor is None:
            self.executor = submitit.AutoExecutor("submitit_log", cluster="local")
            self.executor.update_parameters(gpus_per_node=devices)

    def _init_trainer(self):
        return pl.Trainer(devices=self.devices, **self.kwargs)

    def _fit(self, *args, **kwargs):
        trainer = self._init_trainer()
        return trainer.fit(*args, **kwargs)

    def _test(self, *args, **kwargs):
        trainer = self._init_trainer()
        return trainer.test(*args, **kwargs)

    def _predict(self, *args, **kwargs):
        trainer = self._init_trainer()
        return trainer.predict(*args, **kwargs)

    def fit(self, *args, **kwargs):
        return self.executor.submit(self._fit, *args, **kwargs)

    def test(self, *args, **kwargs):
        return self.executor.submit(self._test, *args, **kwargs)

    def predict(self, *args, **kwargs):
        return self.executor.submit(self._predict, *args, **kwargs)


from rai_toolbox.mushin.testing.lightning import SimpleLightningModule, SimpleDataModule


dm = SimpleDataModule()
model = SimpleLightningModule()
trainer = SubmititTrainer(strategy="ddp", accelerator="gpu", devices=2, max_epochs=1)

job = trainer.test(model=model, datamodule=dm)
print(f"Started Local Job: {job.job_id}")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
argparse (removed) Related to argument parsing (argparse, Hydra, ...) bug Something isn't working priority: 1 Medium priority task strategy: ddp DistributedDataParallel
Projects
None yet
Development

Successfully merging a pull request may close this issue.

9 participants