diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index 305ee591b0257..4c5da2c96e2e4 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -40,6 +40,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed the work not stopped when successful when passed directly to the LightningApp ([#15801](https://github.com/Lightning-AI/lightning/pull/15801)) +- Fixed the PyTorch Inference locally on GPU ([#15813](https://github.com/Lightning-AI/lightning/pull/15813)) + ## [1.8.2] - 2022-11-17 diff --git a/src/lightning_app/components/serve/gradio.py b/src/lightning_app/components/serve/gradio.py index 328e70e743b43..6e9b1d8777f67 100644 --- a/src/lightning_app/components/serve/gradio.py +++ b/src/lightning_app/components/serve/gradio.py @@ -1,8 +1,10 @@ import abc +import os from functools import partial from types import ModuleType from typing import Any, List, Optional +from lightning_app.components.serve.python_server import _PyTorchSpawnRunExecutor, WorkRunExecutor from lightning_app.core.work import LightningWork from lightning_app.utilities.imports import _is_gradio_available, requires @@ -39,6 +41,10 @@ def __init__(self, *args, **kwargs): assert self.inputs assert self.outputs self._model = None + # Note: Enable to run inference on GPUs. + self._run_executor_cls = ( + WorkRunExecutor if os.getenv("LIGHTNING_CLOUD_APP_ID", None) else _PyTorchSpawnRunExecutor + ) @property def model(self): diff --git a/src/lightning_app/components/serve/python_server.py b/src/lightning_app/components/serve/python_server.py index 731bf1c37e969..9ce1b23701059 100644 --- a/src/lightning_app/components/serve/python_server.py +++ b/src/lightning_app/components/serve/python_server.py @@ -1,5 +1,6 @@ import abc import base64 +import os from pathlib import Path from typing import Any, Dict, Optional @@ -9,12 +10,54 @@ from pydantic import BaseModel from starlette.staticfiles import StaticFiles +from lightning_app.core.queues import MultiProcessQueue from lightning_app.core.work import LightningWork from lightning_app.utilities.app_helpers import Logger +from lightning_app.utilities.proxies import _proxy_setattr, unwrap, WorkRunExecutor, WorkStateObserver logger = Logger(__name__) +class _PyTorchSpawnRunExecutor(WorkRunExecutor): + + """This Executor enables to move PyTorch tensors on GPU. + + Without this executor, it woud raise the following expection: + RuntimeError: Cannot re-initialize CUDA in forked subprocess. + To use CUDA with multiprocessing, you must use the 'spawn' start method + """ + + enable_start_observer: bool = False + + def __call__(self, *args: Any, **kwargs: Any): + import torch + + with self.enable_spawn(): + queue = self.delta_queue if isinstance(self.delta_queue, MultiProcessQueue) else self.delta_queue.to_dict() + torch.multiprocessing.spawn( + self.dispatch_run, + args=(self.__class__, self.work, queue, args, kwargs), + nprocs=1, + ) + + @staticmethod + def dispatch_run(local_rank, cls, work, delta_queue, args, kwargs): + if local_rank == 0: + if isinstance(delta_queue, dict): + delta_queue = cls.process_queue(delta_queue) + work._request_queue = cls.process_queue(work._request_queue) + work._response_queue = cls.process_queue(work._response_queue) + + state_observer = WorkStateObserver(work, delta_queue=delta_queue) + state_observer.start() + _proxy_setattr(work, delta_queue, state_observer) + + unwrap(work.run)(*args, **kwargs) + + if local_rank == 0: + state_observer.join(0) + + class _DefaultInputData(BaseModel): payload: str @@ -106,6 +149,11 @@ def predict(self, request): self._input_type = input_type self._output_type = output_type + # Note: Enable to run inference on GPUs. + self._run_executor_cls = ( + WorkRunExecutor if os.getenv("LIGHTNING_CLOUD_APP_ID", None) else _PyTorchSpawnRunExecutor + ) + def setup(self, *args, **kwargs) -> None: """This method is called before the server starts. Override this if you need to download the model or initialize the weights, setting up pipelines etc.