Skip to content

Commit

Permalink
[App] Rename to new convention (#15621)
Browse files Browse the repository at this point in the history
* update

* update
  • Loading branch information
tchaton committed Nov 10, 2022
1 parent 6ba00af commit 7ec15ae
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 14 deletions.
@@ -1,9 +1,9 @@
import lightning as L
from lightning.app.components import PyTorchLightningMultiNode
from lightning.app.components import LightningTrainerMultiNode
from lightning.pytorch.demos.boring_classes import BoringModel


class PyTorchLightningDistributed(L.LightningWork):
class LightningTrainerDistributed(L.LightningWork):
@staticmethod
def run():
model = BoringModel()
Expand All @@ -16,8 +16,8 @@ def run():

# Run over 2 nodes of 4 x V100
app = L.LightningApp(
PyTorchLightningMultiNode(
PyTorchLightningDistributed,
LightningTrainerMultiNode(
LightningTrainerDistributed,
num_nodes=2,
cloud_compute=L.CloudCompute("gpu-fast-multi"), # 4 x V100
)
Expand Down
File renamed without changes.
2 changes: 1 addition & 1 deletion pyproject.toml
Expand Up @@ -62,7 +62,7 @@ module = [
"lightning_app.components.multi_node.lite",
"lightning_app.components.multi_node.base",
"lightning_app.components.multi_node.pytorch_spawn",
"lightning_app.components.multi_node.pl",
"lightning_app.components.multi_node.trainer",
"lightning_app.api.http_methods",
"lightning_app.api.request_types",
"lightning_app.cli.commands.app_commands",
Expand Down
4 changes: 2 additions & 2 deletions src/lightning_app/components/__init__.py
@@ -1,9 +1,9 @@
from lightning_app.components.database.client import DatabaseClient
from lightning_app.components.database.server import Database
from lightning_app.components.multi_node import (
LightningTrainerMultiNode,
LiteMultiNode,
MultiNode,
PyTorchLightningMultiNode,
PyTorchSpawnMultiNode,
)
from lightning_app.components.python.popen import PopenPythonScript
Expand All @@ -29,5 +29,5 @@
"LightningTrainingComponent",
"PyTorchLightningScriptRunner",
"PyTorchSpawnMultiNode",
"PyTorchLightningMultiNode",
"LightningTrainerMultiNode",
]
4 changes: 2 additions & 2 deletions src/lightning_app/components/multi_node/__init__.py
@@ -1,6 +1,6 @@
from lightning_app.components.multi_node.base import MultiNode
from lightning_app.components.multi_node.lite import LiteMultiNode
from lightning_app.components.multi_node.pl import PyTorchLightningMultiNode
from lightning_app.components.multi_node.pytorch_spawn import PyTorchSpawnMultiNode
from lightning_app.components.multi_node.trainer import LightningTrainerMultiNode

__all__ = ["LiteMultiNode", "MultiNode", "PyTorchSpawnMultiNode", "PyTorchLightningMultiNode"]
__all__ = ["LiteMultiNode", "MultiNode", "PyTorchSpawnMultiNode", "LightningTrainerMultiNode"]
Expand Up @@ -13,14 +13,14 @@


@runtime_checkable
class _PyTorchLightningWorkProtocol(Protocol):
class _LightningTrainerWorkProtocol(Protocol):
@staticmethod
def run() -> None:
...


@dataclass
class _PyTorchLightningRunExecutor(_PyTorchSpawnRunExecutor):
class _LightningTrainerRunExecutor(_PyTorchSpawnRunExecutor):
@staticmethod
def run(
local_rank: int,
Expand Down Expand Up @@ -71,7 +71,7 @@ def pre_fn(trainer, *args, **kwargs):
tracer._restore()


class PyTorchLightningMultiNode(MultiNode):
class LightningTrainerMultiNode(MultiNode):
def __init__(
self,
work_cls: Type["LightningWork"],
Expand All @@ -80,7 +80,7 @@ def __init__(
*work_args: Any,
**work_kwargs: Any,
) -> None:
assert issubclass(work_cls, _PyTorchLightningWorkProtocol)
assert issubclass(work_cls, _LightningTrainerWorkProtocol)
if not is_static_method(work_cls, "run"):
raise TypeError(
f"The provided {work_cls} run method needs to be static for now."
Expand All @@ -89,7 +89,7 @@ def __init__(

# Note: Private way to modify the work run executor
# Probably exposed to the users in the future if needed.
work_cls._run_executor_cls = _PyTorchLightningRunExecutor
work_cls._run_executor_cls = _LightningTrainerRunExecutor

super().__init__(
work_cls,
Expand Down

0 comments on commit 7ec15ae

Please sign in to comment.