Skip to content

Commit

Permalink
Launch options for Lightning Lite (#14992)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
Co-authored-by: Carlos Mocholí <carlossmocholi@gmail.com>
  • Loading branch information
3 people committed Nov 2, 2022
1 parent 5d638cf commit 6aa6423
Show file tree
Hide file tree
Showing 8 changed files with 546 additions and 14 deletions.
6 changes: 4 additions & 2 deletions src/lightning_lite/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

-

- Added `LightningLite.launch()` to programmatically launch processes (e.g. in Jupyter notebook) ([#14992](https://github.com/Lightning-AI/lightning/issues/14992))
- Added the option to launch Lightning Lite scripts from the CLI, without the need to wrap the code into the `run` method ([#14992](https://github.com/Lightning-AI/lightning/issues/14992))

-

Expand All @@ -19,7 +21,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

-
- The `LightningLite.run()` method is no longer abstract ([#14992](https://github.com/Lightning-AI/lightning/issues/14992))

-

Expand Down
171 changes: 171 additions & 0 deletions src/lightning_lite/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from argparse import ArgumentParser, Namespace
from typing import List, Tuple

from lightning_lite.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator
from lightning_lite.utilities.device_parser import _parse_gpu_ids
from lightning_lite.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_13

_log = logging.getLogger(__name__)

_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu")
_SUPPORTED_STRATEGIES = (None, "ddp", "dp", "deepspeed")
_SUPPORTED_PRECISION = ("64", "32", "16", "bf16")


def _parse_args() -> Tuple[Namespace, List[str]]:
parser = ArgumentParser(description="Launch your script with the Lightning Lite CLI.")
parser.add_argument("script", type=str, help="Path to the Python script with Lightning Lite inside.")
parser.add_argument(
"--accelerator",
type=str,
default="cpu",
choices=_SUPPORTED_ACCELERATORS,
help="The hardware accelerator to run on.",
)
parser.add_argument(
"--strategy",
type=str,
default=None,
choices=_SUPPORTED_STRATEGIES,
help="Strategy for how to run across multiple devices.",
)
parser.add_argument(
"--devices",
type=str,
default="1",
help=(
"Number of devices to run on (``int``), which devices to run on (``list`` or ``str``), or ``'auto'``."
" The value applies per node."
),
)
parser.add_argument(
"--num-nodes",
"--num_nodes",
type=int,
default=1,
help="Number of machines (nodes) for distributed execution.",
)
parser.add_argument(
"--node-rank",
"--node_rank",
type=int,
default=0,
help=(
"The index of the machine (node) this command gets started on. Must be a number in the range"
" 0, ..., num_nodes - 1."
),
)
parser.add_argument(
"--main-address",
"--main_address",
type=str,
default="127.0.0.1",
help="The hostname or IP address of the main machine (usually the one with node_rank = 0).",
)
parser.add_argument(
"--main-port",
"--main_port",
type=int,
default=29400,
help="The main port to connect to the main machine.",
)
parser.add_argument(
"--precision",
type=str,
default="32",
choices=_SUPPORTED_PRECISION,
help=(
"Double precision (``64``), full precision (``32``), half precision (``16``) or bfloat16 precision"
" (``'bf16'``)"
),
)

args, script_args = parser.parse_known_args()
return args, script_args


def _set_env_variables(args: Namespace) -> None:
"""Set the environment variables for the new processes.
The Lite connector will parse the arguments set here.
"""
os.environ["LT_CLI_USED"] = "1"
os.environ["LT_ACCELERATOR"] = str(args.accelerator)
if args.strategy is not None:
os.environ["LT_STRATEGY"] = str(args.strategy)
os.environ["LT_DEVICES"] = str(args.devices)
os.environ["LT_NUM_NODES"] = str(args.num_nodes)
os.environ["LT_PRECISION"] = str(args.precision)


def _get_num_processes(accelerator: str, devices: str) -> int:
"""Parse the `devices` argument to determine how many processes need to be launched on the current machine."""
if accelerator == "gpu":
parsed_devices = _parse_gpu_ids(devices, include_cuda=True, include_mps=True)
elif accelerator == "cuda":
parsed_devices = CUDAAccelerator.parse_devices(devices)
elif accelerator == "mps":
parsed_devices = MPSAccelerator.parse_devices(devices)
elif accelerator == "tpu":
raise ValueError("Launching processes for TPU through the CLI is not supported.")
else:
return CPUAccelerator.parse_devices(devices)
return len(parsed_devices) if parsed_devices is not None else 0


def _torchrun_launch(args: Namespace, script_args: List[str]) -> None:
"""This will invoke `torchrun` programmatically to launch the given script in new processes."""

if _IS_WINDOWS and _TORCH_GREATER_EQUAL_1_13:
# TODO: remove once import issue is resolved: https://github.com/pytorch/pytorch/issues/85427
_log.error(
"On the Windows platform, this launcher is currently only supported on torch < 1.13 due to a bug"
" upstream: https://github.com/pytorch/pytorch/issues/85427"
)
exit(1)

import torch.distributed.run as torchrun

if args.strategy == "dp":
num_processes = 1
else:
num_processes = _get_num_processes(args.accelerator, args.devices)

torchrun_args = [
f"--nproc_per_node={num_processes}",
f"--nnodes={args.num_nodes}",
f"--node_rank={args.node_rank}",
f"--master_addr={args.main_address}",
f"--master_port={args.main_port}",
args.script,
]
torchrun_args.extend(script_args)

# set a good default number of threads for OMP to avoid warnings being emitted to the user
os.environ.setdefault("OMP_NUM_THREADS", str(max(1, (os.cpu_count() or 1) // num_processes)))
torchrun.main(torchrun_args)


def main() -> None:
args, script_args = _parse_args()
_set_env_variables(args)
_torchrun_launch(args, script_args)


if __name__ == "__main__":
main()
33 changes: 31 additions & 2 deletions src/lightning_lite/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from collections import Counter
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

import torch
from typing_extensions import Literal
Expand Down Expand Up @@ -101,6 +101,14 @@ def __init__(
precision: _PRECISION_INPUT = 32,
plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None,
) -> None:

# These arguments can be set through environment variables set by the CLI
accelerator = self._argument_from_env("accelerator", accelerator, default=None)
strategy = self._argument_from_env("strategy", strategy, default=None)
devices = self._argument_from_env("devices", devices, default=None)
num_nodes = self._argument_from_env("num_nodes", num_nodes, default=1)
precision = self._argument_from_env("precision", precision, default=32)

# 1. Parsing flags
# Get registered strategies, built-in accelerators and precision plugins
self._registered_strategies = STRATEGY_REGISTRY.available_strategies()
Expand Down Expand Up @@ -514,6 +522,27 @@ def _lazy_init_strategy(self) -> None:
f" found {self.strategy.__class__.__name__}."
)

@staticmethod
def _argument_from_env(name: str, current: Any, default: Any) -> Any:
env_value: Optional[Union[str, int]] = os.environ.get("LT_" + name.upper())

if env_value is None:
return current

if name == "precision":
# TODO: support precision input as string, then this special handling is not needed
env_value = int(env_value) if env_value in ("16", "32", "64") else env_value

if env_value is not None and env_value != current and current != default:
raise ValueError(
f"Your code has `LightningLite({name}={current!r}, ...)` but it conflicts with the value "
f"`--{name}={current}` set through the CLI. "
" Remove it either from the CLI or from the Lightning Lite object."
)
if env_value is None:
return current
return env_value

@property
def is_distributed(self) -> bool:
# TODO: deprecate this property
Expand Down
52 changes: 44 additions & 8 deletions src/lightning_lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
from abc import ABC, abstractmethod
from abc import ABC
from contextlib import contextmanager, nullcontext
from functools import partial
from pathlib import Path
Expand All @@ -21,6 +22,7 @@
import torch
import torch.nn as nn
from lightning_utilities.core.apply_func import apply_to_collection
from lightning_utilities.core.overrides import is_overridden
from lightning_utilities.core.rank_zero import rank_zero_warn
from torch import Tensor
from torch.optim import Optimizer
Expand Down Expand Up @@ -90,8 +92,11 @@ def __init__(
self._precision: Precision = self._strategy.precision
self._models_setup: int = 0

# wrap the run method so we can inject setup logic or spawn processes for the user
setattr(self, "run", partial(self._run_impl, self.run))
self._prepare_run_method()
if _is_using_cli():
# when the CLI is used to launch the script, we need to set up the environment (init processes) here so
# that the user can immediately use all functionality in strategies
self._strategy.setup_environment()

@property
def device(self) -> torch.device:
Expand Down Expand Up @@ -126,7 +131,6 @@ def is_global_zero(self) -> bool:
"""Wether this rank is rank zero."""
return self._strategy.is_global_zero

@abstractmethod
def run(self, *args: Any, **kwargs: Any) -> Any:
"""All the code inside this run method gets accelerated by Lite.
Expand Down Expand Up @@ -413,6 +417,23 @@ def load(self, filepath: Union[str, Path]) -> Any:
"""
return self._strategy.load_checkpoint(filepath)

def launch(self, function: Optional[Callable[["LightningLite"], Any]] = None, *args: Any, **kwargs: Any) -> Any:
if _is_using_cli():
raise RuntimeError(
"This script was launched through the CLI, and processes have already been created. Calling "
" `.launch()` again is not allowed."
)
if function is not None and not inspect.signature(function).parameters:
raise TypeError(
"The function passed to `Lite.launch()` needs to take at least one argument. The launcher will pass"
" in the `LightningLite` object so you can use it inside the function."
)
function = partial(self._run_with_setup, function or _do_nothing)
args = [self, *args]
if self._strategy.launcher is not None:
return self._strategy.launcher.launch(function, *args, **kwargs)
return function(*args, **kwargs)

@staticmethod
def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None) -> int:
"""Helper function to seed everything without explicitly importing Lightning.
Expand All @@ -426,21 +447,19 @@ def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None)
return seed_everything(seed=seed, workers=workers)

def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:
# wrap the real run method with setup logic
run_method = partial(self._run_with_setup, run_method)

if self._strategy.launcher is not None:
return self._strategy.launcher.launch(run_method, *args, **kwargs)
else:
return run_method(*args, **kwargs)

def _run_with_setup(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:
def _run_with_setup(self, run_function: Callable, *args: Any, **kwargs: Any) -> Any:
self._strategy.setup_environment()
# apply sharded context to prevent OOM
with self._strategy.module_sharded_context(), _replace_dunder_methods(
DataLoader, "dataset"
), _replace_dunder_methods(BatchSampler):
return run_method(*args, **kwargs)
return run_function(*args, **kwargs)

def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn.Module:
initial_device = next(model.parameters()).device
Expand Down Expand Up @@ -481,6 +500,15 @@ def _get_distributed_sampler(dataloader: DataLoader, **kwargs: Any) -> Distribut
kwargs.setdefault("seed", int(os.getenv("PL_GLOBAL_SEED", 0)))
return DistributedSamplerWrapper(dataloader.sampler, **kwargs)

def _prepare_run_method(self) -> None:
if is_overridden("run", self, LightningLite) and _is_using_cli():
raise TypeError(
"Overriding `LightningLite.run()` and launching from the CLI is not allowed. Run the script normally,"
" or change your code to directly call `lite = LightningLite(...); lite.setup(...)` etc."
)
# wrap the run method, so we can inject setup logic or spawn processes for the user
setattr(self, "run", partial(self._run_impl, self.run))

@staticmethod
def _validate_setup(model: nn.Module, optimizers: Sequence[Optimizer]) -> None:
if isinstance(model, _LiteModule):
Expand All @@ -496,3 +524,11 @@ def _validate_setup_dataloaders(dataloaders: Sequence[DataLoader]) -> None:

if any(not isinstance(dl, DataLoader) for dl in dataloaders):
raise TypeError("Only PyTorch DataLoader are currently supported in `setup_dataloaders`.")


def _is_using_cli() -> bool:
return bool(int(os.environ.get("LT_CLI_USED", "0")))


def _do_nothing(*_: Any) -> None:
pass
5 changes: 3 additions & 2 deletions src/lightning_lite/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,9 @@ def wrapper(obj: Any, *args: Any, **kwargs: Any) -> None:


def _wrap_attr_method(method: Callable, tag: _WrapAttrTag) -> Callable:
"""Wraps the ``__setattr__`` or ``__delattr__`` method of classes (currently :class:`~torch.utils.data.DataLoader` and
:class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses."""
"""Wraps the ``__setattr__`` or ``__delattr__`` method of classes (currently
:class:`~torch.utils.data.DataLoader` and :class:`~torch.utils.data.BatchSampler`) in order to enable re-
instantiation of custom subclasses."""

@functools.wraps(method)
def wrapper(obj: Any, *args: Any) -> None:
Expand Down

0 comments on commit 6aa6423

Please sign in to comment.