Skip to content

Commit

Permalink
[App] Improve the autoscaler UI (#16063)
Browse files Browse the repository at this point in the history
[App] Improve the autoscaler UI (#16063)

(cherry picked from commit 39d27f6)
  • Loading branch information
akihironitta authored and Borda committed Dec 20, 2022
1 parent 6996dc8 commit a0f8f70
Show file tree
Hide file tree
Showing 8 changed files with 110 additions and 26 deletions.
2 changes: 1 addition & 1 deletion docs/source-app/api_references.rst
Expand Up @@ -45,7 +45,7 @@ ___________________
~multi_node.lite.LiteMultiNode
~multi_node.pytorch_spawn.PyTorchSpawnMultiNode
~multi_node.trainer.LightningTrainerMultiNode
~auto_scaler.AutoScaler
~serve.auto_scaler.AutoScaler

----

Expand Down
14 changes: 5 additions & 9 deletions examples/app_server_with_auto_scaler/app.py
@@ -1,5 +1,5 @@
# ! pip install torch torchvision
from typing import Any, List
from typing import List

import torch
import torchvision
Expand All @@ -8,16 +8,12 @@
import lightning as L


class RequestModel(BaseModel):
image: str # bytecode


class BatchRequestModel(BaseModel):
inputs: List[RequestModel]
inputs: List[L.app.components.Image]


class BatchResponse(BaseModel):
outputs: List[Any]
outputs: List[L.app.components.Number]


class PyTorchServer(L.app.components.PythonServer):
Expand Down Expand Up @@ -81,8 +77,8 @@ def scale(self, replicas: int, metrics: dict) -> int:
max_replicas=4,
autoscale_interval=10,
endpoint="predict",
input_type=RequestModel,
output_type=Any,
input_type=L.app.components.Image,
output_type=L.app.components.Number,
timeout_batching=1,
max_batch_size=8,
)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Expand Up @@ -79,8 +79,8 @@ module = [
"lightning_app.components.serve.types.image",
"lightning_app.components.serve.types.type",
"lightning_app.components.serve.python_server",
"lightning_app.components.serve.auto_scaler",
"lightning_app.components.training",
"lightning_app.components.auto_scaler",
"lightning_app.core.api",
"lightning_app.core.app",
"lightning_app.core.flow",
Expand Down
2 changes: 2 additions & 0 deletions src/lightning_app/CHANGELOG.md
Expand Up @@ -11,6 +11,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added partial support for fastapi `Request` annotation in `configure_api` handlers ([#16047](https://github.com/Lightning-AI/lightning/pull/16047))

- Added a nicer UI with URL and examples for the autoscaler component ([#16063](https://github.com/Lightning-AI/lightning/pull/16063))

- Added more datatypes to serving component ([#16018](https://github.com/Lightning-AI/lightning/pull/16018))

- Added `work.delete` method to delete the work ([#16103](https://github.com/Lightning-AI/lightning/pull/16103))
Expand Down
2 changes: 1 addition & 1 deletion src/lightning_app/components/__init__.py
@@ -1,4 +1,3 @@
from lightning_app.components.auto_scaler import AutoScaler
from lightning_app.components.database.client import DatabaseClient
from lightning_app.components.database.server import Database
from lightning_app.components.multi_node import (
Expand All @@ -9,6 +8,7 @@
)
from lightning_app.components.python.popen import PopenPythonScript
from lightning_app.components.python.tracer import Code, TracerPythonScript
from lightning_app.components.serve.auto_scaler import AutoScaler
from lightning_app.components.serve.gradio import ServeGradio
from lightning_app.components.serve.python_server import Category, Image, Number, PythonServer, Text
from lightning_app.components.serve.serve import ModelInferenceAPI
Expand Down
3 changes: 2 additions & 1 deletion src/lightning_app/components/serve/__init__.py
@@ -1,5 +1,6 @@
from lightning_app.components.serve.auto_scaler import AutoScaler
from lightning_app.components.serve.gradio import ServeGradio
from lightning_app.components.serve.python_server import Category, Image, Number, PythonServer, Text
from lightning_app.components.serve.streamlit import ServeStreamlit

__all__ = ["ServeGradio", "ServeStreamlit", "PythonServer", "Image", "Number", "Category", "Text"]
__all__ = ["ServeGradio", "ServeStreamlit", "PythonServer", "Image", "Number", "Category", "Text", "AutoScaler"]
Expand Up @@ -6,7 +6,7 @@
import uuid
from base64 import b64encode
from itertools import cycle
from typing import Any, Dict, List, Tuple, Type
from typing import Any, Dict, List, Optional, Tuple, Type

import requests
import uvicorn
Expand All @@ -15,11 +15,13 @@
from fastapi.responses import RedirectResponse
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from pydantic import BaseModel
from starlette.staticfiles import StaticFiles
from starlette.status import HTTP_401_UNAUTHORIZED

from lightning_app.core.flow import LightningFlow
from lightning_app.core.work import LightningWork
from lightning_app.utilities.app_helpers import Logger
from lightning_app.utilities.cloud import is_running_in_cloud
from lightning_app.utilities.imports import _is_aiohttp_available, requires
from lightning_app.utilities.packaging.cloud_compute import CloudCompute

Expand Down Expand Up @@ -114,20 +116,21 @@ class _LoadBalancer(LightningWork):
requests to be batched. In any case, requests are processed as soon as `max_batch_size` is reached.
timeout_keep_alive: The number of seconds until it closes Keep-Alive connections if no new data is received.
timeout_inference_request: The number of seconds to wait for inference.
\**kwargs: Arguments passed to :func:`LightningWork.init` like ``CloudCompute``, ``BuildConfig``, etc.
**kwargs: Arguments passed to :func:`LightningWork.init` like ``CloudCompute``, ``BuildConfig``, etc.
"""

@requires(["aiohttp"])
def __init__(
self,
input_type: BaseModel,
output_type: BaseModel,
input_type: Type[BaseModel],
output_type: Type[BaseModel],
endpoint: str,
max_batch_size: int = 8,
# all timeout args are in seconds
timeout_batching: int = 1,
timeout_batching: float = 1,
timeout_keep_alive: int = 60,
timeout_inference_request: int = 60,
work_name: Optional[str] = "API", # used for displaying the name in the UI
**kwargs: Any,
) -> None:
super().__init__(cloud_compute=CloudCompute("default"), **kwargs)
Expand All @@ -142,6 +145,7 @@ def __init__(
self._batch = []
self._responses = {} # {request_id: response}
self._last_batch_sent = 0
self._work_name = work_name

if not endpoint.startswith("/"):
endpoint = "/" + endpoint
Expand Down Expand Up @@ -280,6 +284,14 @@ async def update_servers(servers: List[str], authenticated: bool = Depends(authe
async def balance_api(inputs: self._input_type):
return await self.process_request(inputs)

endpoint_info_page = self._get_endpoint_info_page()
if endpoint_info_page:
fastapi_app.mount(
"/endpoint-info", StaticFiles(directory=endpoint_info_page.serve_dir, html=True), name="static"
)

logger.info(f"Your load balancer has started. The endpoint is 'http://{self.host}:{self.port}{self.endpoint}'")

uvicorn.run(
fastapi_app,
host=self.host,
Expand Down Expand Up @@ -332,6 +344,60 @@ def send_request_to_update_servers(self, servers: List[str]):
response = requests.put(f"{self.url}/system/update-servers", json=servers, headers=headers, timeout=10)
response.raise_for_status()

@staticmethod
def _get_sample_dict_from_datatype(datatype: Any) -> dict:
if not hasattr(datatype, "schema"):
# not a pydantic model
raise TypeError(f"datatype must be a pydantic model, for the UI to be generated. but got {datatype}")

if hasattr(datatype, "_get_sample_data"):
return datatype._get_sample_data()

datatype_props = datatype.schema()["properties"]
out: Dict[str, Any] = {}
lut = {"string": "data string", "number": 0.0, "integer": 0, "boolean": False}
for k, v in datatype_props.items():
if v["type"] not in lut:
raise TypeError("Unsupported type")
out[k] = lut[v["type"]]
return out

def get_code_sample(self, url: str) -> Optional[str]:
input_type: Any = self._input_type
output_type: Any = self._output_type

if not (hasattr(input_type, "request_code_sample") and hasattr(output_type, "response_code_sample")):
return None
return f"{input_type.request_code_sample(url)}\n{output_type.response_code_sample()}"

def _get_endpoint_info_page(self) -> Optional["APIAccessFrontend"]: # noqa: F821
try:
from lightning_api_access import APIAccessFrontend
except ModuleNotFoundError:
logger.warn("APIAccessFrontend not found. Please install lightning-api-access to enable the UI")
return

if is_running_in_cloud():
url = f"{self._future_url}{self.endpoint}"
else:
url = f"http://localhost:{self.port}{self.endpoint}"

frontend_objects = {"name": self._work_name, "url": url, "method": "POST", "request": None, "response": None}
code_samples = self.get_code_sample(url)
if code_samples:
frontend_objects["code_samples"] = code_samples
# TODO also set request/response for JS UI
else:
try:
request = self._get_sample_dict_from_datatype(self._input_type)
response = self._get_sample_dict_from_datatype(self._output_type)
except TypeError:
return None
else:
frontend_objects["request"] = request
frontend_objects["response"] = response
return APIAccessFrontend(apis=[frontend_objects])


class AutoScaler(LightningFlow):
"""The ``AutoScaler`` can be used to automatically change the number of replicas of the given server in
Expand Down Expand Up @@ -403,8 +469,8 @@ def __init__(
max_batch_size: int = 8,
timeout_batching: float = 1,
endpoint: str = "api/predict",
input_type: BaseModel = Dict,
output_type: BaseModel = Dict,
input_type: Type[BaseModel] = Dict,
output_type: Type[BaseModel] = Dict,
*work_args: Any,
**work_kwargs: Any,
) -> None:
Expand Down Expand Up @@ -438,6 +504,7 @@ def __init__(
timeout_batching=timeout_batching,
cache_calls=True,
parallel=True,
work_name=self._work_cls.__name__,
)
for _ in range(min_replicas):
work = self.create_work()
Expand Down Expand Up @@ -574,5 +641,8 @@ def autoscale(self) -> None:
self._last_autoscale = time.time()

def configure_layout(self):
tabs = [{"name": "Swagger", "content": self.load_balancer.url}]
tabs = [
{"name": "Endpoint Info", "content": f"{self.load_balancer}/endpoint-info"},
{"name": "Swagger", "content": self.load_balancer.url},
]
return tabs
@@ -1,10 +1,11 @@
import time
from unittest import mock
from unittest.mock import patch

import pytest

from lightning_app import CloudCompute, LightningWork
from lightning_app.components import AutoScaler
from lightning_app.components import AutoScaler, Text


class EmptyWork(LightningWork):
Expand Down Expand Up @@ -32,8 +33,8 @@ def test_num_replicas_after_init():


@patch("uvicorn.run")
@patch("lightning_app.components.auto_scaler._LoadBalancer.url")
@patch("lightning_app.components.auto_scaler.AutoScaler.num_pending_requests")
@patch("lightning_app.components.serve.auto_scaler._LoadBalancer.url")
@patch("lightning_app.components.serve.auto_scaler.AutoScaler.num_pending_requests")
def test_num_replicas_not_above_max_replicas(*_):
"""Test self.num_replicas doesn't exceed max_replicas."""
max_replicas = 6
Expand All @@ -52,8 +53,8 @@ def test_num_replicas_not_above_max_replicas(*_):


@patch("uvicorn.run")
@patch("lightning_app.components.auto_scaler._LoadBalancer.url")
@patch("lightning_app.components.auto_scaler.AutoScaler.num_pending_requests")
@patch("lightning_app.components.serve.auto_scaler._LoadBalancer.url")
@patch("lightning_app.components.serve.auto_scaler.AutoScaler.num_pending_requests")
def test_num_replicas_not_belo_min_replicas(*_):
"""Test self.num_replicas doesn't exceed max_replicas."""
min_replicas = 1
Expand Down Expand Up @@ -98,3 +99,17 @@ def test_create_work_cloud_compute_cloned():
auto_scaler = AutoScaler(EmptyWork, cloud_compute=cloud_compute)
_ = auto_scaler.create_work()
assert auto_scaler._work_kwargs["cloud_compute"] is not cloud_compute


fastapi_mock = mock.MagicMock()
mocked_fastapi_creater = mock.MagicMock(return_value=fastapi_mock)


@patch("lightning_app.components.serve.auto_scaler._create_fastapi", mocked_fastapi_creater)
@patch("lightning_app.components.serve.auto_scaler.uvicorn.run", mock.MagicMock())
def test_API_ACCESS_ENDPOINT_creation():
auto_scaler = AutoScaler(EmptyWork, input_type=Text, output_type=Text)
assert auto_scaler.load_balancer._work_name == "EmptyWork"

auto_scaler.load_balancer.run()
fastapi_mock.mount.assert_called_once_with("/endpoint-info", mock.ANY, name="static")

0 comments on commit a0f8f70

Please sign in to comment.