diff --git a/examples/app_multi_node/train_pl.py b/examples/app_multi_node/train_lt.py similarity index 69% rename from examples/app_multi_node/train_pl.py rename to examples/app_multi_node/train_lt.py index e887eaef7c075..5cbee32dd8132 100644 --- a/examples/app_multi_node/train_pl.py +++ b/examples/app_multi_node/train_lt.py @@ -1,9 +1,9 @@ import lightning as L -from lightning.app.components import PyTorchLightningMultiNode +from lightning.app.components import LightningTrainerMultiNode from lightning.pytorch.demos.boring_classes import BoringModel -class PyTorchLightningDistributed(L.LightningWork): +class LightningTrainerDistributed(L.LightningWork): @staticmethod def run(): model = BoringModel() @@ -16,8 +16,8 @@ def run(): # Run over 2 nodes of 4 x V100 app = L.LightningApp( - PyTorchLightningMultiNode( - PyTorchLightningDistributed, + LightningTrainerMultiNode( + LightningTrainerDistributed, num_nodes=2, cloud_compute=L.CloudCompute("gpu-fast-multi"), # 4 x V100 ) diff --git a/examples/app_multi_node/train_pl_script.py b/examples/app_multi_node/train_lt_script.py similarity index 100% rename from examples/app_multi_node/train_pl_script.py rename to examples/app_multi_node/train_lt_script.py diff --git a/pyproject.toml b/pyproject.toml index 005eba2846a31..bc8d9c7658dcd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -62,7 +62,7 @@ module = [ "lightning_app.components.multi_node.lite", "lightning_app.components.multi_node.base", "lightning_app.components.multi_node.pytorch_spawn", - "lightning_app.components.multi_node.pl", + "lightning_app.components.multi_node.trainer", "lightning_app.api.http_methods", "lightning_app.api.request_types", "lightning_app.cli.commands.app_commands", diff --git a/src/lightning_app/components/__init__.py b/src/lightning_app/components/__init__.py index e72d1f443b33c..2426a9042b516 100644 --- a/src/lightning_app/components/__init__.py +++ b/src/lightning_app/components/__init__.py @@ -1,9 +1,9 @@ from lightning_app.components.database.client import DatabaseClient from lightning_app.components.database.server import Database from lightning_app.components.multi_node import ( + LightningTrainerMultiNode, LiteMultiNode, MultiNode, - PyTorchLightningMultiNode, PyTorchSpawnMultiNode, ) from lightning_app.components.python.popen import PopenPythonScript @@ -29,5 +29,5 @@ "LightningTrainingComponent", "PyTorchLightningScriptRunner", "PyTorchSpawnMultiNode", - "PyTorchLightningMultiNode", + "LightningTrainerMultiNode", ] diff --git a/src/lightning_app/components/multi_node/__init__.py b/src/lightning_app/components/multi_node/__init__.py index 2921f79dc75b5..b2d45a2610a58 100644 --- a/src/lightning_app/components/multi_node/__init__.py +++ b/src/lightning_app/components/multi_node/__init__.py @@ -1,6 +1,6 @@ from lightning_app.components.multi_node.base import MultiNode from lightning_app.components.multi_node.lite import LiteMultiNode -from lightning_app.components.multi_node.pl import PyTorchLightningMultiNode from lightning_app.components.multi_node.pytorch_spawn import PyTorchSpawnMultiNode +from lightning_app.components.multi_node.trainer import LightningTrainerMultiNode -__all__ = ["LiteMultiNode", "MultiNode", "PyTorchSpawnMultiNode", "PyTorchLightningMultiNode"] +__all__ = ["LiteMultiNode", "MultiNode", "PyTorchSpawnMultiNode", "LightningTrainerMultiNode"] diff --git a/src/lightning_app/components/multi_node/pl.py b/src/lightning_app/components/multi_node/trainer.py similarity index 92% rename from src/lightning_app/components/multi_node/pl.py rename to src/lightning_app/components/multi_node/trainer.py index c11b72b6ce68d..ea33106a7ece9 100644 --- a/src/lightning_app/components/multi_node/pl.py +++ b/src/lightning_app/components/multi_node/trainer.py @@ -13,14 +13,14 @@ @runtime_checkable -class _PyTorchLightningWorkProtocol(Protocol): +class _LightningTrainerWorkProtocol(Protocol): @staticmethod def run() -> None: ... @dataclass -class _PyTorchLightningRunExecutor(_PyTorchSpawnRunExecutor): +class _LightningTrainerRunExecutor(_PyTorchSpawnRunExecutor): @staticmethod def run( local_rank: int, @@ -71,7 +71,7 @@ def pre_fn(trainer, *args, **kwargs): tracer._restore() -class PyTorchLightningMultiNode(MultiNode): +class LightningTrainerMultiNode(MultiNode): def __init__( self, work_cls: Type["LightningWork"], @@ -80,7 +80,7 @@ def __init__( *work_args: Any, **work_kwargs: Any, ) -> None: - assert issubclass(work_cls, _PyTorchLightningWorkProtocol) + assert issubclass(work_cls, _LightningTrainerWorkProtocol) if not is_static_method(work_cls, "run"): raise TypeError( f"The provided {work_cls} run method needs to be static for now." @@ -89,7 +89,7 @@ def __init__( # Note: Private way to modify the work run executor # Probably exposed to the users in the future if needed. - work_cls._run_executor_cls = _PyTorchLightningRunExecutor + work_cls._run_executor_cls = _LightningTrainerRunExecutor super().__init__( work_cls,