Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[App] Rename to new convention #15621

Merged
merged 3 commits into from Nov 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
@@ -1,9 +1,9 @@
import lightning as L
from lightning.app.components import PyTorchLightningMultiNode
from lightning.app.components import LightningTrainerMultiNode
from lightning.pytorch.demos.boring_classes import BoringModel


class PyTorchLightningDistributed(L.LightningWork):
class LightningTrainerDistributed(L.LightningWork):
@staticmethod
def run():
model = BoringModel()
Expand All @@ -16,8 +16,8 @@ def run():

# Run over 2 nodes of 4 x V100
app = L.LightningApp(
PyTorchLightningMultiNode(
PyTorchLightningDistributed,
LightningTrainerMultiNode(
LightningTrainerDistributed,
num_nodes=2,
cloud_compute=L.CloudCompute("gpu-fast-multi"), # 4 x V100
)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Expand Up @@ -62,7 +62,7 @@ module = [
"lightning_app.components.multi_node.lite",
"lightning_app.components.multi_node.base",
"lightning_app.components.multi_node.pytorch_spawn",
"lightning_app.components.multi_node.pl",
"lightning_app.components.multi_node.trainer",
"lightning_app.api.http_methods",
"lightning_app.api.request_types",
"lightning_app.cli.commands.app_commands",
Expand Down
4 changes: 2 additions & 2 deletions src/lightning_app/components/__init__.py
@@ -1,9 +1,9 @@
from lightning_app.components.database.client import DatabaseClient
from lightning_app.components.database.server import Database
from lightning_app.components.multi_node import (
LightningTrainerMultiNode,
LiteMultiNode,
MultiNode,
PyTorchLightningMultiNode,
PyTorchSpawnMultiNode,
)
from lightning_app.components.python.popen import PopenPythonScript
Expand All @@ -29,5 +29,5 @@
"LightningTrainingComponent",
"PyTorchLightningScriptRunner",
"PyTorchSpawnMultiNode",
"PyTorchLightningMultiNode",
"LightningTrainerMultiNode",
]
4 changes: 2 additions & 2 deletions src/lightning_app/components/multi_node/__init__.py
@@ -1,6 +1,6 @@
from lightning_app.components.multi_node.base import MultiNode
from lightning_app.components.multi_node.lite import LiteMultiNode
from lightning_app.components.multi_node.pl import PyTorchLightningMultiNode
from lightning_app.components.multi_node.pytorch_spawn import PyTorchSpawnMultiNode
from lightning_app.components.multi_node.trainer import LightningTrainerMultiNode

__all__ = ["LiteMultiNode", "MultiNode", "PyTorchSpawnMultiNode", "PyTorchLightningMultiNode"]
__all__ = ["LiteMultiNode", "MultiNode", "PyTorchSpawnMultiNode", "LightningTrainerMultiNode"]
Expand Up @@ -13,14 +13,14 @@


@runtime_checkable
class _PyTorchLightningWorkProtocol(Protocol):
class _LightningTrainerWorkProtocol(Protocol):
@staticmethod
def run() -> None:
...


@dataclass
class _PyTorchLightningRunExecutor(_PyTorchSpawnRunExecutor):
class _LightningTrainerRunExecutor(_PyTorchSpawnRunExecutor):
@staticmethod
def run(
local_rank: int,
Expand Down Expand Up @@ -71,7 +71,7 @@ def pre_fn(trainer, *args, **kwargs):
tracer._restore()


class PyTorchLightningMultiNode(MultiNode):
class LightningTrainerMultiNode(MultiNode):
def __init__(
self,
work_cls: Type["LightningWork"],
Expand All @@ -80,7 +80,7 @@ def __init__(
*work_args: Any,
**work_kwargs: Any,
) -> None:
assert issubclass(work_cls, _PyTorchLightningWorkProtocol)
assert issubclass(work_cls, _LightningTrainerWorkProtocol)
if not is_static_method(work_cls, "run"):
raise TypeError(
f"The provided {work_cls} run method needs to be static for now."
Expand All @@ -89,7 +89,7 @@ def __init__(

# Note: Private way to modify the work run executor
# Probably exposed to the users in the future if needed.
work_cls._run_executor_cls = _PyTorchLightningRunExecutor
work_cls._run_executor_cls = _LightningTrainerRunExecutor

super().__init__(
work_cls,
Expand Down