diff --git a/.gitignore b/.gitignore index 054a5ba16aff5..835a98a19efd9 100644 --- a/.gitignore +++ b/.gitignore @@ -109,8 +109,8 @@ celerybeat-schedule # dotenv .env -.env_staging -.env_local +.env.staging +.env.local # virtualenv .venv diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index 17419c06871f5..2fcc31422f450 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -10,6 +10,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added partial support for fastapi `Request` annotation in `configure_api` handlers ([#16047](https://github.com/Lightning-AI/lightning/pull/16047)) + + ### Changed diff --git a/src/lightning_app/api/http_methods.py b/src/lightning_app/api/http_methods.py index ca09a9a83eecc..982e02d959f08 100644 --- a/src/lightning_app/api/http_methods.py +++ b/src/lightning_app/api/http_methods.py @@ -2,12 +2,14 @@ import inspect import time from copy import deepcopy +from dataclasses import dataclass from functools import wraps from multiprocessing import Queue from typing import Any, Callable, Dict, List, Optional from uuid import uuid4 -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI, HTTPException, Request +from lightning_utilities.core.apply_func import apply_to_collection from lightning_app.api.request_types import _APIRequest, _CommandRequest, _RequestResponse from lightning_app.utilities.app_helpers import Logger @@ -19,6 +21,77 @@ def _signature_proxy_function(): pass +@dataclass +class _FastApiMockRequest: + """This class is meant to mock FastAPI Request class that isn't pickle-able. + + If a user relies on FastAPI Request annotation, the Lightning framework + patches the annotation before pickling and replace them right after. + + Finally, the FastAPI request is converted back to the _FastApiMockRequest + before being delivered to the users. + + Example: + + import lightning as L + from fastapi import Request + from lightning.app.api import Post + + class Flow(L.LightningFlow): + + def request(self, request: Request) -> OutputRequestModel: + ... + + def configure_api(self): + return [Post("/api/v1/request", self.request)] + """ + + _body: Optional[str] = None + _json: Optional[str] = None + _method: Optional[str] = None + _headers: Optional[Dict] = None + + @property + def receive(self): + raise NotImplementedError + + @property + def method(self): + raise self._method + + @property + def headers(self): + return self._headers + + def body(self): + return self._body + + def json(self): + return self._json + + def stream(self): + raise NotImplementedError + + def form(self): + raise NotImplementedError + + def close(self): + raise NotImplementedError + + def is_disconnected(self): + raise NotImplementedError + + +async def _mock_fastapi_request(request: Request): + # TODO: Add more requests parameters. + return _FastApiMockRequest( + _body=await request.body(), + _json=await request.json(), + _headers=request.headers, + _method=request.method, + ) + + class _HttpMethod: def __init__(self, route: str, method: Callable, method_name: Optional[str] = None, timeout: int = 30, **kwargs): """This class is used to inject user defined methods within the App Rest API. @@ -34,6 +107,7 @@ def __init__(self, route: str, method: Callable, method_name: Optional[str] = No self.method_annotations = method.__annotations__ # TODO: Validate the signature contains only pydantic models. self.method_signature = inspect.signature(method) + if not self.attached_to_flow: self.component_name = method.__name__ self.method = method @@ -43,10 +117,16 @@ def __init__(self, route: str, method: Callable, method_name: Optional[str] = No self.timeout = timeout self.kwargs = kwargs + # Enable the users to rely on FastAPI annotation typing with Request. + # Note: Only a part of the Request functionatilities are supported. + self._patch_fast_api_request() + def add_route(self, app: FastAPI, request_queue: Queue, responses_store: Dict[str, Any]) -> None: # 1: Get the route associated with the http method. route = getattr(app, self.__class__.__name__.lower()) + self._unpatch_fast_api_request() + # 2: Create a proxy function with the signature of the wrapped method. fn = deepcopy(_signature_proxy_function) fn.__annotations__ = self.method_annotations @@ -69,6 +149,11 @@ async def _handle_request(*args, **kwargs): @wraps(_signature_proxy_function) async def _handle_request(*args, **kwargs): async def fn(*args, **kwargs): + args, kwargs = apply_to_collection((args, kwargs), Request, _mock_fastapi_request) + for k, v in kwargs.items(): + if hasattr(v, "__await__"): + kwargs[k] = await v + request_id = str(uuid4()).split("-")[0] logger.debug(f"Processing request {request_id} for route: {self.route}") request_queue.put( @@ -101,6 +186,26 @@ async def fn(*args, **kwargs): # 4: Register the user provided route to the Rest API. route(self.route, **self.kwargs)(_handle_request) + def _patch_fast_api_request(self): + """This function replaces signature annotation for Request with its mock.""" + for k, v in self.method_annotations.items(): + if v == Request: + self.method_annotations[k] = _FastApiMockRequest + + for v in self.method_signature.parameters.values(): + if v._annotation == Request: + v._annotation = _FastApiMockRequest + + def _unpatch_fast_api_request(self): + """This function replaces back signature annotation to fastapi Request.""" + for k, v in self.method_annotations.items(): + if v == _FastApiMockRequest: + self.method_annotations[k] = Request + + for v in self.method_signature.parameters.values(): + if v._annotation == _FastApiMockRequest: + v._annotation = Request + class Post(_HttpMethod): pass diff --git a/tests/tests_app/core/test_lightning_api.py b/tests/tests_app/core/test_lightning_api.py index 04b89c927941a..f3eb8f9bacda9 100644 --- a/tests/tests_app/core/test_lightning_api.py +++ b/tests/tests_app/core/test_lightning_api.py @@ -12,7 +12,7 @@ import pytest import requests from deepdiff import DeepDiff, Delta -from fastapi import HTTPException +from fastapi import HTTPException, Request from httpx import AsyncClient from pydantic import BaseModel @@ -479,10 +479,13 @@ def run(self): if self.counter == 501: self._exit() - def request(self, config: InputRequestModel) -> OutputRequestModel: + def request(self, config: InputRequestModel, request: Request) -> OutputRequestModel: self.counter += 1 if config.index % 5 == 0: raise HTTPException(status_code=400, detail="HERE") + assert request.body() + assert request.json() + assert request.headers return OutputRequestModel(name=config.name, counter=self.counter) def configure_api(self):