Skip to content

Commit

Permalink
[App] Implement ready for components (#16129)
Browse files Browse the repository at this point in the history
  • Loading branch information
ethanwharris authored and carmocca committed Dec 20, 2022
1 parent cf3934b commit dc5260c
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 6 deletions.
7 changes: 7 additions & 0 deletions src/lightning_app/components/serve/auto_scaler.py
Expand Up @@ -212,6 +212,8 @@ def __init__(
else:
raise ValueError("cold_start_proxy must be of type ColdStartProxy or str")

self.ready = False

async def send_batch(self, batch: List[Tuple[str, _BatchRequestModel]], server_url: str):
request_data: List[_LoadBalancer._input_type] = [b[1] for b in batch]
batch_request_data = _BatchRequestModel(inputs=request_data)
Expand Down Expand Up @@ -410,6 +412,7 @@ async def balance_api(inputs: input_type):
)

logger.info(f"Your load balancer has started. The endpoint is 'http://{self.host}:{self.port}{self.endpoint}'")
self.ready = True

uvicorn.run(
fastapi_app,
Expand Down Expand Up @@ -641,6 +644,10 @@ def __init__(
def workers(self) -> List[LightningWork]:
return [self.get_work(i) for i in range(self.num_replicas)]

@property
def ready(self) -> bool:
return self.load_balancer.ready

def create_work(self) -> LightningWork:
"""Replicates a LightningWork instance with args and kwargs provided via ``__init__``."""
cloud_compute = self._work_kwargs.get("cloud_compute", None)
Expand Down
3 changes: 3 additions & 0 deletions src/lightning_app/components/serve/gradio.py
Expand Up @@ -42,6 +42,8 @@ def __init__(self, *args, **kwargs):
assert self.outputs
self._model = None

self.ready = False

@property
def model(self):
return self._model
Expand All @@ -62,6 +64,7 @@ def run(self, *args, **kwargs):
self._model = self.build_model()
fn = partial(self.predict, *args, **kwargs)
fn.__name__ = self.predict.__name__
self.ready = True
gradio.Interface(
fn=fn,
inputs=self.inputs,
Expand Down
3 changes: 3 additions & 0 deletions src/lightning_app/components/serve/python_server.py
Expand Up @@ -193,6 +193,8 @@ def predict(self, request):
self._input_type = input_type
self._output_type = output_type

self.ready = False

def setup(self, *args, **kwargs) -> None:
"""This method is called before the server starts. Override this if you need to download the model or
initialize the weights, setting up pipelines etc.
Expand Down Expand Up @@ -300,6 +302,7 @@ def run(self, *args: Any, **kwargs: Any) -> Any:
fastapi_app = FastAPI()
self._attach_predict_fn(fastapi_app)

self.ready = True
logger.info(
f"Your {self.__class__.__qualname__} has started. View it in your browser: http://{self.host}:{self.port}"
)
Expand Down
4 changes: 4 additions & 0 deletions src/lightning_app/components/serve/serve.py
Expand Up @@ -64,6 +64,8 @@ def __init__(
self.workers = workers
self._model = None

self.ready = False

@property
def model(self):
return self._model
Expand Down Expand Up @@ -108,9 +110,11 @@ def run(self):
"serve:fastapi_service",
]
process = subprocess.Popen(command, env=env, cwd=os.path.dirname(__file__))
self.ready = True
process.wait()
else:
self._populate_app(fastapi_service)
self.ready = True
self._launch_server(fastapi_service)

def _populate_app(self, fastapi_service: FastAPI):
Expand Down
2 changes: 1 addition & 1 deletion src/lightning_app/core/flow.py
Expand Up @@ -797,7 +797,7 @@ def __init__(self, work):
@property
def ready(self) -> bool:
ready = getattr(self.work, "ready", None)
if ready:
if ready is not None:
return ready
return self.work.url != ""

Expand Down
16 changes: 11 additions & 5 deletions tests/tests_app/core/test_lightning_flow.py
Expand Up @@ -12,7 +12,7 @@

import lightning_app
from lightning_app import CloudCompute, LightningApp
from lightning_app.core.flow import LightningFlow
from lightning_app.core.flow import _RootFlow, LightningFlow
from lightning_app.core.work import LightningWork
from lightning_app.runners import MultiProcessRuntime
from lightning_app.storage import Path
Expand Down Expand Up @@ -868,10 +868,10 @@ def test_lightning_flow_flows_and_works():
class WorkReady(LightningWork):
def __init__(self):
super().__init__(parallel=True)
self.counter = 0
self.ready = False

def run(self):
self.counter += 1
self.ready = True


class FlowReady(LightningFlow):
Expand All @@ -890,7 +890,13 @@ def run(self):
self._exit()


def test_flow_ready():
class RootFlowReady(_RootFlow):
def __init__(self):
super().__init__(WorkReady())


@pytest.mark.parametrize("flow", [FlowReady, RootFlowReady])
def test_flow_ready(flow):
"""This test validates that the app status queue is populated correctly."""

mock_queue = _MockQueue("api_publish_state_queue")
Expand All @@ -910,7 +916,7 @@ def lagged_run_once(method):
state["done"] = new_done
return False

app = LightningApp(FlowReady())
app = LightningApp(flow())
app._run = partial(run_patch, method=app._run)
app.run_once = partial(lagged_run_once, method=app.run_once)
MultiProcessRuntime(app, start_server=False).dispatch()
Expand Down

0 comments on commit dc5260c

Please sign in to comment.