Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[App] PoC: Add support for Request #16047

Merged
merged 15 commits into from Dec 16, 2022
4 changes: 2 additions & 2 deletions .gitignore
Expand Up @@ -110,8 +110,8 @@ celerybeat-schedule

# dotenv
.env
.env_staging
tchaton marked this conversation as resolved.
Show resolved Hide resolved
.env_local
.env.staging
.env.local

# virtualenv
.venv
Expand Down
3 changes: 3 additions & 0 deletions src/lightning_app/CHANGELOG.md
Expand Up @@ -13,6 +13,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a progres bar while connecting to an app through the CLI ([#16035](https://github.com/Lightning-AI/lightning/pull/16035))


- Added partial support for fastapi `Request` annotation in `configure_api` handlers ([#16047](https://github.com/Lightning-AI/lightning/pull/16047))


### Changed

-
Expand Down
109 changes: 108 additions & 1 deletion src/lightning_app/api/http_methods.py
Expand Up @@ -2,12 +2,14 @@
import inspect
import time
from copy import deepcopy
from dataclasses import dataclass
from functools import wraps
from multiprocessing import Queue
from typing import Any, Callable, Dict, List, Optional
from uuid import uuid4

from fastapi import FastAPI, HTTPException
from fastapi import FastAPI, HTTPException, Request
from lightning_utilities.core.apply_func import apply_to_collection

from lightning_app.api.request_types import _APIRequest, _CommandRequest, _RequestResponse
from lightning_app.utilities.app_helpers import Logger
Expand All @@ -19,6 +21,77 @@ def _signature_proxy_function():
pass


@dataclass
class FastApiMockRequest:
"""This class is meant to mock FastAPI Request class that isn't pickalable.
tchaton marked this conversation as resolved.
Show resolved Hide resolved

If a user relies on FastAPI Request annotation, the Lightning framework
patches the annotation before pickling and replace them right after.

Finally, the FastAPI request is converting back to the FastApiMockRequest
tchaton marked this conversation as resolved.
Show resolved Hide resolved
before being delivered to the users.

Example:

import lightning as L
from fastapi import Request
from lightning.app.api import Post

class Flow(L.LightningFlow):

def request(self, request: Request) -> OutputRequestModel:
...

def configure_api(self):
return [Post("/api/v1/request", self.request)]
"""

_body: Optional[str] = None
_json: Optional[str] = None
_method: Optional[str] = None
_headers: Optional[Dict] = None

@property
def receive(self):
raise NotImplementedError

@property
def method(self):
raise self._method

@property
def headers(self):
return self._headers

def body(self):
tchaton marked this conversation as resolved.
Show resolved Hide resolved
return self._body

def json(self):
return self._json

def stream(self):
raise NotImplementedError

def form(self):
raise NotImplementedError

def close(self):
raise NotImplementedError

def is_disconnected(self):
raise NotImplementedError


async def _mock_fastapi_request(request: Request):
# TODO: Add more requests parameters.
return FastApiMockRequest(
_body=await request.body(),
_json=await request.json(),
_headers=request.headers,
_method=request.method,
)


class _HttpMethod:
def __init__(self, route: str, method: Callable, method_name: Optional[str] = None, timeout: int = 30, **kwargs):
"""This class is used to inject user defined methods within the App Rest API.
Expand All @@ -34,6 +107,7 @@ def __init__(self, route: str, method: Callable, method_name: Optional[str] = No
self.method_annotations = method.__annotations__
# TODO: Validate the signature contains only pydantic models.
self.method_signature = inspect.signature(method)

if not self.attached_to_flow:
self.component_name = method.__name__
self.method = method
Expand All @@ -43,10 +117,16 @@ def __init__(self, route: str, method: Callable, method_name: Optional[str] = No
self.timeout = timeout
self.kwargs = kwargs

# Enable the users to rely on FastAPI annotation typing with Request.
# Note: Only a part of the Request functionatilities are supported.
self._patch_fast_api_request()

def add_route(self, app: FastAPI, request_queue: Queue, responses_store: Dict[str, Any]) -> None:
# 1: Get the route associated with the http method.
route = getattr(app, self.__class__.__name__.lower())

self._unpatch_fast_api_request()

# 2: Create a proxy function with the signature of the wrapped method.
fn = deepcopy(_signature_proxy_function)
fn.__annotations__ = self.method_annotations
Expand All @@ -69,6 +149,11 @@ async def _handle_request(*args, **kwargs):
@wraps(_signature_proxy_function)
async def _handle_request(*args, **kwargs):
async def fn(*args, **kwargs):
args, kwargs = apply_to_collection((args, kwargs), Request, _mock_fastapi_request)
for k, v in kwargs.items():
if hasattr(v, "__await__"):
kwargs[k] = await v

request_id = str(uuid4()).split("-")[0]
logger.debug(f"Processing request {request_id} for route: {self.route}")
request_queue.put(
Expand Down Expand Up @@ -101,6 +186,28 @@ async def fn(*args, **kwargs):
# 4: Register the user provided route to the Rest API.
route(self.route, **self.kwargs)(_handle_request)

def _patch_fast_api_request(self):
"""This function replaces signature annotation for Request with its mock."""
for k, v in self.method_annotations.items():
if v == Request:
v = FastApiMockRequest
self.method_annotations[k] = v

for v in self.method_signature.parameters.values():
if v._annotation == Request:
v._annotation = FastApiMockRequest

def _unpatch_fast_api_request(self):
"""This function replaces bacl signature annotation to fastapi Request."""
tchaton marked this conversation as resolved.
Show resolved Hide resolved
for k, v in self.method_annotations.items():
if v == FastApiMockRequest:
v = Request
self.method_annotations[k] = v
tchaton marked this conversation as resolved.
Show resolved Hide resolved

for v in self.method_signature.parameters.values():
if v._annotation == FastApiMockRequest:
v._annotation = Request


class Post(_HttpMethod):
pass
Expand Down
3 changes: 2 additions & 1 deletion src/lightning_app/core/app.py
Expand Up @@ -184,7 +184,8 @@ def __init__(
def _update_index_file(self):
# update index.html,
# this should happen once for all apps before the ui server starts running.
frontend.update_index_file(FRONTEND_DIR, info=self.info, root_path=self.root_path)
if self.root_path != "":
tchaton marked this conversation as resolved.
Show resolved Hide resolved
tchaton marked this conversation as resolved.
Show resolved Hide resolved
frontend.update_index_file(FRONTEND_DIR, info=self.info, root_path=self.root_path)

def get_component_by_name(self, component_name: str):
"""Returns the instance corresponding to the given component name."""
Expand Down
7 changes: 5 additions & 2 deletions tests/tests_app/core/test_lightning_api.py
Expand Up @@ -12,7 +12,7 @@
import pytest
import requests
from deepdiff import DeepDiff, Delta
from fastapi import HTTPException
from fastapi import HTTPException, Request
from httpx import AsyncClient
from pydantic import BaseModel

Expand Down Expand Up @@ -479,10 +479,13 @@ def run(self):
if self.counter == 501:
self._exit()

def request(self, config: InputRequestModel) -> OutputRequestModel:
def request(self, config: InputRequestModel, request: Request) -> OutputRequestModel:
self.counter += 1
if config.index % 5 == 0:
raise HTTPException(status_code=400, detail="HERE")
assert request.body()
assert request.json()
assert request.headers
return OutputRequestModel(name=config.name, counter=self.counter)

def configure_api(self):
Expand Down