From 23e0d1b2460470a2e98bee07b95f97dd0f0e0fcc Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Fri, 25 Nov 2022 13:43:16 +0000 Subject: [PATCH 1/3] [App] Enable Python Server and Gradio Serve to run on accelerated device such as GPU CUDA / MPS (#15813) --- src/lightning_app/CHANGELOG.md | 2 + src/lightning_app/components/serve/gradio.py | 6 +++ .../components/serve/python_server.py | 48 +++++++++++++++++++ 3 files changed, 56 insertions(+) diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index 23419a48c92b3..dc1a90a40cee4 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -20,6 +20,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed debugging with VSCode IDE ([#15747](https://github.com/Lightning-AI/lightning/pull/15747)) - Fixed setting property to the `LightningFlow` ([#15750](https://github.com/Lightning-AI/lightning/pull/15750)) +- Fixed the PyTorch Inference locally on GPU ([#15813](https://github.com/Lightning-AI/lightning/pull/15813)) + ## [1.8.2] - 2022-11-17 diff --git a/src/lightning_app/components/serve/gradio.py b/src/lightning_app/components/serve/gradio.py index 328e70e743b43..6e9b1d8777f67 100644 --- a/src/lightning_app/components/serve/gradio.py +++ b/src/lightning_app/components/serve/gradio.py @@ -1,8 +1,10 @@ import abc +import os from functools import partial from types import ModuleType from typing import Any, List, Optional +from lightning_app.components.serve.python_server import _PyTorchSpawnRunExecutor, WorkRunExecutor from lightning_app.core.work import LightningWork from lightning_app.utilities.imports import _is_gradio_available, requires @@ -39,6 +41,10 @@ def __init__(self, *args, **kwargs): assert self.inputs assert self.outputs self._model = None + # Note: Enable to run inference on GPUs. + self._run_executor_cls = ( + WorkRunExecutor if os.getenv("LIGHTNING_CLOUD_APP_ID", None) else _PyTorchSpawnRunExecutor + ) @property def model(self): diff --git a/src/lightning_app/components/serve/python_server.py b/src/lightning_app/components/serve/python_server.py index 731bf1c37e969..9ce1b23701059 100644 --- a/src/lightning_app/components/serve/python_server.py +++ b/src/lightning_app/components/serve/python_server.py @@ -1,5 +1,6 @@ import abc import base64 +import os from pathlib import Path from typing import Any, Dict, Optional @@ -9,12 +10,54 @@ from pydantic import BaseModel from starlette.staticfiles import StaticFiles +from lightning_app.core.queues import MultiProcessQueue from lightning_app.core.work import LightningWork from lightning_app.utilities.app_helpers import Logger +from lightning_app.utilities.proxies import _proxy_setattr, unwrap, WorkRunExecutor, WorkStateObserver logger = Logger(__name__) +class _PyTorchSpawnRunExecutor(WorkRunExecutor): + + """This Executor enables to move PyTorch tensors on GPU. + + Without this executor, it woud raise the following expection: + RuntimeError: Cannot re-initialize CUDA in forked subprocess. + To use CUDA with multiprocessing, you must use the 'spawn' start method + """ + + enable_start_observer: bool = False + + def __call__(self, *args: Any, **kwargs: Any): + import torch + + with self.enable_spawn(): + queue = self.delta_queue if isinstance(self.delta_queue, MultiProcessQueue) else self.delta_queue.to_dict() + torch.multiprocessing.spawn( + self.dispatch_run, + args=(self.__class__, self.work, queue, args, kwargs), + nprocs=1, + ) + + @staticmethod + def dispatch_run(local_rank, cls, work, delta_queue, args, kwargs): + if local_rank == 0: + if isinstance(delta_queue, dict): + delta_queue = cls.process_queue(delta_queue) + work._request_queue = cls.process_queue(work._request_queue) + work._response_queue = cls.process_queue(work._response_queue) + + state_observer = WorkStateObserver(work, delta_queue=delta_queue) + state_observer.start() + _proxy_setattr(work, delta_queue, state_observer) + + unwrap(work.run)(*args, **kwargs) + + if local_rank == 0: + state_observer.join(0) + + class _DefaultInputData(BaseModel): payload: str @@ -106,6 +149,11 @@ def predict(self, request): self._input_type = input_type self._output_type = output_type + # Note: Enable to run inference on GPUs. + self._run_executor_cls = ( + WorkRunExecutor if os.getenv("LIGHTNING_CLOUD_APP_ID", None) else _PyTorchSpawnRunExecutor + ) + def setup(self, *args, **kwargs) -> None: """This method is called before the server starts. Override this if you need to download the model or initialize the weights, setting up pipelines etc. From f29fdd5650b39b6623aebff612138b8d688cb6c6 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Fri, 25 Nov 2022 16:37:48 +0000 Subject: [PATCH 2/3] bump version to 1.8.3.post1 --- src/lightning/__version__.py | 2 +- src/lightning_app/__version__.py | 2 +- src/lightning_lite/__version__.py | 2 +- src/pytorch_lightning/__version__.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/lightning/__version__.py b/src/lightning/__version__.py index 293a0ef04c47a..882d8f8541cd4 100644 --- a/src/lightning/__version__.py +++ b/src/lightning/__version__.py @@ -1 +1 @@ -version = "1.8.3.post0" +version = "1.8.3.post1" diff --git a/src/lightning_app/__version__.py b/src/lightning_app/__version__.py index 293a0ef04c47a..882d8f8541cd4 100644 --- a/src/lightning_app/__version__.py +++ b/src/lightning_app/__version__.py @@ -1 +1 @@ -version = "1.8.3.post0" +version = "1.8.3.post1" diff --git a/src/lightning_lite/__version__.py b/src/lightning_lite/__version__.py index 293a0ef04c47a..882d8f8541cd4 100644 --- a/src/lightning_lite/__version__.py +++ b/src/lightning_lite/__version__.py @@ -1 +1 @@ -version = "1.8.3.post0" +version = "1.8.3.post1" diff --git a/src/pytorch_lightning/__version__.py b/src/pytorch_lightning/__version__.py index 293a0ef04c47a..882d8f8541cd4 100644 --- a/src/pytorch_lightning/__version__.py +++ b/src/pytorch_lightning/__version__.py @@ -1 +1 @@ -version = "1.8.3.post0" +version = "1.8.3.post1" From 7eb8ccf4b62dac644c297564ea3efe8fe1e7f288 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Fri, 25 Nov 2022 16:42:53 +0000 Subject: [PATCH 3/3] update changelog --- src/lightning_app/CHANGELOG.md | 1 - 1 file changed, 1 deletion(-) diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index dc1a90a40cee4..c23e163fd1606 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -19,7 +19,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed debugging with VSCode IDE ([#15747](https://github.com/Lightning-AI/lightning/pull/15747)) - Fixed setting property to the `LightningFlow` ([#15750](https://github.com/Lightning-AI/lightning/pull/15750)) - - Fixed the PyTorch Inference locally on GPU ([#15813](https://github.com/Lightning-AI/lightning/pull/15813))