Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[App] Serve datatypes with better client code #16018

Merged
merged 20 commits into from Dec 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/lightning_app/CHANGELOG.md
Expand Up @@ -12,9 +12,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added a progres bar while connecting to an app through the CLI ([#16035](https://github.com/Lightning-AI/lightning/pull/16035))


- Added partial support for fastapi `Request` annotation in `configure_api` handlers ([#16047](https://github.com/Lightning-AI/lightning/pull/16047))

- Added more datatypes to serving component ([#16018](https://github.com/Lightning-AI/lightning/pull/16018))

### Changed

Expand Down
4 changes: 3 additions & 1 deletion src/lightning_app/components/__init__.py
Expand Up @@ -10,7 +10,7 @@
from lightning_app.components.python.popen import PopenPythonScript
from lightning_app.components.python.tracer import Code, TracerPythonScript
from lightning_app.components.serve.gradio import ServeGradio
from lightning_app.components.serve.python_server import Image, Number, PythonServer
from lightning_app.components.serve.python_server import Category, Image, Number, PythonServer, Text
from lightning_app.components.serve.serve import ModelInferenceAPI
from lightning_app.components.serve.streamlit import ServeStreamlit
from lightning_app.components.training import LightningTrainerScript, PyTorchLightningScriptRunner
Expand All @@ -28,6 +28,8 @@
"PythonServer",
"Image",
"Number",
"Category",
"Text",
"MultiNode",
"LiteMultiNode",
"LightningTrainerScript",
Expand Down
4 changes: 2 additions & 2 deletions src/lightning_app/components/serve/__init__.py
@@ -1,5 +1,5 @@
from lightning_app.components.serve.gradio import ServeGradio
from lightning_app.components.serve.python_server import Image, Number, PythonServer
from lightning_app.components.serve.python_server import Category, Image, Number, PythonServer, Text
from lightning_app.components.serve.streamlit import ServeStreamlit

__all__ = ["ServeGradio", "ServeStreamlit", "PythonServer", "Image", "Number"]
__all__ = ["ServeGradio", "ServeStreamlit", "PythonServer", "Image", "Number", "Category", "Text"]
119 changes: 97 additions & 22 deletions src/lightning_app/components/serve/python_server.py
Expand Up @@ -2,9 +2,9 @@
import base64
import os
import platform
from pathlib import Path
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, TYPE_CHECKING

import requests
import uvicorn
from fastapi import FastAPI
from lightning_utilities.core.imports import compare_version, module_available
Expand All @@ -14,6 +14,9 @@
from lightning_app.utilities.app_helpers import Logger
from lightning_app.utilities.imports import _is_torch_available, requires

if TYPE_CHECKING:
from lightning_app.frontend.frontend import Frontend

logger = Logger(__name__)

# Skip doctests if requirements aren't available
Expand Down Expand Up @@ -48,18 +51,80 @@ class Image(BaseModel):
image: Optional[str]

@staticmethod
def _get_sample_data() -> Dict[Any, Any]:
imagepath = Path(__file__).parent / "catimage.png"
with open(imagepath, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return {"image": encoded_string.decode("UTF-8")}
def get_sample_data() -> Dict[Any, Any]:
url = "https://raw.githubusercontent.com/Lightning-AI/LAI-Triton-Server-Component/main/catimage.png"
img = requests.get(url).content
img = base64.b64encode(img).decode("UTF-8")
return {"image": img}

@staticmethod
def request_code_sample(url: str) -> str:
return (
"""import base64
from pathlib import Path
import requests

imgurl = "https://raw.githubusercontent.com/Lightning-AI/LAI-Triton-Server-Component/main/catimage.png"
img = requests.get(imgurl).content
img = base64.b64encode(img).decode("UTF-8")
response = requests.post('"""
+ url
+ """', json={
"image": img
})"""
)

@staticmethod
def response_code_sample() -> str:
return """img = response.json()["image"]
img = base64.b64decode(img.encode("utf-8"))
Path("response.png").write_bytes(img)
"""


class Category(BaseModel):
category: Optional[int]

@staticmethod
def get_sample_data() -> Dict[Any, Any]:
return {"prediction": 463}

@staticmethod
def response_code_sample() -> str:
return """print("Predicted category is: ", response.json()["category"])
"""


class Text(BaseModel):
text: Optional[str]

@staticmethod
def get_sample_data() -> Dict[Any, Any]:
return {"text": "A portrait of a person looking away from the camera"}

@staticmethod
def request_code_sample(url: str) -> str:
return (
"""import base64
from pathlib import Path
import requests

response = requests.post('"""
+ url
+ """', json={
"text": "A portrait of a person looking away from the camera"
})
"""
)


class Number(BaseModel):
# deprecated
# TODO remove this in favour of Category
prediction: Optional[int]

@staticmethod
def _get_sample_data() -> Dict[Any, Any]:
def get_sample_data() -> Dict[Any, Any]:
return {"prediction": 463}


Expand Down Expand Up @@ -154,8 +219,8 @@ def predict(self, request: Any) -> Any:

@staticmethod
def _get_sample_dict_from_datatype(datatype: Any) -> dict:
if hasattr(datatype, "_get_sample_data"):
return datatype._get_sample_data()
if hasattr(datatype, "get_sample_data"):
return datatype.get_sample_data()

datatype_props = datatype.schema()["properties"]
out: Dict[str, Any] = {}
Expand Down Expand Up @@ -187,7 +252,15 @@ def predict_fn(request: input_type): # type: ignore

fastapi_app.post("/predict", response_model=output_type)(predict_fn)

def configure_layout(self) -> None:
def get_code_sample(self, url: str) -> Optional[str]:
input_type: Any = self.configure_input_type()
output_type: Any = self.configure_output_type()

if not (hasattr(input_type, "request_code_sample") and hasattr(output_type, "response_code_sample")):
return None
return f"{input_type.request_code_sample(url)}\n{output_type.response_code_sample()}"

def configure_layout(self) -> Optional["Frontend"]:
try:
from lightning_api_access import APIAccessFrontend
except ModuleNotFoundError:
Expand All @@ -203,17 +276,19 @@ def configure_layout(self) -> None:
except TypeError:
return None

return APIAccessFrontend(
apis=[
{
"name": class_name,
"url": url,
"method": "POST",
"request": request,
"response": response,
}
]
)
frontend_payload = {
"name": class_name,
"url": url,
"method": "POST",
"request": request,
"response": response,
}

code_sample = self.get_code_sample(url)
if code_sample:
frontend_payload["code_sample"] = code_sample

return APIAccessFrontend(apis=[frontend_payload])

def run(self, *args: Any, **kwargs: Any) -> Any:
"""Run method takes care of configuring and setting up a FastAPI server behind the scenes.
Expand Down
4 changes: 2 additions & 2 deletions tests/tests_app/components/serve/test_python_server.py
Expand Up @@ -32,14 +32,14 @@ def test_python_server_component():


def test_image_sample_data():
data = Image()._get_sample_data()
data = Image().get_sample_data()
assert isinstance(data, dict)
assert "image" in data
assert len(data["image"]) > 100


def test_number_sample_data():
data = Number()._get_sample_data()
data = Number().get_sample_data()
assert isinstance(data, dict)
assert "prediction" in data
assert data["prediction"] == 463