Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Launch options for Lightning Lite #14992

Merged
merged 55 commits into from
Nov 2, 2022
Merged
Show file tree
Hide file tree
Changes from 40 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
ec164b9
squash all
awaelchli Oct 4, 2022
43b9768
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 4, 2022
bd2cde2
support script args
awaelchli Oct 4, 2022
bae0c67
wip
awaelchli Oct 4, 2022
71c7b60
refactor
awaelchli Oct 4, 2022
39ff946
reset
awaelchli Oct 4, 2022
98cb643
cli stuff
awaelchli Oct 4, 2022
2450512
cli tests
awaelchli Oct 5, 2022
e541119
test connector
awaelchli Oct 5, 2022
198c59c
function inspection
awaelchli Oct 5, 2022
7f71abe
types
awaelchli Oct 5, 2022
a956e55
tests for collision
awaelchli Oct 5, 2022
52894b7
add notice
awaelchli Oct 5, 2022
71cafff
docs
awaelchli Oct 5, 2022
002e5f7
mypy stuff
awaelchli Oct 5, 2022
145d88f
changelog
awaelchli Oct 5, 2022
a8cf41f
remove demo examples
awaelchli Oct 5, 2022
70d8b78
error handling for run and cli
awaelchli Oct 5, 2022
36f177d
Merge branch 'master' into lite/launcher-poc
awaelchli Oct 5, 2022
1592d7d
Merge branch 'master' into lite/launcher-poc
awaelchli Oct 5, 2022
d08791d
remove handled todo
awaelchli Oct 5, 2022
71421dc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 5, 2022
784d59b
Merge branch 'master' into lite/launcher-poc
awaelchli Oct 6, 2022
3a5964a
fix test
awaelchli Oct 6, 2022
5d9dd67
update cli detection
awaelchli Oct 8, 2022
7762963
mypy
awaelchli Oct 8, 2022
ffa18e6
Merge branch 'master' into lite/launcher-poc
awaelchli Oct 8, 2022
5d6d983
add description
awaelchli Oct 8, 2022
9b200a9
address review comments
awaelchli Oct 8, 2022
5a0e6fc
fix env variable selection
awaelchli Oct 8, 2022
85c9463
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 8, 2022
1a9c982
Merge branch 'master' into lite/launcher-poc
awaelchli Oct 19, 2022
43742d3
Merge branch 'master' into lite/launcher-poc
awaelchli Nov 1, 2022
7332980
fix test
awaelchli Nov 1, 2022
9b8ec68
unused import
awaelchli Nov 1, 2022
bda824a
Merge branch 'master' into lite/launcher-poc
awaelchli Nov 1, 2022
d92b37f
notebook
awaelchli Nov 1, 2022
c08b3f7
notebook
awaelchli Nov 1, 2022
9698440
raise error on win + 1.13
awaelchli Nov 1, 2022
c47a0af
fix
awaelchli Nov 1, 2022
ce470c8
fix
awaelchli Nov 1, 2022
0a1c0f5
Update src/lightning_lite/CHANGELOG.md
awaelchli Nov 1, 2022
84184ab
fix type
awaelchli Nov 1, 2022
0c937a2
nit
awaelchli Nov 1, 2022
6a63051
update gpu parsing
awaelchli Nov 1, 2022
c37a4fe
chlog
awaelchli Nov 2, 2022
add70ef
skip
awaelchli Nov 2, 2022
5f829fa
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 2, 2022
0544a7b
Update src/lightning_lite/cli.py
awaelchli Nov 2, 2022
2d79cb3
update test from code review
awaelchli Nov 2, 2022
d69523a
local import
awaelchli Nov 2, 2022
63f5931
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 2, 2022
c00ff74
address review
awaelchli Nov 2, 2022
61d091e
Merge branch 'lite/launcher-poc' of github.com:Lightning-AI/lightning…
awaelchli Nov 2, 2022
20ee566
fix windows test import
awaelchli Nov 2, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/lightning_lite/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

-
### Added
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

