Skip to content

Commit

Permalink
[App] PoC: Add support for Request (#16047)
Browse files Browse the repository at this point in the history
(cherry picked from commit 592b126)
  • Loading branch information
tchaton authored and Borda committed Dec 20, 2022
1 parent 3652718 commit 17e15af
Show file tree
Hide file tree
Showing 4 changed files with 116 additions and 5 deletions.
4 changes: 2 additions & 2 deletions .gitignore
Expand Up @@ -109,8 +109,8 @@ celerybeat-schedule

# dotenv
.env
.env_staging
.env_local
.env.staging
.env.local

# virtualenv
.venv
Expand Down
3 changes: 3 additions & 0 deletions src/lightning_app/CHANGELOG.md
Expand Up @@ -10,6 +10,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added


- Added partial support for fastapi `Request` annotation in `configure_api` handlers ([#16047](https://github.com/Lightning-AI/lightning/pull/16047))


### Changed


Expand Down
107 changes: 106 additions & 1 deletion src/lightning_app/api/http_methods.py
Expand Up @@ -2,12 +2,14 @@
import inspect
import time
from copy import deepcopy
from dataclasses import dataclass
from functools import wraps
from multiprocessing import Queue
from typing import Any, Callable, Dict, List, Optional
from uuid import uuid4

from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, Request
from lightning_utilities.core.apply_func import apply_to_collection

from lightning_app.api.request_types import _APIRequest, _CommandRequest, _RequestResponse
from lightning_app.utilities.app_helpers import Logger
Expand All @@ -19,6 +21,77 @@ def _signature_proxy_function():
pass


@dataclass
class _FastApiMockRequest:
"""This class is meant to mock FastAPI Request class that isn't pickle-able.
If a user relies on FastAPI Request annotation, the Lightning framework
patches the annotation before pickling and replace them right after.
Finally, the FastAPI request is converted back to the _FastApiMockRequest
before being delivered to the users.
Example:
import lightning as L
from fastapi import Request
from lightning.app.api import Post
class Flow(L.LightningFlow):
def request(self, request: Request) -> OutputRequestModel:
...
def configure_api(self):
return [Post("/api/v1/request", self.request)]
"""

_body: Optional[str] = None
_json: Optional[str] = None
_method: Optional[str] = None
_headers: Optional[Dict] = None

@property
def receive(self):
raise NotImplementedError

@property
def method(self):
raise self._method

@property
def headers(self):
return self._headers

def body(self):
return self._body

def json(self):
return self._json

def stream(self):
raise NotImplementedError

def form(self):
raise NotImplementedError

def close(self):
raise NotImplementedError

def is_disconnected(self):
raise NotImplementedError


async def _mock_fastapi_request(request: Request):
# TODO: Add more requests parameters.
return _FastApiMockRequest(
_body=await request.body(),
_json=await request.json(),
_headers=request.headers,
_method=request.method,
)


class _HttpMethod:
def __init__(self, route: str, method: Callable, method_name: Optional[str] = None, timeout: int = 30, **kwargs):
"""This class is used to inject user defined methods within the App Rest API.
Expand All @@ -34,6 +107,7 @@ def __init__(self, route: str, method: Callable, method_name: Optional[str] = No
self.method_annotations = method.__annotations__
# TODO: Validate the signature contains only pydantic models.
self.method_signature = inspect.signature(method)

if not self.attached_to_flow:
self.component_name = method.__name__
self.method = method
Expand All @@ -43,10 +117,16 @@ def __init__(self, route: str, method: Callable, method_name: Optional[str] = No
self.timeout = timeout
self.kwargs = kwargs

# Enable the users to rely on FastAPI annotation typing with Request.
# Note: Only a part of the Request functionatilities are supported.
self._patch_fast_api_request()

def add_route(self, app: FastAPI, request_queue: Queue, responses_store: Dict[str, Any]) -> None:
# 1: Get the route associated with the http method.
route = getattr(app, self.__class__.__name__.lower())

self._unpatch_fast_api_request()

# 2: Create a proxy function with the signature of the wrapped method.
fn = deepcopy(_signature_proxy_function)
fn.__annotations__ = self.method_annotations
Expand All @@ -69,6 +149,11 @@ async def _handle_request(*args, **kwargs):
@wraps(_signature_proxy_function)
async def _handle_request(*args, **kwargs):
async def fn(*args, **kwargs):
args, kwargs = apply_to_collection((args, kwargs), Request, _mock_fastapi_request)
for k, v in kwargs.items():
if hasattr(v, "__await__"):
kwargs[k] = await v

request_id = str(uuid4()).split("-")[0]
logger.debug(f"Processing request {request_id} for route: {self.route}")
request_queue.put(
Expand Down Expand Up @@ -101,6 +186,26 @@ async def fn(*args, **kwargs):
# 4: Register the user provided route to the Rest API.
route(self.route, **self.kwargs)(_handle_request)

def _patch_fast_api_request(self):
"""This function replaces signature annotation for Request with its mock."""
for k, v in self.method_annotations.items():
if v == Request:
self.method_annotations[k] = _FastApiMockRequest

for v in self.method_signature.parameters.values():
if v._annotation == Request:
v._annotation = _FastApiMockRequest

def _unpatch_fast_api_request(self):
"""This function replaces back signature annotation to fastapi Request."""
for k, v in self.method_annotations.items():
if v == _FastApiMockRequest:
self.method_annotations[k] = Request

for v in self.method_signature.parameters.values():
if v._annotation == _FastApiMockRequest:
v._annotation = Request


class Post(_HttpMethod):
pass
Expand Down
7 changes: 5 additions & 2 deletions tests/tests_app/core/test_lightning_api.py
Expand Up @@ -12,7 +12,7 @@
import pytest
import requests
from deepdiff import DeepDiff, Delta
from fastapi import HTTPException
from fastapi import HTTPException, Request
from httpx import AsyncClient
from pydantic import BaseModel

Expand Down Expand Up @@ -479,10 +479,13 @@ def run(self):
if self.counter == 501:
self._exit()

def request(self, config: InputRequestModel) -> OutputRequestModel:
def request(self, config: InputRequestModel, request: Request) -> OutputRequestModel:
self.counter += 1
if config.index % 5 == 0:
raise HTTPException(status_code=400, detail="HERE")
assert request.body()
assert request.json()
assert request.headers
return OutputRequestModel(name=config.name, counter=self.counter)

def configure_api(self):
Expand Down

0 comments on commit 17e15af

Please sign in to comment.