Skip to content

Commit

Permalink
[App] Cold start proxy in autoscaler (#16094)
Browse files Browse the repository at this point in the history
* cold start proxy

* Update src/lightning_app/components/serve/auto_scaler.py

* changelog

* better-doc

Co-authored-by: Akihiro Nitta <nitta@akihironitta.com>
Co-authored-by: thomas chaton <thomas@grid.ai>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
  • Loading branch information
4 people authored and carmocca committed Dec 20, 2022
1 parent dcc42dd commit cfb6c8d
Show file tree
Hide file tree
Showing 7 changed files with 229 additions and 37 deletions.
1 change: 1 addition & 0 deletions docs/source-app/api_references.rst
Expand Up @@ -46,6 +46,7 @@ ___________________
~multi_node.pytorch_spawn.PyTorchSpawnMultiNode
~multi_node.trainer.LightningTrainerMultiNode
~serve.auto_scaler.AutoScaler
~serve.auto_scaler.ColdStartProxy

----

Expand Down
1 change: 1 addition & 0 deletions requirements/app/test.txt
Expand Up @@ -4,6 +4,7 @@ pytest==7.2.0
pytest-timeout==2.1.0
pytest-cov==4.0.0
pytest-doctestplus>=0.9.0
pytest-asyncio==0.20.3
playwright==1.28.0
httpx
trio<0.22.0
Expand Down
4 changes: 4 additions & 0 deletions src/lightning_app/CHANGELOG.md
Expand Up @@ -25,6 +25,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `display_name` property to LightningWork for the cloud ([#16095](https://github.com/Lightning-AI/lightning/pull/16095))

- Added `ColdStartProxy` to the AutoScaler ([#16094](https://github.com/Lightning-AI/lightning/pull/16094))


### Changed

Expand Down Expand Up @@ -52,6 +54,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed a bug where `AutoScaler` would fail with min_replica=0 ([#16092](https://github.com/Lightning-AI/lightning/pull/16092)

- Fixed auto-batching to enable batching for requests coming even after batch interval but is in the queue ([#16110](https://github.com/Lightning-AI/lightning/pull/16110))


- Fixed a non-thread safe deepcopy in the scheduler ([#16114](https://github.com/Lightning-AI/lightning/pull/16114))

Expand Down
3 changes: 2 additions & 1 deletion src/lightning_app/components/__init__.py
Expand Up @@ -8,7 +8,7 @@
)
from lightning_app.components.python.popen import PopenPythonScript
from lightning_app.components.python.tracer import Code, TracerPythonScript
from lightning_app.components.serve.auto_scaler import AutoScaler
from lightning_app.components.serve.auto_scaler import AutoScaler, ColdStartProxy
from lightning_app.components.serve.gradio import ServeGradio
from lightning_app.components.serve.python_server import Category, Image, Number, PythonServer, Text
from lightning_app.components.serve.serve import ModelInferenceAPI
Expand All @@ -17,6 +17,7 @@

__all__ = [
"AutoScaler",
"ColdStartProxy",
"DatabaseClient",
"Database",
"PopenPythonScript",
Expand Down
14 changes: 12 additions & 2 deletions src/lightning_app/components/serve/__init__.py
@@ -1,6 +1,16 @@
from lightning_app.components.serve.auto_scaler import AutoScaler
from lightning_app.components.serve.auto_scaler import AutoScaler, ColdStartProxy
from lightning_app.components.serve.gradio import ServeGradio
from lightning_app.components.serve.python_server import Category, Image, Number, PythonServer, Text
from lightning_app.components.serve.streamlit import ServeStreamlit

__all__ = ["ServeGradio", "ServeStreamlit", "PythonServer", "Image", "Number", "Category", "Text", "AutoScaler"]
__all__ = [
"ServeGradio",
"ServeStreamlit",
"PythonServer",
"Image",
"Number",
"Category",
"Text",
"AutoScaler",
"ColdStartProxy",
]
184 changes: 152 additions & 32 deletions src/lightning_app/components/serve/auto_scaler.py
Expand Up @@ -6,7 +6,7 @@
import uuid
from base64 import b64encode
from itertools import cycle
from typing import Any, Dict, List, Optional, Tuple, Type
from typing import Any, Dict, List, Optional, Tuple, Type, Union

import requests
import uvicorn
Expand All @@ -32,7 +32,53 @@
logger = Logger(__name__)


def _raise_granular_exception(exception: Exception) -> None:
class ColdStartProxy:
"""ColdStartProxy allows users to configure the load balancer to use a proxy service while the work is cold
starting. This is useful with services that gets realtime requests but startup time for workers is high.
If the request body is same and the method is POST for the proxy service,
then the default implementation of `handle_request` can be used. In that case
initialize the proxy with the proxy url. Otherwise, the user can override the `handle_request`
Args:
proxy_url (str): The url of the proxy service
"""

def __init__(self, proxy_url):
self.proxy_url = proxy_url
self.proxy_timeout = 50
# checking `asyncio.iscoroutinefunction` instead of `inspect.iscoroutinefunction`
# because AsyncMock in the tests requres the former to pass
if not asyncio.iscoroutinefunction(self.handle_request):
raise TypeError("handle_request must be an `async` function")

async def handle_request(self, request: BaseModel) -> Any:
"""This method is called when the request is received while the work is cold starting. The default
implementation of this method is to forward the request body to the proxy service with POST method but the
user can override this method to handle the request in any way.
Args:
request (BaseModel): The request body, a pydantic model that is being
forwarded by load balancer which is a FastAPI service
"""
try:
async with aiohttp.ClientSession() as session:
headers = {
"accept": "application/json",
"Content-Type": "application/json",
}
async with session.post(
self.proxy_url,
json=request.dict(),
timeout=self.proxy_timeout,
headers=headers,
) as response:
return await response.json()
except Exception as ex:
raise HTTPException(status_code=500, detail=f"Error in proxy: {ex}")


def _maybe_raise_granular_exception(exception: Exception) -> None:
"""Handle an exception from hitting the model servers."""
if not isinstance(exception, Exception):
return
Expand Down Expand Up @@ -116,6 +162,8 @@ class _LoadBalancer(LightningWork):
requests to be batched. In any case, requests are processed as soon as `max_batch_size` is reached.
timeout_keep_alive: The number of seconds until it closes Keep-Alive connections if no new data is received.
timeout_inference_request: The number of seconds to wait for inference.
api_name: The name to be displayed on the UI. Normally, it is the name of the work class
cold_start_proxy: The proxy service to use while the work is cold starting.
**kwargs: Arguments passed to :func:`LightningWork.init` like ``CloudCompute``, ``BuildConfig``, etc.
"""

Expand All @@ -130,45 +178,63 @@ def __init__(
timeout_batching: float = 1,
timeout_keep_alive: int = 60,
timeout_inference_request: int = 60,
work_name: Optional[str] = "API", # used for displaying the name in the UI
api_name: Optional[str] = "API", # used for displaying the name in the UI
cold_start_proxy: Union[ColdStartProxy, str, None] = None,
**kwargs: Any,
) -> None:
super().__init__(cloud_compute=CloudCompute("default"), **kwargs)
self._input_type = input_type
self._output_type = output_type
self._timeout_keep_alive = timeout_keep_alive
self._timeout_inference_request = timeout_inference_request
self.servers = []
self._servers = []
self.max_batch_size = max_batch_size
self.timeout_batching = timeout_batching
self._iter = None
self._batch = []
self._responses = {} # {request_id: response}
self._last_batch_sent = 0
self._work_name = work_name
self._server_status = {}
self._api_name = api_name

if not endpoint.startswith("/"):
endpoint = "/" + endpoint

self.endpoint = endpoint
self._fastapi_app = None

self._cold_start_proxy = None
if cold_start_proxy:
if isinstance(cold_start_proxy, str):
self._cold_start_proxy = ColdStartProxy(proxy_url=cold_start_proxy)
elif isinstance(cold_start_proxy, ColdStartProxy):
self._cold_start_proxy = cold_start_proxy
else:
raise ValueError("cold_start_proxy must be of type ColdStartProxy or str")

async def send_batch(self, batch: List[Tuple[str, _BatchRequestModel]]):
server = next(self._iter) # round-robin
async def send_batch(self, batch: List[Tuple[str, _BatchRequestModel]], server_url: str):
request_data: List[_LoadBalancer._input_type] = [b[1] for b in batch]
batch_request_data = _BatchRequestModel(inputs=request_data)

try:
self._server_status[server_url] = False
async with aiohttp.ClientSession() as session:
headers = {
"accept": "application/json",
"Content-Type": "application/json",
}
async with session.post(
f"{server}{self.endpoint}",
f"{server_url}{self.endpoint}",
json=batch_request_data.dict(),
timeout=self._timeout_inference_request,
headers=headers,
) as response:
# resetting the server status so other requests can be
# scheduled on this node
if server_url in self._server_status:
# TODO - if the server returns an error, track that so
# we don't send more requests to it
self._server_status[server_url] = True
if response.status == 408:
raise HTTPException(408, "Request timed out")
response.raise_for_status()
Expand All @@ -181,48 +247,87 @@ async def send_batch(self, batch: List[Tuple[str, _BatchRequestModel]]):
except Exception as ex:
result = {request[0]: ex for request in batch}
self._responses.update(result)
finally:
self._server_status[server_url] = True

def _find_free_server(self) -> Optional[str]:
existing = set(self._server_status.keys())
for server in existing:
status = self._server_status.get(server, None)
if status is None:
logger.error("Server is not found in the status list. This should not happen.")
if status:
return server

async def consumer(self):
"""The consumer process that continuously checks for new requests and sends them to the API.
Two instances of this function should not be running with shared `_state_server` as that would create race
conditions
"""
self._last_batch_sent = time.time()
while True:
await asyncio.sleep(0.05)

batch = self._batch[: self.max_batch_size]
while batch and (
(len(batch) == self.max_batch_size) or ((time.time() - self._last_batch_sent) > self.timeout_batching)
):
asyncio.create_task(self.send_batch(batch))

self._batch = self._batch[self.max_batch_size :]
batch = self._batch[: self.max_batch_size]
is_batch_ready = len(batch) == self.max_batch_size
is_batch_timeout = time.time() - self._last_batch_sent > self.timeout_batching
server_url = self._find_free_server()
# setting the server status to be busy! This will be reset by
# the send_batch function after the server responds
if server_url is None:
continue
if batch and (is_batch_ready or is_batch_timeout):
# find server with capacity
asyncio.create_task(self.send_batch(batch, server_url))
# resetting the batch array, TODO - not locking the array
self._batch = self._batch[len(batch) :]
self._last_batch_sent = time.time()

async def process_request(self, data: BaseModel):
if not self.servers:
async def process_request(self, data: BaseModel, request_id=uuid.uuid4().hex):
if not self._servers and not self._cold_start_proxy:
raise HTTPException(500, "None of the workers are healthy!")

request_id = uuid.uuid4().hex
request: Tuple = (request_id, data)
self._batch.append(request)
# if no servers are available, proxy the request to cold start proxy handler
if not self._servers and self._cold_start_proxy:
return await self._cold_start_proxy.handle_request(data)

# if out of capacity, proxy the request to cold start proxy handler
if not self._has_processing_capacity() and self._cold_start_proxy:
return await self._cold_start_proxy.handle_request(data)

# if we have capacity, process the request
self._batch.append((request_id, data))
while True:
await asyncio.sleep(0.05)

if request_id in self._responses:
result = self._responses[request_id]
del self._responses[request_id]
_raise_granular_exception(result)
_maybe_raise_granular_exception(result)
return result

def _has_processing_capacity(self):
"""This function checks if we have processing capacity for one more request or not.
Depends on the value from here, we decide whether we should proxy the request or not
"""
if not self._fastapi_app:
return False
active_server_count = len(self._servers)
max_processable = self.max_batch_size * active_server_count
current_req_count = self._fastapi_app.num_current_requests
return current_req_count < max_processable

def run(self):
logger.info(f"servers: {self.servers}")
logger.info(f"servers: {self._servers}")
lock = asyncio.Lock()

self._iter = cycle(self.servers)
self._iter = cycle(self._servers)
self._last_batch_sent = time.time()

fastapi_app = _create_fastapi("Load Balancer")
security = HTTPBasic()
fastapi_app.SEND_TASK = None
self._fastapi_app = fastapi_app

input_type = self._input_type

Expand Down Expand Up @@ -269,8 +374,8 @@ def authenticate_private_endpoint(credentials: HTTPBasicCredentials = Depends(se
@fastapi_app.get("/system/info", response_model=_SysInfo)
async def sys_info(authenticated: bool = Depends(authenticate_private_endpoint)):
return _SysInfo(
num_workers=len(self.servers),
servers=self.servers,
num_workers=len(self._servers),
servers=self._servers,
num_requests=fastapi_app.num_current_requests,
processing_time=fastapi_app.last_processing_time,
global_request_count=fastapi_app.global_request_count,
Expand All @@ -279,8 +384,20 @@ async def sys_info(authenticated: bool = Depends(authenticate_private_endpoint))
@fastapi_app.put("/system/update-servers")
async def update_servers(servers: List[str], authenticated: bool = Depends(authenticate_private_endpoint)):
async with lock:
self.servers = servers
self._iter = cycle(self.servers)
self._servers = servers
self._iter = cycle(self._servers)
updated_servers = set()
# do not try to loop over the dict keys as the dict might change from other places
existing_servers = list(self._server_status.keys())
for server in servers:
updated_servers.add(server)
if server not in existing_servers:
self._server_status[server] = True
logger.info(f"Registering server {server}", self._server_status)
for existing in existing_servers:
if existing not in updated_servers:
logger.info(f"De-Registering server {existing}", self._server_status)
del self._server_status[existing]

@fastapi_app.post(self.endpoint, response_model=self._output_type)
async def balance_api(inputs: input_type):
Expand Down Expand Up @@ -308,7 +425,7 @@ def update_servers(self, server_works: List[LightningWork]):
AutoScaler uses this method to increase/decrease the number of works.
"""
old_servers = set(self.servers)
old_servers = set(self._servers)
server_urls: List[str] = [server.url for server in server_works if server.url]
new_servers = set(server_urls)

Expand Down Expand Up @@ -384,7 +501,7 @@ def _get_endpoint_info_page(self) -> Optional["APIAccessFrontend"]: # noqa: F82
else:
url = f"http://localhost:{self.port}{self.endpoint}"

frontend_objects = {"name": self._work_name, "url": url, "method": "POST", "request": None, "response": None}
frontend_objects = {"name": self._api_name, "url": url, "method": "POST", "request": None, "response": None}
code_samples = self.get_code_sample(url)
if code_samples:
frontend_objects["code_samples"] = code_samples
Expand Down Expand Up @@ -416,6 +533,7 @@ class AutoScaler(LightningFlow):
timeout_batching: (auto-batching) The number of seconds to wait before sending the requests to process.
input_type: Input type.
output_type: Output type.
cold_start_proxy: If provided, the proxy will be used while the worker machines are warming up.
.. testcode::
Expand Down Expand Up @@ -477,6 +595,7 @@ def __init__(
endpoint: str = "api/predict",
input_type: Type[BaseModel] = Dict,
output_type: Type[BaseModel] = Dict,
cold_start_proxy: Union[ColdStartProxy, str, None] = None,
*work_args: Any,
**work_kwargs: Any,
) -> None:
Expand Down Expand Up @@ -511,7 +630,8 @@ def __init__(
timeout_batching=timeout_batching,
cache_calls=True,
parallel=True,
work_name=self._work_cls.__name__,
api_name=self._work_cls.__name__,
cold_start_proxy=cold_start_proxy,
)
for _ in range(min_replicas):
work = self.create_work()
Expand Down

0 comments on commit cfb6c8d

Please sign in to comment.