From e89dde8a0031756f36cb4d20725710e0c88af4ce Mon Sep 17 00:00:00 2001 From: thomas Date: Wed, 14 Dec 2022 04:43:52 +0000 Subject: [PATCH 01/10] update --- src/lightning_app/api/http_methods.py | 45 +++++++++++++++++++++- src/lightning_app/core/app.py | 3 +- tests/tests_app/core/test_lightning_api.py | 7 +++- 3 files changed, 51 insertions(+), 4 deletions(-) diff --git a/src/lightning_app/api/http_methods.py b/src/lightning_app/api/http_methods.py index ca09a9a83eecc..3c1a1922c8eea 100644 --- a/src/lightning_app/api/http_methods.py +++ b/src/lightning_app/api/http_methods.py @@ -2,12 +2,14 @@ import inspect import time from copy import deepcopy +from dataclasses import dataclass from functools import wraps from multiprocessing import Queue from typing import Any, Callable, Dict, List, Optional from uuid import uuid4 -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI, HTTPException, Request +from lightning_utilities.core.apply_func import apply_to_collection from lightning_app.api.request_types import _APIRequest, _CommandRequest, _RequestResponse from lightning_app.utilities.app_helpers import Logger @@ -19,6 +21,17 @@ def _signature_proxy_function(): pass +@dataclass +class FastApiMockRequest: + headers: Optional[str] = None + + +async def _mock_fastapi_request(request: Request): + data = await request.json() + # TODO: Add more requests parameters. + return FastApiMockRequest(data) + + class _HttpMethod: def __init__(self, route: str, method: Callable, method_name: Optional[str] = None, timeout: int = 30, **kwargs): """This class is used to inject user defined methods within the App Rest API. @@ -34,6 +47,7 @@ def __init__(self, route: str, method: Callable, method_name: Optional[str] = No self.method_annotations = method.__annotations__ # TODO: Validate the signature contains only pydantic models. self.method_signature = inspect.signature(method) + if not self.attached_to_flow: self.component_name = method.__name__ self.method = method @@ -43,10 +57,14 @@ def __init__(self, route: str, method: Callable, method_name: Optional[str] = No self.timeout = timeout self.kwargs = kwargs + self._patch_fast_api_request() + def add_route(self, app: FastAPI, request_queue: Queue, responses_store: Dict[str, Any]) -> None: # 1: Get the route associated with the http method. route = getattr(app, self.__class__.__name__.lower()) + self._unpatch_fast_api_request() + # 2: Create a proxy function with the signature of the wrapped method. fn = deepcopy(_signature_proxy_function) fn.__annotations__ = self.method_annotations @@ -69,6 +87,11 @@ async def _handle_request(*args, **kwargs): @wraps(_signature_proxy_function) async def _handle_request(*args, **kwargs): async def fn(*args, **kwargs): + args, kwargs = apply_to_collection((args, kwargs), Request, _mock_fastapi_request) + for k, v in kwargs.items(): + if hasattr(v, "__await__"): + kwargs[k] = await v + request_id = str(uuid4()).split("-")[0] logger.debug(f"Processing request {request_id} for route: {self.route}") request_queue.put( @@ -101,6 +124,26 @@ async def fn(*args, **kwargs): # 4: Register the user provided route to the Rest API. route(self.route, **self.kwargs)(_handle_request) + def _patch_fast_api_request(self): + for k, v in self.method_annotations.items(): + if v == Request: + v = FastApiMockRequest + self.method_annotations[k] = v + + for v in self.method_signature.parameters.values(): + if v._annotation == Request: + v._annotation = FastApiMockRequest + + def _unpatch_fast_api_request(self): + for k, v in self.method_annotations.items(): + if v == FastApiMockRequest: + v = Request + self.method_annotations[k] = v + + for v in self.method_signature.parameters.values(): + if v._annotation == FastApiMockRequest: + v._annotation = Request + class Post(_HttpMethod): pass diff --git a/src/lightning_app/core/app.py b/src/lightning_app/core/app.py index 9c3aeeb650de0..87910170d3d74 100644 --- a/src/lightning_app/core/app.py +++ b/src/lightning_app/core/app.py @@ -184,7 +184,8 @@ def __init__( def _update_index_file(self): # update index.html, # this should happen once for all apps before the ui server starts running. - frontend.update_index_file(FRONTEND_DIR, info=self.info, root_path=self.root_path) + if self.root_path != "": + frontend.update_index_file(FRONTEND_DIR, info=self.info, root_path=self.root_path) def get_component_by_name(self, component_name: str): """Returns the instance corresponding to the given component name.""" diff --git a/tests/tests_app/core/test_lightning_api.py b/tests/tests_app/core/test_lightning_api.py index 04b89c927941a..846ae974bcf5f 100644 --- a/tests/tests_app/core/test_lightning_api.py +++ b/tests/tests_app/core/test_lightning_api.py @@ -12,7 +12,7 @@ import pytest import requests from deepdiff import DeepDiff, Delta -from fastapi import HTTPException +from fastapi import HTTPException, Request from httpx import AsyncClient from pydantic import BaseModel @@ -479,10 +479,11 @@ def run(self): if self.counter == 501: self._exit() - def request(self, config: InputRequestModel) -> OutputRequestModel: + def request(self, config: InputRequestModel, request: Request) -> OutputRequestModel: self.counter += 1 if config.index % 5 == 0: raise HTTPException(status_code=400, detail="HERE") + print(request.headers) return OutputRequestModel(name=config.name, counter=self.counter) def configure_api(self): @@ -514,6 +515,8 @@ def test_configure_api(): sleep(0.1) time_left -= 0.1 + # sleep(1000) + # Test Upload File files = {"uploaded_file": open(__file__, "rb")} From 317182b8ae55834393e5de110c3da2f3aed0e20f Mon Sep 17 00:00:00 2001 From: thomas Date: Wed, 14 Dec 2022 14:03:02 +0000 Subject: [PATCH 02/10] update --- src/lightning_app/api/http_methods.py | 40 ++++++++++++++++++++-- tests/tests_app/core/test_lightning_api.py | 2 +- 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/src/lightning_app/api/http_methods.py b/src/lightning_app/api/http_methods.py index 3c1a1922c8eea..60db89bb76b13 100644 --- a/src/lightning_app/api/http_methods.py +++ b/src/lightning_app/api/http_methods.py @@ -23,13 +23,47 @@ def _signature_proxy_function(): @dataclass class FastApiMockRequest: - headers: Optional[str] = None + _body: Optional[str] = None + _json: Optional[str] = None + _method: Optional[str] = None + _headers: Optional[Dict] = None + + @property + def receive(self): + raise NotImplementedError + + @property + def method(self): + raise self._method + + @property + def headers(self): + return self._headers + + def body(self): + return self._body + + def json(self): + return self._json + + def stream(self): + raise NotImplementedError + + def form(self): + raise NotImplementedError + + def close(self): + raise NotImplementedError + + def is_disconnected(self): + raise NotImplementedError async def _mock_fastapi_request(request: Request): - data = await request.json() # TODO: Add more requests parameters. - return FastApiMockRequest(data) + return FastApiMockRequest( + _body=await request.body(), _json=await request.json(), _headers=request.headers, _method=request.method + ) class _HttpMethod: diff --git a/tests/tests_app/core/test_lightning_api.py b/tests/tests_app/core/test_lightning_api.py index 846ae974bcf5f..0b7819349aa03 100644 --- a/tests/tests_app/core/test_lightning_api.py +++ b/tests/tests_app/core/test_lightning_api.py @@ -483,7 +483,7 @@ def request(self, config: InputRequestModel, request: Request) -> OutputRequestM self.counter += 1 if config.index % 5 == 0: raise HTTPException(status_code=400, detail="HERE") - print(request.headers) + print(request.body()) return OutputRequestModel(name=config.name, counter=self.counter) def configure_api(self): From 912f89fafd5d86e41b014119aadd60cae0c8b2bf Mon Sep 17 00:00:00 2001 From: Noha Alon Date: Wed, 14 Dec 2022 17:52:20 +0200 Subject: [PATCH 03/10] Update tests/tests_app/core/test_lightning_api.py --- tests/tests_app/core/test_lightning_api.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/tests_app/core/test_lightning_api.py b/tests/tests_app/core/test_lightning_api.py index 0b7819349aa03..e319db862f2c7 100644 --- a/tests/tests_app/core/test_lightning_api.py +++ b/tests/tests_app/core/test_lightning_api.py @@ -515,8 +515,6 @@ def test_configure_api(): sleep(0.1) time_left -= 0.1 - # sleep(1000) - # Test Upload File files = {"uploaded_file": open(__file__, "rb")} From 7087b09aac4cc0b24a3fc1aaff7bcb2c099b0f36 Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 16 Dec 2022 08:49:24 +0100 Subject: [PATCH 04/10] update --- .gitignore | 4 ++-- src/lightning_app/api/http_methods.py | 32 ++++++++++++++++++++++++++- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/.gitignore b/.gitignore index 982309a8a356a..03c30bc8caa2a 100644 --- a/.gitignore +++ b/.gitignore @@ -110,8 +110,8 @@ celerybeat-schedule # dotenv .env -.env_staging -.env_local +.env.staging +.env.local # virtualenv .venv diff --git a/src/lightning_app/api/http_methods.py b/src/lightning_app/api/http_methods.py index 60db89bb76b13..5b097b569540c 100644 --- a/src/lightning_app/api/http_methods.py +++ b/src/lightning_app/api/http_methods.py @@ -23,6 +23,29 @@ def _signature_proxy_function(): @dataclass class FastApiMockRequest: + """This class is meant to mock FastAPI Request class that isn't pickalable. + + If a user relies on FastAPI Request annotation, the Lightning framework + patches the annotation before pickling and replace them right after. + + Finally, the FastAPI request is converting back to the FastApiMockRequest + before being delivered to the users. + + Example: + + import lightning as L + from fastapi import Request + from lightning.app.api import Post + + class Flow(L.LightningFlow): + + def request(self, request: Request) -> OutputRequestModel: + ... + + def configure_api(self): + return [Post("/api/v1/request", self.request)] + """ + _body: Optional[str] = None _json: Optional[str] = None _method: Optional[str] = None @@ -62,7 +85,10 @@ def is_disconnected(self): async def _mock_fastapi_request(request: Request): # TODO: Add more requests parameters. return FastApiMockRequest( - _body=await request.body(), _json=await request.json(), _headers=request.headers, _method=request.method + _body=await request.body(), + _json=await request.json(), + _headers=request.headers, + _method=request.method, ) @@ -91,6 +117,8 @@ def __init__(self, route: str, method: Callable, method_name: Optional[str] = No self.timeout = timeout self.kwargs = kwargs + # Enable the users to rely on FastAPI annotation typing with Request. + # Note: Only a part of the Request functionatilities are supported. self._patch_fast_api_request() def add_route(self, app: FastAPI, request_queue: Queue, responses_store: Dict[str, Any]) -> None: @@ -159,6 +187,7 @@ async def fn(*args, **kwargs): route(self.route, **self.kwargs)(_handle_request) def _patch_fast_api_request(self): + """This function replaces signature annotation for Request with its mock.""" for k, v in self.method_annotations.items(): if v == Request: v = FastApiMockRequest @@ -169,6 +198,7 @@ def _patch_fast_api_request(self): v._annotation = FastApiMockRequest def _unpatch_fast_api_request(self): + """This function replaces bacl signature annotation to fastapi Request.""" for k, v in self.method_annotations.items(): if v == FastApiMockRequest: v = Request From 7394a66efcfe91adbf04969428354f56f836d6a5 Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 16 Dec 2022 08:51:50 +0100 Subject: [PATCH 05/10] update --- tests/tests_app/core/test_lightning_api.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/tests_app/core/test_lightning_api.py b/tests/tests_app/core/test_lightning_api.py index e319db862f2c7..f3eb8f9bacda9 100644 --- a/tests/tests_app/core/test_lightning_api.py +++ b/tests/tests_app/core/test_lightning_api.py @@ -483,7 +483,9 @@ def request(self, config: InputRequestModel, request: Request) -> OutputRequestM self.counter += 1 if config.index % 5 == 0: raise HTTPException(status_code=400, detail="HERE") - print(request.body()) + assert request.body() + assert request.json() + assert request.headers return OutputRequestModel(name=config.name, counter=self.counter) def configure_api(self): From c809940cc95b92637960adb3ec0c967fa3086310 Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 16 Dec 2022 08:52:44 +0100 Subject: [PATCH 06/10] update --- src/lightning_app/CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index fbddc4210a869..d7cb861134035 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -11,6 +11,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `Lightning{Flow,Work}.lightningignores` attributes to programmatically ignore files before uploading to the cloud ([#15818](https://github.com/Lightning-AI/lightning/pull/15818)) +- Added partial support for fastapi `Request` annotation in `configure_api` handlers ([#16047](https://github.com/Lightning-AI/lightning/pull/16047)) + + ### Changed - From ff10e0415418a90bd31a487a6d87e42d44adeb5d Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 16 Dec 2022 09:33:56 +0100 Subject: [PATCH 07/10] update --- src/lightning_app/api/http_methods.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/lightning_app/api/http_methods.py b/src/lightning_app/api/http_methods.py index 5b097b569540c..c727f299e124b 100644 --- a/src/lightning_app/api/http_methods.py +++ b/src/lightning_app/api/http_methods.py @@ -22,13 +22,13 @@ def _signature_proxy_function(): @dataclass -class FastApiMockRequest: +class _FastApiMockRequest: """This class is meant to mock FastAPI Request class that isn't pickalable. If a user relies on FastAPI Request annotation, the Lightning framework patches the annotation before pickling and replace them right after. - Finally, the FastAPI request is converting back to the FastApiMockRequest + Finally, the FastAPI request is converting back to the _FastApiMockRequest before being delivered to the users. Example: @@ -84,7 +84,7 @@ def is_disconnected(self): async def _mock_fastapi_request(request: Request): # TODO: Add more requests parameters. - return FastApiMockRequest( + return _FastApiMockRequest( _body=await request.body(), _json=await request.json(), _headers=request.headers, @@ -190,22 +190,22 @@ def _patch_fast_api_request(self): """This function replaces signature annotation for Request with its mock.""" for k, v in self.method_annotations.items(): if v == Request: - v = FastApiMockRequest + v = _FastApiMockRequest self.method_annotations[k] = v for v in self.method_signature.parameters.values(): if v._annotation == Request: - v._annotation = FastApiMockRequest + v._annotation = _FastApiMockRequest def _unpatch_fast_api_request(self): """This function replaces bacl signature annotation to fastapi Request.""" for k, v in self.method_annotations.items(): - if v == FastApiMockRequest: + if v == _FastApiMockRequest: v = Request self.method_annotations[k] = v for v in self.method_signature.parameters.values(): - if v._annotation == FastApiMockRequest: + if v._annotation == _FastApiMockRequest: v._annotation = Request From 8cb8b1aa00a4573d48ad54fec94f1a356a252743 Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 16 Dec 2022 09:40:54 +0100 Subject: [PATCH 08/10] update --- src/lightning_app/api/http_methods.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lightning_app/api/http_methods.py b/src/lightning_app/api/http_methods.py index c727f299e124b..015324165c181 100644 --- a/src/lightning_app/api/http_methods.py +++ b/src/lightning_app/api/http_methods.py @@ -23,12 +23,12 @@ def _signature_proxy_function(): @dataclass class _FastApiMockRequest: - """This class is meant to mock FastAPI Request class that isn't pickalable. + """This class is meant to mock FastAPI Request class that isn't pickle-able. If a user relies on FastAPI Request annotation, the Lightning framework patches the annotation before pickling and replace them right after. - Finally, the FastAPI request is converting back to the _FastApiMockRequest + Finally, the FastAPI request is converted back to the _FastApiMockRequest before being delivered to the users. Example: @@ -198,7 +198,7 @@ def _patch_fast_api_request(self): v._annotation = _FastApiMockRequest def _unpatch_fast_api_request(self): - """This function replaces bacl signature annotation to fastapi Request.""" + """This function replaces back signature annotation to fastapi Request.""" for k, v in self.method_annotations.items(): if v == _FastApiMockRequest: v = Request From bf0db20220b90252bb556dd4bacdc84d7670da80 Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 16 Dec 2022 12:43:10 +0100 Subject: [PATCH 09/10] cleanup --- src/lightning_app/api/http_methods.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/lightning_app/api/http_methods.py b/src/lightning_app/api/http_methods.py index 015324165c181..982e02d959f08 100644 --- a/src/lightning_app/api/http_methods.py +++ b/src/lightning_app/api/http_methods.py @@ -190,8 +190,7 @@ def _patch_fast_api_request(self): """This function replaces signature annotation for Request with its mock.""" for k, v in self.method_annotations.items(): if v == Request: - v = _FastApiMockRequest - self.method_annotations[k] = v + self.method_annotations[k] = _FastApiMockRequest for v in self.method_signature.parameters.values(): if v._annotation == Request: @@ -201,8 +200,7 @@ def _unpatch_fast_api_request(self): """This function replaces back signature annotation to fastapi Request.""" for k, v in self.method_annotations.items(): if v == _FastApiMockRequest: - v = Request - self.method_annotations[k] = v + self.method_annotations[k] = Request for v in self.method_signature.parameters.values(): if v._annotation == _FastApiMockRequest: From 1b7e353a9b0bab506ea17e4f0c8d70a19ae4a4dc Mon Sep 17 00:00:00 2001 From: thomas Date: Fri, 16 Dec 2022 12:54:53 +0100 Subject: [PATCH 10/10] update --- src/lightning_app/core/app.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/lightning_app/core/app.py b/src/lightning_app/core/app.py index 87910170d3d74..9c3aeeb650de0 100644 --- a/src/lightning_app/core/app.py +++ b/src/lightning_app/core/app.py @@ -184,8 +184,7 @@ def __init__( def _update_index_file(self): # update index.html, # this should happen once for all apps before the ui server starts running. - if self.root_path != "": - frontend.update_index_file(FRONTEND_DIR, info=self.info, root_path=self.root_path) + frontend.update_index_file(FRONTEND_DIR, info=self.info, root_path=self.root_path) def get_component_by_name(self, component_name: str): """Returns the instance corresponding to the given component name."""