Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[App] Resolve PythonServer on M1 #15949

Merged
merged 32 commits into from Dec 8, 2022
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
229677a
update
Dec 7, 2022
c6191c2
update
Dec 7, 2022
6805e0a
update
Dec 7, 2022
3abe245
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 7, 2022
b05927d
update
Dec 7, 2022
a625c82
Merge branch 'resolve_python_server' of https://github.com/Lightning-…
Dec 7, 2022
95b7068
update
Dec 7, 2022
a82d8ca
update
Dec 7, 2022
e1e41da
update
Dec 7, 2022
04dec73
Merge branch 'master' into resolve_python_server
tchaton Dec 7, 2022
5dd92c0
update
Dec 7, 2022
e303da2
Merge branch 'resolve_python_server' of https://github.com/Lightning-…
Dec 7, 2022
6df486f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 7, 2022
1e9d1e2
update
Dec 7, 2022
89379a8
Merge branch 'resolve_python_server' of https://github.com/Lightning-…
Dec 7, 2022
cdd8226
update
Dec 7, 2022
07e9b35
update
Dec 7, 2022
8c48628
update
Dec 7, 2022
02aff87
update
Dec 7, 2022
9073c1a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 7, 2022
d74b178
Merge branch 'master' into resolve_python_server
tchaton Dec 8, 2022
ac48dd7
update
Dec 8, 2022
9981c43
update
Dec 8, 2022
bee3c70
update
Dec 8, 2022
a362f33
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 8, 2022
33b5e76
update
Dec 8, 2022
6390a36
Merge branch 'resolve_python_server' of https://github.com/Lightning-…
Dec 8, 2022
dbf5a63
update
Dec 8, 2022
7843615
Merge branch 'master' into resolve_python_server
tchaton Dec 8, 2022
6accc44
Merge branch 'master' into resolve_python_server
tchaton Dec 8, 2022
f45c61a
update
Dec 8, 2022
9133f93
Merge branch 'resolve_python_server' of https://github.com/Lightning-…
Dec 8, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements/app/ui.txt
@@ -1,2 +1,2 @@
streamlit>=1.3.1, <=1.11.1
streamlit
tchaton marked this conversation as resolved.
Show resolved Hide resolved
panel>=0.12.7, <=0.13.1
4 changes: 2 additions & 2 deletions src/lightning_app/CHANGELOG.md
Expand Up @@ -50,14 +50,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed SSH CLI command listing stopped components ([#15810](https://github.com/Lightning-AI/lightning/pull/15810))


- Fixed the work not stopped when successful when passed directly to the LightningApp ([#15801](https://github.com/Lightning-AI/lightning/pull/15801))


- Fixed the `enable_spawn` method of the `WorkRunExecutor` ([#15812](https://github.com/Lightning-AI/lightning/pull/15812)

- Fixed Sigterm Handler causing thread lock which caused KeyboardInterrupt to hang ([#15881](https://github.com/Lightning-AI/lightning/pull/15881))

- Fixed PythonServer generating noise on M1 ([#15949](https://github.com/Lightning-AI/lightning/pull/15949))

- Fixed a bug where using `L.app.structures` would cause multiple apps to be opened and fail with an error in the cloud ([#15911](https://github.com/Lightning-AI/lightning/pull/15911))


Expand Down
4 changes: 1 addition & 3 deletions src/lightning_app/components/auto_scaler.py
Expand Up @@ -269,9 +269,7 @@ async def sys_info(authenticated: bool = Depends(authenticate_private_endpoint))

@fastapi_app.put("/system/update-servers")
async def update_servers(servers: List[str], authenticated: bool = Depends(authenticate_private_endpoint)):
async with lock:
tchaton marked this conversation as resolved.
Show resolved Hide resolved
self.servers = servers

self.servers = servers
self._iter = cycle(self.servers)

@fastapi_app.post(self.endpoint, response_model=self._output_type)
Expand Down
8 changes: 2 additions & 6 deletions src/lightning_app/components/serve/gradio.py
@@ -1,10 +1,8 @@
import abc
import os
from functools import partial
from types import ModuleType
from typing import Any, List, Optional

from lightning_app.components.serve.python_server import _PyTorchSpawnRunExecutor, WorkRunExecutor
from lightning_app.core.work import LightningWork
from lightning_app.utilities.imports import _is_gradio_available, requires

Expand Down Expand Up @@ -36,15 +34,13 @@ class ServeGradio(LightningWork, abc.ABC):
title: Optional[str] = None
description: Optional[str] = None

_start_method = "spawn"

def __init__(self, *args, **kwargs):
requires("gradio")(super().__init__(*args, **kwargs))
assert self.inputs
assert self.outputs
self._model = None
# Note: Enable to run inference on GPUs.
self._run_executor_cls = (
WorkRunExecutor if os.getenv("LIGHTNING_CLOUD_APP_ID", None) else _PyTorchSpawnRunExecutor
)

@property
def model(self):
Expand Down
63 changes: 19 additions & 44 deletions src/lightning_app/components/serve/python_server.py
@@ -1,20 +1,19 @@
import abc
import base64
import os
import platform
from pathlib import Path
from typing import Any, Dict, Optional

import uvicorn
from fastapi import FastAPI
from lightning_utilities.core.imports import module_available
from lightning_utilities.core.imports import compare_version, module_available
from pydantic import BaseModel
from starlette.staticfiles import StaticFiles

from lightning_app.core.queues import MultiProcessQueue
from lightning_app.core.work import LightningWork
from lightning_app.utilities.app_helpers import Logger
from lightning_app.utilities.imports import _is_torch_available, requires
from lightning_app.utilities.proxies import _proxy_setattr, unwrap, WorkRunExecutor, WorkStateObserver

logger = Logger(__name__)

Expand All @@ -28,44 +27,19 @@
__doctest_skip__ += ["PythonServer", "PythonServer.*"]


class _PyTorchSpawnRunExecutor(WorkRunExecutor):
def _get_device():
import operator

"""This Executor enables to move PyTorch tensors on GPU.
import torch

Without this executor, it would raise the following exception:
RuntimeError: Cannot re-initialize CUDA in forked subprocess.
To use CUDA with multiprocessing, you must use the 'spawn' start method
"""
_TORCH_GREATER_EQUAL_1_12 = compare_version("torch", operator.ge, "1.12.0")

enable_start_observer: bool = False
local_rank = int(os.getenv("LOCAL_RANK", "0"))

def __call__(self, *args: Any, **kwargs: Any):
import torch

with self.enable_spawn():
queue = self.delta_queue if isinstance(self.delta_queue, MultiProcessQueue) else self.delta_queue.to_dict()
torch.multiprocessing.spawn(
self.dispatch_run,
args=(self.__class__, self.work, queue, args, kwargs),
nprocs=1,
)

@staticmethod
def dispatch_run(local_rank, cls, work, delta_queue, args, kwargs):
if local_rank == 0:
if isinstance(delta_queue, dict):
delta_queue = cls.process_queue(delta_queue)
work._request_queue = cls.process_queue(work._request_queue)
work._response_queue = cls.process_queue(work._response_queue)

state_observer = WorkStateObserver(work, delta_queue=delta_queue)
state_observer.start()
_proxy_setattr(work, delta_queue, state_observer)

unwrap(work.run)(*args, **kwargs)

if local_rank == 0:
state_observer.join(0)
if _TORCH_GREATER_EQUAL_1_12 and torch.backends.mps.is_available() and platform.processor() in ("arm", "arm64"):
return torch.device("mps", local_rank)
else:
return torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")


class _DefaultInputData(BaseModel):
Expand Down Expand Up @@ -96,6 +70,9 @@ def _get_sample_data() -> Dict[Any, Any]:


class PythonServer(LightningWork, abc.ABC):

_start_method = "spawn"

@requires(["torch", "lightning_api_access"])
def __init__( # type: ignore
self,
Expand Down Expand Up @@ -161,11 +138,6 @@ def predict(self, request):
self._input_type = input_type
self._output_type = output_type

# Note: Enable to run inference on GPUs.
self._run_executor_cls = (
WorkRunExecutor if os.getenv("LIGHTNING_CLOUD_APP_ID", None) else _PyTorchSpawnRunExecutor
)

def setup(self, *args, **kwargs) -> None:
"""This method is called before the server starts. Override this if you need to download the model or
initialize the weights, setting up pipelines etc.
Expand Down Expand Up @@ -211,13 +183,16 @@ def _get_sample_dict_from_datatype(datatype: Any) -> dict:
return out

def _attach_predict_fn(self, fastapi_app: FastAPI) -> None:
from torch import inference_mode
from torch import inference_mode, no_grad

input_type: type = self.configure_input_type()
output_type: type = self.configure_output_type()

device = _get_device()
context = no_grad if device.type == "mps" else inference_mode

def predict_fn(request: input_type): # type: ignore
with inference_mode():
with context():
return self.predict(request)

fastapi_app.post("/predict", response_model=output_type)(predict_fn)
Expand Down
4 changes: 4 additions & 0 deletions src/lightning_app/frontend/streamlit_base.py
Expand Up @@ -2,6 +2,7 @@

From here, we will call the render function that the user provided in ``configure_layout``.
"""
import asyncio
import os
import pydoc
from typing import Callable
Expand All @@ -20,6 +21,9 @@ def _get_render_fn_from_environment() -> Callable:

def _main():
"""Run the render_fn with the current flow_state."""
loop = asyncio.new_event_loop()
tchaton marked this conversation as resolved.
Show resolved Hide resolved
asyncio.set_event_loop(loop)

app_state = AppState(plugin=StreamLitStatePlugin())

# Fetch the information of which flow attaches to this streamlit instance
Expand Down