- Added `LightningLite.launch()` to programmatically launch processes (e.g. in Jupyter notebook) ([#14992](https://github.com/Lightning-AI/lightning/issues/14992))
- Added the option to launch Lightning Lite scripts from the CLI, without the need to wrap the code into the `run` method ([#14992](https://github.com/Lightning-AI/lightning/issues/14992))

-

Expand Down
168 changes: 168 additions & 0 deletions src/lightning_lite/cli.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
from argparse import ArgumentParser, Namespace
from typing import List, Tuple

from lightning_lite.accelerators import CPUAccelerator, CUDAAccelerator, MPSAccelerator
from lightning_lite.utilities.imports import _IS_WINDOWS, _TORCH_GREATER_EQUAL_1_13

# torchrun in PyTorch 1.13.0 has a bug on the Windows platform and is thus not importable:
# https://github.com/pytorch/pytorch/issues/85427
if _IS_WINDOWS and _TORCH_GREATER_EQUAL_1_13:
import torch.distributed.run as torchrun
else:
torchrun = None

_log = logging.getLogger(__name__)

_SUPPORTED_ACCELERATORS = ("cpu", "gpu", "cuda", "mps", "tpu")
_SUPPORTED_STRATEGIES = (None, "ddp", "dp", "deepspeed")
_SUPPORTED_PRECISION = ("64", "32", "16", "bf16")


def _parse_args() -> Tuple[Namespace, List[str]]:
parser = ArgumentParser(description="Launch your script with the Lightning Lite CLI.")
parser.add_argument("script", type=str, help="Path to the Python script with Lightning Lite inside.")
parser.add_argument(
"--accelerator",
type=str,
default="cpu",
choices=_SUPPORTED_ACCELERATORS,
help="The hardware accelerator to run on.",
)
parser.add_argument(
"--strategy",
type=str,
default=None,
choices=_SUPPORTED_STRATEGIES,
help="Strategy for how to run across multiple devices.",
)
parser.add_argument(
"--devices",
type=str,
default="1",
help=(
"Number of devices to run on (``int``), which devices to run on (``list`` or ``str``), or ``'auto'``."
" The value applies per node."
),
)
parser.add_argument(
justusschock marked this conversation as resolved.
Show resolved Hide resolved
"--num-nodes",
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
type=int,
default=1,
help="Number of machines (nodes) for distributed execution.",
)
parser.add_argument(
"--node-rank",
type=int,
default=0,
help=(
"The index of the machine (node) this command gets started on. Must be a number in the range"
" 0, ..., num_nodes - 1."
),
)
parser.add_argument(
"--main-address",
type=str,
default="127.0.0.1",
help="The hostname or IP address of the main machine (usually the one with node_rank = 0).",
)
parser.add_argument(
"--main-port",
type=int,
default=29400,
help="The main port to connect to the main machine.",
)
parser.add_argument(
"--precision",
type=str,
default="32",
choices=_SUPPORTED_PRECISION,
help=(
"Double precision (``64``), full precision (``32``), half precision (``16``) or bfloat16 precision"
" (``'bf16'``)"
),
)

args, script_args = parser.parse_known_args()
return args, script_args


def _set_env_variables(args: Namespace) -> None:
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
"""Set the environment variables for the new processes.

The Lite connector will parse the arguments set here.
"""
os.environ["LT_CLI_USED"] = "1"
os.environ["LT_ACCELERATOR"] = str(args.accelerator)
if args.strategy is not None:
os.environ["LT_STRATEGY"] = str(args.strategy)
os.environ["LT_DEVICES"] = str(args.devices)
os.environ["LT_NUM_NODES"] = str(args.num_nodes)
os.environ["LT_PRECISION"] = str(args.precision)


def _get_num_processes(accelerator: str, devices: str) -> int:
carmocca marked this conversation as resolved.
Show resolved Hide resolved
"""Parse the `devices` argument to determine how many processes need to be launched on the current machine."""
if accelerator in ("cuda", "gpu"):
parsed_devices = CUDAAccelerator.parse_devices(devices)
elif accelerator in ("mps", "gpu"):
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
parsed_devices = MPSAccelerator.parse_devices(devices)
elif accelerator == "tpu":
raise ValueError("Launching processes for TPU through the CLI is not supported.")
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
else:
return CPUAccelerator.parse_devices(devices)
return len(parsed_devices) if parsed_devices is not None else 0


def _torchrun_launch(args: Namespace, script_args: List[str]) -> None:
"""This will invoke `torchrun` programmatically to launch the given script in new processes."""

if _IS_WINDOWS and _TORCH_GREATER_EQUAL_1_13:
# TODO: remove once import issue is resolved: https://github.com/pytorch/pytorch/issues/85427
_log.error(
"On the Windows platform, this launcher is currently only supported on torch < 1.13 due to a bug"
" upstream: https://github.com/pytorch/pytorch/issues/85427"
)
exit(1)

if args.strategy == "dp":
num_processes = 1
else:
num_processes = _get_num_processes(args.accelerator, args.devices)
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

torchrun_args = []
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
torchrun_args.extend(["--nproc_per_node", str(num_processes)])
torchrun_args.extend(["--nnodes", str(args.num_nodes)])
torchrun_args.extend(["--node_rank", str(args.node_rank)])
torchrun_args.extend(["--master_addr", args.main_address])
torchrun_args.extend(["--master_port", str(args.main_port)])
torchrun_args.append(args.script)
torchrun_args.extend(script_args)

# set a good default number of threads for OMP to avoid warnings being emitted to the user
os.environ.setdefault("OMP_NUM_THREADS", str(max(1, (os.cpu_count() or 1) // num_processes)))
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
torchrun.main(torchrun_args)


def main() -> None:
args, script_args = _parse_args()
_set_env_variables(args)
_torchrun_launch(args, script_args)


if __name__ == "__main__":
main()
33 changes: 31 additions & 2 deletions src/lightning_lite/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from collections import Counter
from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

import torch
from typing_extensions import Literal
Expand Down Expand Up @@ -101,6 +101,14 @@ def __init__(
precision: _PRECISION_INPUT = 32,
plugins: Optional[Union[_PLUGIN_INPUT, List[_PLUGIN_INPUT]]] = None,
) -> None:

# These arguments can be set through environment variables set by the CLI
accelerator = self._argument_from_env("accelerator", accelerator, default=None)
strategy = self._argument_from_env("strategy", strategy, default=None)
devices = self._argument_from_env("devices", devices, default=None)
num_nodes = self._argument_from_env("num_nodes", num_nodes, default=1)
precision = self._argument_from_env("precision", precision, default=32)

# 1. Parsing flags
# Get registered strategies, built-in accelerators and precision plugins
self._registered_strategies = STRATEGY_REGISTRY.available_strategies()
Expand Down Expand Up @@ -514,6 +522,27 @@ def _lazy_init_strategy(self) -> None:
f" found {self.strategy.__class__.__name__}."
)

@staticmethod
carmocca marked this conversation as resolved.
Show resolved Hide resolved
def _argument_from_env(name: str, current: Any, default: Any) -> Any:
env_value: Optional[Union[str, int]] = os.environ.get("LT_" + name.upper())

if env_value is None:
return current

if name == "precision":
# TODO: support precision input as string, then this special handling is not needed
env_value = int(env_value) if env_value in ("16", "32", "64") else env_value
justusschock marked this conversation as resolved.
Show resolved Hide resolved

if env_value is not None and env_value != current and current != default:
raise ValueError(
f"Your code has `LightningLite({name}={current!r}, ...)` but it conflicts with the value "
f"`--{name}={current}` set through the CLI. "
" Remove it either from the CLI or from the Lightning Lite object."
)
if env_value is None:
return current
return env_value

@property
def is_distributed(self) -> bool:
# TODO: deprecate this property
Expand Down
52 changes: 44 additions & 8 deletions src/lightning_lite/lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
import os
from abc import ABC, abstractmethod
from abc import ABC
from contextlib import contextmanager, nullcontext
from functools import partial
from pathlib import Path
Expand All @@ -21,6 +22,7 @@
import torch
import torch.nn as nn
from lightning_utilities.core.apply_func import apply_to_collection
from lightning_utilities.core.overrides import is_overridden
from lightning_utilities.core.rank_zero import rank_zero_warn
from torch import Tensor
from torch.optim import Optimizer
Expand Down Expand Up @@ -90,8 +92,11 @@ def __init__(
self._precision: Precision = self._strategy.precision
self._models_setup: int = 0

# wrap the run method so we can inject setup logic or spawn processes for the user
setattr(self, "run", partial(self._run_impl, self.run))
self._prepare_run_method()
if _is_using_cli():
# when the CLI is used to launch the script, we need to set up the environment (init processes) here so
# that the user can immediately use all functionality in strategies
self._strategy.setup_environment()

@property
def device(self) -> torch.device:
Expand Down Expand Up @@ -126,7 +131,6 @@ def is_global_zero(self) -> bool:
"""Wether this rank is rank zero."""
return self._strategy.is_global_zero

@abstractmethod
def run(self, *args: Any, **kwargs: Any) -> Any:
"""All the code inside this run method gets accelerated by Lite.

Expand Down Expand Up @@ -413,6 +417,23 @@ def load(self, filepath: Union[str, Path]) -> Any:
"""
return self._strategy.load_checkpoint(filepath)

def launch(self, function: Optional[Callable[["LightningLite"], Any]] = None, *args: Any, **kwargs: Any) -> Any:
if _is_using_cli():
raise RuntimeError(
"This script was launched through the CLI, and processes have already been created. Calling "
" `.launch()` again is not allowed."
)
if function is not None and not inspect.signature(function).parameters:
raise TypeError(
"The function passed to `Lite.launch()` needs to take at least one argument. The launcher will pass"
" in the `LightningLite` object so you can use it inside the function."
)
function = partial(self._run_with_setup, function or _do_nothing)
args = [self, *args]
if self._strategy.launcher is not None:
return self._strategy.launcher.launch(function, *args, **kwargs)
return function(*args, **kwargs)

@staticmethod
def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None) -> int:
"""Helper function to seed everything without explicitly importing Lightning.
Expand All @@ -426,21 +447,19 @@ def seed_everything(seed: Optional[int] = None, workers: Optional[bool] = None)
return seed_everything(seed=seed, workers=workers)

def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:
# wrap the real run method with setup logic
run_method = partial(self._run_with_setup, run_method)

if self._strategy.launcher is not None:
return self._strategy.launcher.launch(run_method, *args, **kwargs)
else:
return run_method(*args, **kwargs)

def _run_with_setup(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:
def _run_with_setup(self, run_function: Callable, *args: Any, **kwargs: Any) -> Any:
self._strategy.setup_environment()
# apply sharded context to prevent OOM
with self._strategy.module_sharded_context(), _replace_dunder_methods(
DataLoader, "dataset"
), _replace_dunder_methods(BatchSampler):
return run_method(*args, **kwargs)
return run_function(*args, **kwargs)

def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn.Module:
initial_device = next(model.parameters()).device
Expand Down Expand Up @@ -481,6 +500,15 @@ def _get_distributed_sampler(dataloader: DataLoader, **kwargs: Any) -> Distribut
kwargs.setdefault("seed", int(os.getenv("PL_GLOBAL_SEED", 0)))
return DistributedSamplerWrapper(dataloader.sampler, **kwargs)

def _prepare_run_method(self) -> None:
if is_overridden("run", self, LightningLite) and _is_using_cli():
raise TypeError(
"Overriding `LightningLite.run()` and launching from the CLI is not allowed. Run the script normally,"
" or change your code to directly call `lite = LightningLite(...); lite.setup(...)` etc."
)
# wrap the run method, so we can inject setup logic or spawn processes for the user
setattr(self, "run", partial(self._run_impl, self.run))

@staticmethod
def _validate_setup(model: nn.Module, optimizers: Sequence[Optimizer]) -> None:
if isinstance(model, _LiteModule):
Expand All @@ -496,3 +524,11 @@ def _validate_setup_dataloaders(dataloaders: Sequence[DataLoader]) -> None:

if any(not isinstance(dl, DataLoader) for dl in dataloaders):
raise TypeError("Only PyTorch DataLoader are currently supported in `setup_dataloaders`.")


def _is_using_cli() -> bool:
return bool(int(os.environ.get("LT_CLI_USED", "0")))
carmocca marked this conversation as resolved.
Show resolved Hide resolved


def _do_nothing(*_: Any) -> None:
pass
5 changes: 3 additions & 2 deletions src/lightning_lite/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,9 @@ def wrapper(obj: Any, *args: Any, **kwargs: Any) -> None:


def _wrap_attr_method(method: Callable, tag: _WrapAttrTag) -> Callable:
"""Wraps the ``__setattr__`` or ``__delattr__`` method of classes (currently :class:`~torch.utils.data.DataLoader` and
:class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses."""
"""Wraps the ``__setattr__`` or ``__delattr__`` method of classes (currently
:class:`~torch.utils.data.DataLoader` and :class:`~torch.utils.data.BatchSampler`) in order to enable re-
instantiation of custom subclasses."""

@functools.wraps(method)
def wrapper(obj: Any, *args: Any) -> None:
Expand Down
6 changes: 5 additions & 1 deletion src/pytorch_lightning/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue with the `BaseFinetuning` callback not setting the `track_running_stats` attribute for batch normaliztion layers ([#15063](https://github.com/Lightning-AI/lightning/pull/15063))


## [1.8.0] - 2022-11-01
### Changed

## [1.8.0] - 2022-11-01

### Added

Expand Down Expand Up @@ -119,6 +120,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- The `NeptuneLogger` now uses `neptune.init_run` instead of the deprecated `neptune.init` to initialize a run ([#15393](https://github.com/Lightning-AI/lightning/pull/15393))


- The `LightningLite.run()` method is no longer abstract ([#14992](https://github.com/Lightning-AI/lightning/issues/14992))


### Deprecated

- Deprecated `LightningDeepSpeedModule` ([#14000](https://github.com/Lightning-AI/lightning/pull/14000))
Expand Down