diff --git a/src/lightning_app/core/api.py b/src/lightning_app/core/api.py index 568098f9fb95b..51b57b2c732a2 100644 --- a/src/lightning_app/core/api.py +++ b/src/lightning_app/core/api.py @@ -1,10 +1,12 @@ import asyncio +import json import os import queue import sys import traceback from copy import deepcopy from multiprocessing import Queue +from pathlib import Path from tempfile import TemporaryDirectory from threading import Event, Lock, Thread from time import sleep @@ -68,6 +70,7 @@ class SessionMiddleware: app_spec: Optional[List] = None app_status: Optional[AppStatus] = None +app_annotations: Optional[List] = None # In the future, this would be abstracted to support horizontal scaling. responses_store = {} @@ -345,6 +348,13 @@ async def get_status() -> AppStatus: return app_status +@fastapi_service.get("/api/v1/annotations", response_class=JSONResponse) +async def get_annotations() -> Union[List, Dict]: + """Get the annotations associated with this app.""" + global app_annotations + return app_annotations or [] + + @fastapi_service.get("/healthz", status_code=200) async def healthz(response: Response): """Health check endpoint used in the cloud FastAPI servers to check the status periodically.""" @@ -440,6 +450,7 @@ def start_server( global api_app_delta_queue global global_app_state_store global app_spec + global app_annotations app_spec = spec api_app_delta_queue = api_delta_queue @@ -449,6 +460,12 @@ def start_server( global_app_state_store.add(TEST_SESSION_UUID) + # Load annotations + annotations_path = Path("lightning-annotations.json").resolve() + if annotations_path.exists(): + with open(annotations_path) as f: + app_annotations = json.load(f) + refresher = UIRefresher(api_publish_state_queue, api_response_queue) refresher.setDaemon(True) refresher.start() diff --git a/tests/tests_app/core/test_lightning_api.py b/tests/tests_app/core/test_lightning_api.py index 057716555d718..3003ecdd62e2d 100644 --- a/tests/tests_app/core/test_lightning_api.py +++ b/tests/tests_app/core/test_lightning_api.py @@ -5,6 +5,7 @@ import sys from copy import deepcopy from multiprocessing import Process +from pathlib import Path from time import sleep, time from unittest import mock @@ -562,3 +563,36 @@ def test_configure_api(): time_left -= 0.1 assert process.exitcode == 0 process.kill() + + +@pytest.mark.anyio +@mock.patch("lightning_app.core.api.UIRefresher", mock.MagicMock()) +async def test_get_annotations(tmpdir): + cwd = os.getcwd() + os.chdir(tmpdir) + + Path("lightning-annotations.json").write_text('[{"test": 3}]') + + try: + app = AppStageTestingApp(FlowA(), log_level="debug") + app._update_layout() + app.stage = AppStage.BLOCKING + change_state_queue = _MockQueue("change_state_queue") + has_started_queue = _MockQueue("has_started_queue") + api_response_queue = _MockQueue("api_response_queue") + spec = extract_metadata_from_app(app) + start_server( + None, + change_state_queue, + api_response_queue, + has_started_queue=has_started_queue, + uvicorn_run=False, + spec=spec, + ) + + async with AsyncClient(app=fastapi_service, base_url="http://test") as client: + response = await client.get("/api/v1/annotations") + assert response.json() == [{"test": 3}] + finally: + # Cleanup + os.chdir(cwd)