Skip to content

Commit

Permalink
[App] Introduce auto scaler (#15769)
Browse files Browse the repository at this point in the history
* Exlucde __pycache__ in setuptools

* Add load balancer example

* wip

* Update example

* rename

* remove prints

* _LoadBalancer -> LoadBalancer

* AutoScaler(work)

* change var name

* remove locust

* Update docs

* include autoscaler in api ref

* docs typo

* docs typo

* docs typo

* docs typo

* remove unused loadtest

* remove unused device_type

* clean up

* clean up

* clean up

* Add docstring

* type

* env vars to args

* expose an API for users to override to customise autoscaling logic

* update example

* comment

* udpate var name

* fix scale mechanism and clean up

* Update exampl

* ignore mypy

* Add test file

* .

* update impl and update tests

* Update changlog

* .

* revert docs

* update test

* update state to keep calling 'flow.run()'

Co-authored-by: Aniket Maurya <theaniketmaurya@gmail.com>

* Add aiohttp to base requirements

* Update docs

Co-authored-by: Luca Antiga <luca.antiga@gmail.com>

* Use deserializer utility

* fake trigger

* wip: protect /system/* with basic auth

* read password at runtime

* Change env var name

* import torch as optional

* Don't overcreate works

* simplify imports

* Update example

* aiohttp

* Add work_args work_kwargs

* More docs

* remove FIXME

* Apply Jirka's suggestions

Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* clean example device

* add comment on init threshold value

* bad merge

* nit: logging format

* {in,out}put_schema -> {in,out}put_type

* lowercase

* docs on seconds

* process_time -> processing_time

* Dont modify work state from flow

* Update tests

* worker_url -> endpoint

* fix exampl

* Fix default scale logic

* Fix default scale logic

* Fix num_pending_works

* Update num_pending_works

* Fix bug creating too many works

* Remove up/downscale_threshold args

* Update example

* Add typing

* Fix example in docstring

* Fix default scale logic

* Update src/lightning_app/components/auto_scaler.py

Co-authored-by: Noha Alon <nohalon@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* rename method

* rename locvar

* Add todo

* docs ci

* docs ci

* asdfafsdasdf pls docs

* Apply suggestions from code review

Co-authored-by: Ethan Harris <ethanwharris@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* .

* doc

* Update src/lightning_app/components/auto_scaler.py

Co-authored-by: Noha Alon <nohalon@gmail.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Revert "[pre-commit.ci] auto fixes from pre-commit.com hooks"

This reverts commit 24983a0.

* Revert "Update src/lightning_app/components/auto_scaler.py"

This reverts commit 56ea78b.

* Remove redefinition

* Remove load balancer run blocker

* raise RuntimeError

* remove has_sent

* lower the default timeout_batching from 10 to 1

* remove debug

* update the default timeout_batching

* .

* tighten condition

* fix endpoint

* typo in runtimeerror cond

* async lock update severs

* add a test

* {in,out}put_type typing

* Update examples/app_server_with_auto_scaler/app.py

Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>

* Update .actions/setup_tools.py

Co-authored-by: Aniket Maurya <theaniketmaurya@gmail.com>
Co-authored-by: Luca Antiga <luca.antiga@gmail.com>
Co-authored-by: Jirka Borovec <6035284+Borda@users.noreply.github.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Noha Alon <nohalon@gmail.com>
Co-authored-by: Ethan Harris <ethanwharris@gmail.com>
Co-authored-by: Akihiro Nitta <aki@pop-os.localdomain>
Co-authored-by: thomas chaton <thomas@grid.ai>

(cherry picked from commit 64b19fb)
  • Loading branch information
akihironitta authored and Borda committed Dec 7, 2022
1 parent d90f624 commit 6bdb7cb
Show file tree
Hide file tree
Showing 9 changed files with 754 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/source-app/api_references.rst
Expand Up @@ -37,6 +37,7 @@ ___________________
~training.LightningTrainerScript
~serve.gradio.ServeGradio
~serve.serve.ModelInferenceAPI
~auto_scaler.AutoScaler

----

Expand Down
86 changes: 86 additions & 0 deletions examples/app_server_with_auto_scaler/app.py
@@ -0,0 +1,86 @@
from typing import Any, List

import torch
import torchvision
from pydantic import BaseModel

import lightning as L


class RequestModel(BaseModel):
image: str # bytecode


class BatchRequestModel(BaseModel):
inputs: List[RequestModel]


class BatchResponse(BaseModel):
outputs: List[Any]


class PyTorchServer(L.app.components.PythonServer):
def __init__(self, *args, **kwargs):
super().__init__(
port=L.app.utilities.network.find_free_network_port(),
input_type=BatchRequestModel,
output_type=BatchResponse,
cloud_compute=L.CloudCompute("gpu"),
)

def setup(self):
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self._model = torchvision.models.resnet18(pretrained=True).to(self._device)

def predict(self, requests: BatchRequestModel):
transforms = torchvision.transforms.Compose(
[
torchvision.transforms.Resize(224),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
images = []
for request in requests.inputs:
image = L.app.components.serve.types.image.Image.deserialize(request.image)
image = transforms(image).unsqueeze(0)
images.append(image)
images = torch.cat(images)
images = images.to(self._device)
predictions = self._model(images)
results = predictions.argmax(1).cpu().numpy().tolist()
return BatchResponse(outputs=[{"prediction": pred} for pred in results])


class MyAutoScaler(L.app.components.AutoScaler):
def scale(self, replicas: int, metrics: dict) -> int:
"""The default scaling logic that users can override."""
# scale out if the number of pending requests exceeds max batch size.
max_requests_per_work = self.max_batch_size
pending_requests_per_running_or_pending_work = metrics["pending_requests"] / (
replicas + metrics["pending_works"]
)
if pending_requests_per_running_or_pending_work >= max_requests_per_work:
return replicas + 1

# scale in if the number of pending requests is below 25% of max_requests_per_work
min_requests_per_work = max_requests_per_work * 0.25
pending_requests_per_running_work = metrics["pending_requests"] / replicas
if pending_requests_per_running_work < min_requests_per_work:
return replicas - 1

return replicas


app = L.LightningApp(
MyAutoScaler(
PyTorchServer,
min_replicas=2,
max_replicas=4,
autoscale_interval=10,
endpoint="predict",
input_type=RequestModel,
output_type=Any,
timeout_batching=1,
)
)
1 change: 1 addition & 0 deletions pyproject.toml
Expand Up @@ -80,6 +80,7 @@ module = [
"lightning_app.components.serve.types.type",
"lightning_app.components.serve.python_server",
"lightning_app.components.training",
"lightning_app.components.auto_scaler",
"lightning_app.core.api",
"lightning_app.core.app",
"lightning_app.core.flow",
Expand Down
1 change: 1 addition & 0 deletions requirements/app/base.txt
Expand Up @@ -13,3 +13,4 @@ inquirer>=2.10.0
psutil<5.9.4
click<=8.1.3
s3fs>=2022.5.0, <2022.8.3
aiohttp>=3.8.0, <=3.8.3
2 changes: 1 addition & 1 deletion src/lightning_app/CHANGELOG.md
Expand Up @@ -13,8 +13,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added the CLI command `lightning delete app` to delete a lightning app on the cloud ([#15783](https://github.com/Lightning-AI/lightning/pull/15783))
- Added a CloudMultiProcessBackend which enables running a child App from within the Flow in the cloud ([#15800](https://github.com/Lightning-AI/lightning/pull/15800))
- Apps without UIs no longer activate the "Open App" button when running in the cloud ([#15875](https://github.com/Lightning-AI/lightning/pull/15875))
- Added `AutoScaler` component ([#15769](https://github.com/Lightning-AI/lightning/pull/15769))
- Added the property `ready` of the LightningFlow to inform when the `Open App` should be visible ([#15921](https://github.com/Lightning-AI/lightning/pull/15921))

- Added private work attributed `_start_method` to customize how to start the works ([#15923](https://github.com/Lightning-AI/lightning/pull/15923))


Expand Down
2 changes: 2 additions & 0 deletions src/lightning_app/components/__init__.py
@@ -1,3 +1,4 @@
from lightning_app.components.auto_scaler import AutoScaler
from lightning_app.components.database.client import DatabaseClient
from lightning_app.components.database.server import Database
from lightning_app.components.multi_node import (
Expand All @@ -15,6 +16,7 @@
from lightning_app.components.training import LightningTrainerScript, PyTorchLightningScriptRunner

__all__ = [
"AutoScaler",
"DatabaseClient",
"Database",
"PopenPythonScript",
Expand Down

0 comments on commit 6bdb7cb

Please sign in to comment.