From a0f8f70947f584c7ba78167d97a53d3f88337181 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Mon, 19 Dec 2022 19:25:24 +0900 Subject: [PATCH] [App] Improve the autoscaler UI (#16063) [App] Improve the autoscaler UI (#16063) (cherry picked from commit 39d27f637096b911fb08b4ec5f55fe3ce8c2c585) --- docs/source-app/api_references.rst | 2 +- examples/app_server_with_auto_scaler/app.py | 14 ++- pyproject.toml | 2 +- src/lightning_app/CHANGELOG.md | 2 + src/lightning_app/components/__init__.py | 2 +- .../components/serve/__init__.py | 3 +- .../components/{ => serve}/auto_scaler.py | 86 +++++++++++++++++-- .../{ => serve}/test_auto_scaler.py | 25 ++++-- 8 files changed, 110 insertions(+), 26 deletions(-) rename src/lightning_app/components/{ => serve}/auto_scaler.py (85%) rename tests/tests_app/components/{ => serve}/test_auto_scaler.py (74%) diff --git a/docs/source-app/api_references.rst b/docs/source-app/api_references.rst index 2272f7bf13c41..931a9864d261f 100644 --- a/docs/source-app/api_references.rst +++ b/docs/source-app/api_references.rst @@ -45,7 +45,7 @@ ___________________ ~multi_node.lite.LiteMultiNode ~multi_node.pytorch_spawn.PyTorchSpawnMultiNode ~multi_node.trainer.LightningTrainerMultiNode - ~auto_scaler.AutoScaler + ~serve.auto_scaler.AutoScaler ---- diff --git a/examples/app_server_with_auto_scaler/app.py b/examples/app_server_with_auto_scaler/app.py index 70799827776a8..453db2424b404 100644 --- a/examples/app_server_with_auto_scaler/app.py +++ b/examples/app_server_with_auto_scaler/app.py @@ -1,5 +1,5 @@ # ! pip install torch torchvision -from typing import Any, List +from typing import List import torch import torchvision @@ -8,16 +8,12 @@ import lightning as L -class RequestModel(BaseModel): - image: str # bytecode - - class BatchRequestModel(BaseModel): - inputs: List[RequestModel] + inputs: List[L.app.components.Image] class BatchResponse(BaseModel): - outputs: List[Any] + outputs: List[L.app.components.Number] class PyTorchServer(L.app.components.PythonServer): @@ -81,8 +77,8 @@ def scale(self, replicas: int, metrics: dict) -> int: max_replicas=4, autoscale_interval=10, endpoint="predict", - input_type=RequestModel, - output_type=Any, + input_type=L.app.components.Image, + output_type=L.app.components.Number, timeout_batching=1, max_batch_size=8, ) diff --git a/pyproject.toml b/pyproject.toml index 8611ef9323deb..4461d956634c6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,8 +79,8 @@ module = [ "lightning_app.components.serve.types.image", "lightning_app.components.serve.types.type", "lightning_app.components.serve.python_server", + "lightning_app.components.serve.auto_scaler", "lightning_app.components.training", - "lightning_app.components.auto_scaler", "lightning_app.core.api", "lightning_app.core.app", "lightning_app.core.flow", diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index cf868ed52faf2..6586ec4eac8cf 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -11,6 +11,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added partial support for fastapi `Request` annotation in `configure_api` handlers ([#16047](https://github.com/Lightning-AI/lightning/pull/16047)) +- Added a nicer UI with URL and examples for the autoscaler component ([#16063](https://github.com/Lightning-AI/lightning/pull/16063)) + - Added more datatypes to serving component ([#16018](https://github.com/Lightning-AI/lightning/pull/16018)) - Added `work.delete` method to delete the work ([#16103](https://github.com/Lightning-AI/lightning/pull/16103)) diff --git a/src/lightning_app/components/__init__.py b/src/lightning_app/components/__init__.py index 18208aa316f2d..5fd8af6b055de 100644 --- a/src/lightning_app/components/__init__.py +++ b/src/lightning_app/components/__init__.py @@ -1,4 +1,3 @@ -from lightning_app.components.auto_scaler import AutoScaler from lightning_app.components.database.client import DatabaseClient from lightning_app.components.database.server import Database from lightning_app.components.multi_node import ( @@ -9,6 +8,7 @@ ) from lightning_app.components.python.popen import PopenPythonScript from lightning_app.components.python.tracer import Code, TracerPythonScript +from lightning_app.components.serve.auto_scaler import AutoScaler from lightning_app.components.serve.gradio import ServeGradio from lightning_app.components.serve.python_server import Category, Image, Number, PythonServer, Text from lightning_app.components.serve.serve import ModelInferenceAPI diff --git a/src/lightning_app/components/serve/__init__.py b/src/lightning_app/components/serve/__init__.py index a12cb1b45ee71..ac02e69c4f2ab 100644 --- a/src/lightning_app/components/serve/__init__.py +++ b/src/lightning_app/components/serve/__init__.py @@ -1,5 +1,6 @@ +from lightning_app.components.serve.auto_scaler import AutoScaler from lightning_app.components.serve.gradio import ServeGradio from lightning_app.components.serve.python_server import Category, Image, Number, PythonServer, Text from lightning_app.components.serve.streamlit import ServeStreamlit -__all__ = ["ServeGradio", "ServeStreamlit", "PythonServer", "Image", "Number", "Category", "Text"] +__all__ = ["ServeGradio", "ServeStreamlit", "PythonServer", "Image", "Number", "Category", "Text", "AutoScaler"] diff --git a/src/lightning_app/components/auto_scaler.py b/src/lightning_app/components/serve/auto_scaler.py similarity index 85% rename from src/lightning_app/components/auto_scaler.py rename to src/lightning_app/components/serve/auto_scaler.py index 13948ba50af89..6027249de850f 100644 --- a/src/lightning_app/components/auto_scaler.py +++ b/src/lightning_app/components/serve/auto_scaler.py @@ -6,7 +6,7 @@ import uuid from base64 import b64encode from itertools import cycle -from typing import Any, Dict, List, Tuple, Type +from typing import Any, Dict, List, Optional, Tuple, Type import requests import uvicorn @@ -15,11 +15,13 @@ from fastapi.responses import RedirectResponse from fastapi.security import HTTPBasic, HTTPBasicCredentials from pydantic import BaseModel +from starlette.staticfiles import StaticFiles from starlette.status import HTTP_401_UNAUTHORIZED from lightning_app.core.flow import LightningFlow from lightning_app.core.work import LightningWork from lightning_app.utilities.app_helpers import Logger +from lightning_app.utilities.cloud import is_running_in_cloud from lightning_app.utilities.imports import _is_aiohttp_available, requires from lightning_app.utilities.packaging.cloud_compute import CloudCompute @@ -114,20 +116,21 @@ class _LoadBalancer(LightningWork): requests to be batched. In any case, requests are processed as soon as `max_batch_size` is reached. timeout_keep_alive: The number of seconds until it closes Keep-Alive connections if no new data is received. timeout_inference_request: The number of seconds to wait for inference. - \**kwargs: Arguments passed to :func:`LightningWork.init` like ``CloudCompute``, ``BuildConfig``, etc. + **kwargs: Arguments passed to :func:`LightningWork.init` like ``CloudCompute``, ``BuildConfig``, etc. """ @requires(["aiohttp"]) def __init__( self, - input_type: BaseModel, - output_type: BaseModel, + input_type: Type[BaseModel], + output_type: Type[BaseModel], endpoint: str, max_batch_size: int = 8, # all timeout args are in seconds - timeout_batching: int = 1, + timeout_batching: float = 1, timeout_keep_alive: int = 60, timeout_inference_request: int = 60, + work_name: Optional[str] = "API", # used for displaying the name in the UI **kwargs: Any, ) -> None: super().__init__(cloud_compute=CloudCompute("default"), **kwargs) @@ -142,6 +145,7 @@ def __init__( self._batch = [] self._responses = {} # {request_id: response} self._last_batch_sent = 0 + self._work_name = work_name if not endpoint.startswith("/"): endpoint = "/" + endpoint @@ -280,6 +284,14 @@ async def update_servers(servers: List[str], authenticated: bool = Depends(authe async def balance_api(inputs: self._input_type): return await self.process_request(inputs) + endpoint_info_page = self._get_endpoint_info_page() + if endpoint_info_page: + fastapi_app.mount( + "/endpoint-info", StaticFiles(directory=endpoint_info_page.serve_dir, html=True), name="static" + ) + + logger.info(f"Your load balancer has started. The endpoint is 'http://{self.host}:{self.port}{self.endpoint}'") + uvicorn.run( fastapi_app, host=self.host, @@ -332,6 +344,60 @@ def send_request_to_update_servers(self, servers: List[str]): response = requests.put(f"{self.url}/system/update-servers", json=servers, headers=headers, timeout=10) response.raise_for_status() + @staticmethod + def _get_sample_dict_from_datatype(datatype: Any) -> dict: + if not hasattr(datatype, "schema"): + # not a pydantic model + raise TypeError(f"datatype must be a pydantic model, for the UI to be generated. but got {datatype}") + + if hasattr(datatype, "_get_sample_data"): + return datatype._get_sample_data() + + datatype_props = datatype.schema()["properties"] + out: Dict[str, Any] = {} + lut = {"string": "data string", "number": 0.0, "integer": 0, "boolean": False} + for k, v in datatype_props.items(): + if v["type"] not in lut: + raise TypeError("Unsupported type") + out[k] = lut[v["type"]] + return out + + def get_code_sample(self, url: str) -> Optional[str]: + input_type: Any = self._input_type + output_type: Any = self._output_type + + if not (hasattr(input_type, "request_code_sample") and hasattr(output_type, "response_code_sample")): + return None + return f"{input_type.request_code_sample(url)}\n{output_type.response_code_sample()}" + + def _get_endpoint_info_page(self) -> Optional["APIAccessFrontend"]: # noqa: F821 + try: + from lightning_api_access import APIAccessFrontend + except ModuleNotFoundError: + logger.warn("APIAccessFrontend not found. Please install lightning-api-access to enable the UI") + return + + if is_running_in_cloud(): + url = f"{self._future_url}{self.endpoint}" + else: + url = f"http://localhost:{self.port}{self.endpoint}" + + frontend_objects = {"name": self._work_name, "url": url, "method": "POST", "request": None, "response": None} + code_samples = self.get_code_sample(url) + if code_samples: + frontend_objects["code_samples"] = code_samples + # TODO also set request/response for JS UI + else: + try: + request = self._get_sample_dict_from_datatype(self._input_type) + response = self._get_sample_dict_from_datatype(self._output_type) + except TypeError: + return None + else: + frontend_objects["request"] = request + frontend_objects["response"] = response + return APIAccessFrontend(apis=[frontend_objects]) + class AutoScaler(LightningFlow): """The ``AutoScaler`` can be used to automatically change the number of replicas of the given server in @@ -403,8 +469,8 @@ def __init__( max_batch_size: int = 8, timeout_batching: float = 1, endpoint: str = "api/predict", - input_type: BaseModel = Dict, - output_type: BaseModel = Dict, + input_type: Type[BaseModel] = Dict, + output_type: Type[BaseModel] = Dict, *work_args: Any, **work_kwargs: Any, ) -> None: @@ -438,6 +504,7 @@ def __init__( timeout_batching=timeout_batching, cache_calls=True, parallel=True, + work_name=self._work_cls.__name__, ) for _ in range(min_replicas): work = self.create_work() @@ -574,5 +641,8 @@ def autoscale(self) -> None: self._last_autoscale = time.time() def configure_layout(self): - tabs = [{"name": "Swagger", "content": self.load_balancer.url}] + tabs = [ + {"name": "Endpoint Info", "content": f"{self.load_balancer}/endpoint-info"}, + {"name": "Swagger", "content": self.load_balancer.url}, + ] return tabs diff --git a/tests/tests_app/components/test_auto_scaler.py b/tests/tests_app/components/serve/test_auto_scaler.py similarity index 74% rename from tests/tests_app/components/test_auto_scaler.py rename to tests/tests_app/components/serve/test_auto_scaler.py index 672b05bbc9a15..6bd5aa958b6bf 100644 --- a/tests/tests_app/components/test_auto_scaler.py +++ b/tests/tests_app/components/serve/test_auto_scaler.py @@ -1,10 +1,11 @@ import time +from unittest import mock from unittest.mock import patch import pytest from lightning_app import CloudCompute, LightningWork -from lightning_app.components import AutoScaler +from lightning_app.components import AutoScaler, Text class EmptyWork(LightningWork): @@ -32,8 +33,8 @@ def test_num_replicas_after_init(): @patch("uvicorn.run") -@patch("lightning_app.components.auto_scaler._LoadBalancer.url") -@patch("lightning_app.components.auto_scaler.AutoScaler.num_pending_requests") +@patch("lightning_app.components.serve.auto_scaler._LoadBalancer.url") +@patch("lightning_app.components.serve.auto_scaler.AutoScaler.num_pending_requests") def test_num_replicas_not_above_max_replicas(*_): """Test self.num_replicas doesn't exceed max_replicas.""" max_replicas = 6 @@ -52,8 +53,8 @@ def test_num_replicas_not_above_max_replicas(*_): @patch("uvicorn.run") -@patch("lightning_app.components.auto_scaler._LoadBalancer.url") -@patch("lightning_app.components.auto_scaler.AutoScaler.num_pending_requests") +@patch("lightning_app.components.serve.auto_scaler._LoadBalancer.url") +@patch("lightning_app.components.serve.auto_scaler.AutoScaler.num_pending_requests") def test_num_replicas_not_belo_min_replicas(*_): """Test self.num_replicas doesn't exceed max_replicas.""" min_replicas = 1 @@ -98,3 +99,17 @@ def test_create_work_cloud_compute_cloned(): auto_scaler = AutoScaler(EmptyWork, cloud_compute=cloud_compute) _ = auto_scaler.create_work() assert auto_scaler._work_kwargs["cloud_compute"] is not cloud_compute + + +fastapi_mock = mock.MagicMock() +mocked_fastapi_creater = mock.MagicMock(return_value=fastapi_mock) + + +@patch("lightning_app.components.serve.auto_scaler._create_fastapi", mocked_fastapi_creater) +@patch("lightning_app.components.serve.auto_scaler.uvicorn.run", mock.MagicMock()) +def test_API_ACCESS_ENDPOINT_creation(): + auto_scaler = AutoScaler(EmptyWork, input_type=Text, output_type=Text) + assert auto_scaler.load_balancer._work_name == "EmptyWork" + + auto_scaler.load_balancer.run() + fastapi_mock.mount.assert_called_once_with("/endpoint-info", mock.ANY, name="static")