Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sample datatype for Serve Component #15623

Merged
merged 41 commits into from Nov 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
aa7da85
introducing serve component
Nov 9, 2022
bdb9ee0
Merge branch 'master' into feature/serve-component
Nov 9, 2022
03bda90
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 9, 2022
3d51575
clean up tests
Nov 9, 2022
865c678
clean up tests
Nov 9, 2022
f06b29d
doctest
Nov 9, 2022
df0e770
mypy
Nov 9, 2022
baa16f4
Merge branch 'master' into feature/serve-component
Nov 9, 2022
81370bd
structure-fix
Nov 9, 2022
64203b7
Merge branch 'feature/serve-component' of github.com:Lightning-AI/lig…
Nov 9, 2022
b53832a
Merge branch 'master' into feature/serve-component
Nov 9, 2022
8300c34
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 9, 2022
3ab7c3e
cleanup
Nov 10, 2022
62c5d43
Merge branch 'master' into feature/serve-component
Nov 10, 2022
66167e2
master
Nov 10, 2022
efd5125
cleanup
Nov 10, 2022
0736c68
test fix
Nov 10, 2022
a10e363
Merge branch 'master' into feature/serve-component
Nov 10, 2022
95e2fa1
addition
Nov 10, 2022
eeffc62
Merge branch 'feature/serve-component' of github.com:Lightning-AI/lig…
Nov 10, 2022
0efeeb0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 10, 2022
100fb0b
test fix
Nov 10, 2022
6888d96
Merge branch 'feature/serve-component' of github.com:Lightning-AI/lig…
Nov 10, 2022
2485f26
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 10, 2022
731aabf
requirements
Nov 10, 2022
187a9f9
Merge branch 'feature/serve-component' of github.com:Lightning-AI/lig…
Nov 10, 2022
f46478c
Merge branch 'master' into feature/serve-component
Nov 10, 2022
2568e6a
getting future url
Nov 10, 2022
82f98d8
Merge branch 'feature/serve-component' of github.com:Lightning-AI/lig…
Nov 10, 2022
210f1a2
url for local
Nov 10, 2022
4ff98a2
sample data typeg
Nov 10, 2022
583d5c3
master
Nov 10, 2022
39d54ad
changes
Nov 10, 2022
e618748
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 10, 2022
1de071d
prediction
Nov 10, 2022
fe9fb76
Merge branch 'sample-datatype' of github.com:Lightning-AI/lightning i…
Nov 10, 2022
00e7bd0
updates
rlizzo Nov 10, 2022
30c2ddb
updates
rlizzo Nov 10, 2022
d3124df
manifest
Borda Nov 10, 2022
bef5bd5
fix type error
rlizzo Nov 10, 2022
d65ef20
fixed test
rlizzo Nov 10, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
42 changes: 42 additions & 0 deletions examples/app_server/app.py
@@ -0,0 +1,42 @@
# !pip install torchvision pydantic
import base64
import io

import torch
import torchvision
from PIL import Image
from pydantic import BaseModel

import lightning as L
from lightning.app.components.serve import Image as InputImage
from lightning.app.components.serve import PythonServer


class PyTorchServer(PythonServer):
def setup(self):
self._model = torchvision.models.resnet18(pretrained=True)
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self._model.to(self._device)

