Skip to content

Commit

Permalink
squash all
Browse files Browse the repository at this point in the history
  • Loading branch information
awaelchli committed Oct 4, 2022
1 parent c059db4 commit ec164b9
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 3 deletions.
15 changes: 15 additions & 0 deletions examples/01_lite_launch/launcher_cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Run this with
#
# python -m lightning_lite.cli examples/01_lite_launch/launcher_cli.py --devices 2 --precision bf16

import torch.distributed

from lightning_lite import LightningLite


if __name__ == "__main__":
lite = LightningLite()
print("launched", lite.global_rank)
assert torch.distributed.is_initialized()
lite.barrier()
print("end")
16 changes: 16 additions & 0 deletions examples/01_lite_launch/launcher_old.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import torch.distributed

from lightning_lite import LightningLite


class Lite(LightningLite):
def run(self):
print("launched", self.global_rank)
assert torch.distributed.is_initialized()
self.barrier()


if __name__ == "__main__":
lite = Lite(accelerator="cpu", devices=2, strategy="ddp")
lite.run()
print("after run", lite.global_rank)
12 changes: 12 additions & 0 deletions examples/01_lite_launch/launcher_script_v1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import torch.distributed

from lightning_lite import LightningLite


if __name__ == "__main__":
lite = LightningLite(accelerator="cpu", devices=2, strategy="ddp")
lite.launch()
print("launched", lite.global_rank)
assert torch.distributed.is_initialized()
lite.barrier()
print("end")
15 changes: 15 additions & 0 deletions examples/01_lite_launch/launcher_script_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch.distributed

from lightning_lite import LightningLite


def run(lite):
print("launched", lite.global_rank)
assert torch.distributed.is_initialized()
lite.barrier()
print("end")


if __name__ == "__main__":
lite = LightningLite(accelerator="cpu", devices=2, strategy="ddp")
lite.launch(run)
14 changes: 14 additions & 0 deletions examples/01_lite_launch/launcher_spawn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch.distributed

from lightning_lite import LightningLite


def run(lite):
print("launched", lite.global_rank)
assert torch.distributed.is_initialized()


if __name__ == "__main__":
lite = LightningLite(accelerator="cpu", devices=2, strategy="ddp_notebook")
lite.launch(run)
print("main process joins", lite.global_rank)
44 changes: 44 additions & 0 deletions src/lightning_lite/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import os
from argparse import ArgumentParser
import torch.distributed.run as torchrun


def main():
parser = ArgumentParser()
parser.add_argument("script", type=str)
parser.add_argument("--accelerator", type=str, default="cpu", choices=("cpu", "cuda", "mps", "tpu", "auto"))
# TODO: note for some accelerators/strategies, torchrun won't make sense (e.g. dp)
# TODO: should we include spawn?
parser.add_argument("--strategy", type=str, default=None, choices=(None, "ddp", "dp", "deepspeed"))
parser.add_argument("--devices", type=str, default="1")
parser.add_argument("--num-nodes", type=int, default=1)
parser.add_argument("--node-rank", type=int, default=0)
parser.add_argument("--main-address", type=str, default="127.0.0.1")
parser.add_argument("--main-port", type=int, default=29400)
parser.add_argument("--precision", type=str, default="32", choices=("32", "16", "bf16"))
args = parser.parse_args()

os.environ["LT_ACCELERATOR"] = str(args.accelerator)
if args.strategy:
os.environ["LT_STRATEGY"] = str(args.strategy)
os.environ["LT_DEVICES"] = str(args.devices)
os.environ["LT_NUM_NODES"] = str(args.num_nodes)
os.environ["LT_PRECISION"] = str(args.precision)

num_devices = int(args.devices) # TODO: count them

torchrun_args = []
torchrun_args.extend(["--nproc_per_node", str(num_devices)])
torchrun_args.extend(["--nnodes", str(args.num_nodes)])
torchrun_args.extend(["--node_rank", str(args.node_rank)])
torchrun_args.extend(["--master_addr", args.main_address])
torchrun_args.extend(["--master_port", str(args.main_port)])
torchrun_args.append(args.script)

os.environ.setdefault("OMP_NUM_THREADS", str(max(1, os.cpu_count() // num_devices)))

torchrun.main(torchrun_args)


if __name__ == "__main__":
main()
32 changes: 29 additions & 3 deletions src/lightning_lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,13 @@ def __init__(
precision: _PRECISION_INPUT = 32,
plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None,
) -> None:
accelerator = os.getenv("LT_ACCELERATOR", accelerator)
strategy = os.getenv("LT_STRATEGY", strategy)
devices = os.getenv("LT_DEVICES", devices)
num_nodes = os.getenv("LT_NUM_NODES", num_nodes)
precision = os.getenv("LT_PRECISION", precision)
precision = int(precision) if precision in ("16", "32") else precision

self._connector = _Connector(
accelerator=accelerator,
strategy=strategy,
Expand All @@ -93,6 +100,9 @@ def __init__(
# wrap the run method so we can inject setup logic or spawn processes for the user
setattr(self, "run", partial(self._run_impl, self.run))

if "LT_ACCELERATOR" in os.environ:
self._strategy.setup_environment()

@property
def device(self) -> torch.device:
"""The current device this process runs on.
Expand Down Expand Up @@ -126,7 +136,7 @@ def is_global_zero(self) -> bool:
"""Wether this rank is rank zero."""
return self._strategy.is_global_zero

@abstractmethod
# TODO(lite): Error/warn when run overridden but launcher is used
def run(self, *args: Any, **kwargs: Any) -> Any:
"""All the code inside this run method gets accelerated by Lite.
Expand Down Expand Up @@ -367,6 +377,15 @@ def load(self, filepath: Union[str, Path]) -> Any:
"""
return self._strategy.load_checkpoint(filepath)

def launch(self, function: Optional[Callable] = None, *args: Any, **kwargs: Any) -> Any:
function = _do_nothing if function is None else function
function = partial(self._function_with_strategy_setup, function)
args = [self, *args]
if self._strategy.launcher is not None:
return self._strategy.launcher.launch(function, *args, **kwargs)
else:
return function(*args, **kwargs)

@staticmethod
def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None) -> int:
"""Helper function to seed everything without explicitly importing Lightning.
Expand All @@ -380,9 +399,8 @@ def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None)
return seed_everything(seed=seed, workers=workers)

def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:
# wrap the real run method with setup logic
# TODO: skip launcher if already launched externally!
run_method = partial(self._run_with_setup, run_method)

if self._strategy.launcher is not None:
return self._strategy.launcher.launch(run_method, *args, **kwargs)
else:
Expand All @@ -396,6 +414,11 @@ def _run_with_setup(self, run_method: Callable, *args: Any, **kwargs: Any) -> An
), _replace_dunder_methods(BatchSampler):
return run_method(*args, **kwargs)

def _function_with_strategy_setup(self, function: Callable, *args: Any, **kwargs: Any) -> Any:
self._strategy.setup_environment()
with _replace_dunder_methods(DataLoader, "dataset"), _replace_dunder_methods(BatchSampler):
return function(*args, **kwargs)

def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn.Module:
initial_device = next(model.parameters()).device
if any(param.device != initial_device for param in model.parameters()):
Expand Down Expand Up @@ -450,3 +473,6 @@ def _validate_setup_dataloaders(dataloaders: Sequence[DataLoader]) -> None:

if any(not isinstance(dl, DataLoader) for dl in dataloaders):
raise TypeError("Only PyTorch DataLoader are currently supported in `setup_dataloaders`.")


def _do_nothing(*_): pass

0 comments on commit ec164b9

Please sign in to comment.