Skip to content

Commit

Permalink
Fix ImportErrors on Multinode if package not present (#15963)
Browse files Browse the repository at this point in the history
  • Loading branch information
justusschock committed Dec 8, 2022
1 parent 23b12ee commit cbd4dd6
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 10 deletions.
2 changes: 2 additions & 0 deletions src/lightning_app/CHANGELOG.md
Expand Up @@ -69,6 +69,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed PythonServer generating noise on M1 ([#15949](https://github.com/Lightning-AI/lightning/pull/15949))

- Fixed `ImportError` on Multinode if package not present ([#15963](https://github.com/Lightning-AI/lightning/pull/15963))

- Fixed multiprocessing breakpoint ([#15950](https://github.com/Lightning-AI/lightning/pull/15950))

- Fixed detection of a Lightning App running in debug mode ([#15951](https://github.com/Lightning-AI/lightning/pull/15951))
Expand Down
13 changes: 8 additions & 5 deletions src/lightning_app/components/multi_node/lite.py
Expand Up @@ -37,11 +37,14 @@ def run(
mps_accelerators = []

for pkg_name in ("lightning.lite", "lightning_" + "lite"):
pkg = importlib.import_module(pkg_name)
lites.append(pkg.LightningLite)
strategies.append(pkg.strategies.DDPSpawnShardedStrategy)
strategies.append(pkg.strategies.DDPSpawnStrategy)
mps_accelerators.append(pkg.accelerators.MPSAccelerator)
try:
pkg = importlib.import_module(pkg_name)
lites.append(pkg.LightningLite)
strategies.append(pkg.strategies.DDPSpawnShardedStrategy)
strategies.append(pkg.strategies.DDPSpawnStrategy)
mps_accelerators.append(pkg.accelerators.MPSAccelerator)
except (ImportError, ModuleNotFoundError):
continue

# Used to configure PyTorch progress group
os.environ["MASTER_ADDR"] = main_address
Expand Down
13 changes: 8 additions & 5 deletions src/lightning_app/components/multi_node/trainer.py
Expand Up @@ -37,11 +37,14 @@ def run(
mps_accelerators = []

for pkg_name in ("lightning.pytorch", "pytorch_" + "lightning"):
pkg = importlib.import_module(pkg_name)
trainers.append(pkg.Trainer)
strategies.append(pkg.strategies.DDPSpawnShardedStrategy)
strategies.append(pkg.strategies.DDPSpawnStrategy)
mps_accelerators.append(pkg.accelerators.MPSAccelerator)
try:
pkg = importlib.import_module(pkg_name)
trainers.append(pkg.Trainer)
strategies.append(pkg.strategies.DDPSpawnShardedStrategy)
strategies.append(pkg.strategies.DDPSpawnStrategy)
mps_accelerators.append(pkg.accelerators.MPSAccelerator)
except (ImportError, ModuleNotFoundError):
continue

# Used to configure PyTorch progress group
os.environ["MASTER_ADDR"] = main_address
Expand Down

0 comments on commit cbd4dd6

Please sign in to comment.