def predict(self, request):
image = base64.b64decode(request.image.encode("utf-8"))
image = Image.open(io.BytesIO(image))
transforms = torchvision.transforms.Compose(
[
torchvision.transforms.Resize(224),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
image = transforms(image)
image = image.to(self._device)
prediction = self._model(image.unsqueeze(0))
return {"prediction": prediction.argmax().item()}


class OutputData(BaseModel):
prediction: int


component = PyTorchServer(input_type=InputImage, output_type=OutputData, cloud_compute=L.CloudCompute("gpu"))
app = L.LightningApp(component)
1 change: 1 addition & 0 deletions src/lightning/__setup__.py
Expand Up @@ -35,6 +35,7 @@ def _adjust_manifest(**kwargs: Any) -> None:
"recursive-include requirements *.txt",
"recursive-include src/lightning/app/ui *",
"recursive-include src/lightning/cli/*-template *", # Add templates as build-in
"include src/lightning/app/components/serve/catimage.png" + os.linesep,
# fixme: this is strange, this shall work with setup find package - include
"prune src/lightning_app",
"prune src/lightning_lite",
Expand Down
1 change: 1 addition & 0 deletions src/lightning_app/__setup__.py
Expand Up @@ -50,6 +50,7 @@ def _adjust_manifest(**__: Any) -> None:
"recursive-exclude src *.md" + os.linesep,
"recursive-exclude requirements *.txt" + os.linesep,
"recursive-include src/lightning_app *.md" + os.linesep,
"include src/lightning_app/components/serve/catimage.png" + os.linesep,
"recursive-include requirements/app *.txt" + os.linesep,
"recursive-include src/lightning_app/cli/*-template *" + os.linesep, # Add templates
]
Expand Down
4 changes: 3 additions & 1 deletion src/lightning_app/components/__init__.py
Expand Up @@ -9,7 +9,7 @@
from lightning_app.components.python.popen import PopenPythonScript
from lightning_app.components.python.tracer import Code, TracerPythonScript
from lightning_app.components.serve.gradio import ServeGradio
from lightning_app.components.serve.python_server import PythonServer
from lightning_app.components.serve.python_server import Image, Number, PythonServer
from lightning_app.components.serve.serve import ModelInferenceAPI
from lightning_app.components.serve.streamlit import ServeStreamlit
from lightning_app.components.training import LightningTrainingComponent, PyTorchLightningScriptRunner
Expand All @@ -24,6 +24,8 @@
"ServeStreamlit",
"ModelInferenceAPI",
"PythonServer",
"Image",
"Number",
"MultiNode",
"LiteMultiNode",
"LightningTrainingComponent",
Expand Down
4 changes: 2 additions & 2 deletions src/lightning_app/components/serve/__init__.py
@@ -1,5 +1,5 @@
from lightning_app.components.serve.gradio import ServeGradio
from lightning_app.components.serve.python_server import PythonServer
from lightning_app.components.serve.python_server import Image, Number, PythonServer
from lightning_app.components.serve.streamlit import ServeStreamlit

__all__ = ["ServeGradio", "ServeStreamlit", "PythonServer"]
__all__ = ["ServeGradio", "ServeStreamlit", "PythonServer", "Image", "Number"]
Binary file added src/lightning_app/components/serve/catimage.png
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
34 changes: 32 additions & 2 deletions src/lightning_app/components/serve/python_server.py
@@ -1,5 +1,7 @@
import abc
from typing import Any, Dict
import base64
from pathlib import Path
from typing import Any, Dict, Optional

import uvicorn
from fastapi import FastAPI
Expand All @@ -12,6 +14,12 @@
logger = Logger(__name__)


def image_to_base64(image_path):
with open(image_path, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return encoded_string.decode("UTF-8")


class _DefaultInputData(BaseModel):
payload: str

Expand All @@ -20,6 +28,25 @@ class _DefaultOutputData(BaseModel):
prediction: str


class Image(BaseModel):
image: Optional[str]

@staticmethod
def _get_sample_data() -> Dict[Any, Any]:
imagepath = Path(__file__).absolute().parent / "catimage.png"
with open(imagepath, "rb") as image_file:
encoded_string = base64.b64encode(image_file.read())
return {"image": encoded_string.decode("UTF-8")}


class Number(BaseModel):
prediction: Optional[int]

@staticmethod
def _get_sample_data() -> Dict[Any, Any]:
return {"prediction": 463}


class PythonServer(LightningWork, abc.ABC):
def __init__( # type: ignore
self,
Expand Down Expand Up @@ -110,6 +137,9 @@ def predict(self, request: Any) -> Any:

@staticmethod
def _get_sample_dict_from_datatype(datatype: Any) -> dict:
if hasattr(datatype, "_get_sample_data"):
return datatype._get_sample_data()

datatype_props = datatype.schema()["properties"]
out: Dict[str, Any] = {}
for k, v in datatype_props.items():
Expand Down Expand Up @@ -141,7 +171,7 @@ def _attach_frontend(self, fastapi_app: FastAPI) -> None:
url = self._future_url if self._future_url else self.url
if not url:
# if the url is still empty, point it to localhost
url = f"http://127.0.0.1{self.port}"
url = f"http://127.0.0.1:{self.port}"
url = f"{url}/predict"
datatype_parse_error = False
try:
Expand Down
16 changes: 15 additions & 1 deletion tests/tests_app/components/serve/test_python_server.py
@@ -1,6 +1,6 @@
import multiprocessing as mp

from lightning_app.components import PythonServer
from lightning_app.components import Image, Number, PythonServer
from lightning_app.utilities.network import _configure_session, find_free_network_port


Expand Down Expand Up @@ -29,3 +29,17 @@ def test_python_server_component():
res = session.post(f"http://127.0.0.1:{port}/predict", json={"payload": "test"})
process.terminate()
assert res.json()["prediction"] == "test"


def test_image_sample_data():
data = Image()._get_sample_data()
assert isinstance(data, dict)
assert "image" in data
assert len(data["image"]) > 100


def test_number_sample_data():
data = Number()._get_sample_data()
assert isinstance(data, dict)
assert "prediction" in data
assert data["prediction"] == 463