diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index eb38b4c263fa8..5522e0be9a4db 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -24,6 +24,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added private work attributed `_start_method` to customize how to start the works ([#15923](https://github.com/Lightning-AI/lightning/pull/15923)) +- Added a `configure_layout` method to the `LightningWork` which can be used to control how the work is handled in the layout of a parent flow ([#15926](https://github.com/Lightning-AI/lightning/pull/15926)) + ### Changed diff --git a/src/lightning_app/components/serve/gradio.py b/src/lightning_app/components/serve/gradio.py index 6e9b1d8777f67..15e314dfbb82f 100644 --- a/src/lightning_app/components/serve/gradio.py +++ b/src/lightning_app/components/serve/gradio.py @@ -78,3 +78,6 @@ def run(self, *args, **kwargs): server_port=self.port, enable_queue=self.enable_queue, ) + + def configure_layout(self) -> str: + return self.url diff --git a/src/lightning_app/components/serve/python_server.py b/src/lightning_app/components/serve/python_server.py index 99d51ac1cf4fc..04107e9094edc 100644 --- a/src/lightning_app/components/serve/python_server.py +++ b/src/lightning_app/components/serve/python_server.py @@ -8,7 +8,6 @@ from fastapi import FastAPI from lightning_utilities.core.imports import module_available from pydantic import BaseModel -from starlette.staticfiles import StaticFiles from lightning_app.core.queues import MultiProcessQueue from lightning_app.core.work import LightningWork @@ -222,49 +221,30 @@ def predict_fn(request: input_type): # type: ignore fastapi_app.post("/predict", response_model=output_type)(predict_fn) - def _attach_frontend(self, fastapi_app: FastAPI) -> None: - from lightning_api_access import APIAccessFrontend - - class_name = self.__class__.__name__ - url = self._future_url if self._future_url else self.url - if not url: - # if the url is still empty, point it to localhost - url = f"http://127.0.0.1:{self.port}" - url = f"{url}/predict" - datatype_parse_error = False - try: - request = self._get_sample_dict_from_datatype(self.configure_input_type()) - except TypeError: - datatype_parse_error = True - - try: - response = self._get_sample_dict_from_datatype(self.configure_output_type()) - except TypeError: - datatype_parse_error = True - - if datatype_parse_error: - - @fastapi_app.get("/") - def index() -> str: - return ( - "Automatic generation of the UI is only supported for simple, " - "non-nested datatype with types string, integer, float and boolean" - ) - - return - - frontend = APIAccessFrontend( - apis=[ - { - "name": class_name, - "url": url, - "method": "POST", - "request": request, - "response": response, - } - ] - ) - fastapi_app.mount("/", StaticFiles(directory=frontend.serve_dir, html=True), name="static") + def configure_layout(self) -> None: + if module_available("lightning_api_access"): + from lightning_api_access import APIAccessFrontend + + class_name = self.__class__.__name__ + url = f"{self.url}/predict" + + try: + request = self._get_sample_dict_from_datatype(self.configure_input_type()) + response = self._get_sample_dict_from_datatype(self.configure_output_type()) + except TypeError: + return None + + return APIAccessFrontend( + apis=[ + { + "name": class_name, + "url": url, + "method": "POST", + "request": request, + "response": response, + } + ] + ) def run(self, *args: Any, **kwargs: Any) -> Any: """Run method takes care of configuring and setting up a FastAPI server behind the scenes. @@ -275,7 +255,6 @@ def run(self, *args: Any, **kwargs: Any) -> Any: fastapi_app = FastAPI() self._attach_predict_fn(fastapi_app) - self._attach_frontend(fastapi_app) logger.info(f"Your app has started. View it in your browser: http://{self.host}:{self.port}") uvicorn.run(app=fastapi_app, host=self.host, port=self.port, log_level="error") diff --git a/src/lightning_app/components/serve/serve.py b/src/lightning_app/components/serve/serve.py index 150ca522e591b..8b6f35364cc38 100644 --- a/src/lightning_app/components/serve/serve.py +++ b/src/lightning_app/components/serve/serve.py @@ -10,7 +10,6 @@ import uvicorn from fastapi import FastAPI from fastapi.responses import JSONResponse -from starlette.responses import RedirectResponse from lightning_app.components.serve.types import _DESERIALIZER, _SERIALIZER from lightning_app.core.work import LightningWork @@ -37,10 +36,6 @@ async def run(self, data) -> Any: return self.serialize(self.predict(self.deserialize(data))) -async def _redirect(): - return RedirectResponse("/docs") - - class ModelInferenceAPI(LightningWork, abc.ABC): def __init__( self, @@ -121,7 +116,6 @@ def run(self): def _populate_app(self, fastapi_service: FastAPI): self._model = self.build_model() - fastapi_service.get("/")(_redirect) fastapi_service.post("/predict", response_class=JSONResponse)( _InferenceCallable( deserialize=_DESERIALIZER[self.input] if self.input else self.deserialize, @@ -134,6 +128,9 @@ def _launch_server(self, fastapi_service: FastAPI): logger.info(f"Your app has started. View it in your browser: http://{self.host}:{self.port}") uvicorn.run(app=fastapi_service, host=self.host, port=self.port, log_level="error") + def configure_layout(self) -> str: + return f"{self.url}/docs" + def _maybe_create_instance() -> Optional[ModelInferenceAPI]: """This function tries to re-create the user `ModelInferenceAPI` if the environment associated with multi diff --git a/src/lightning_app/components/serve/streamlit.py b/src/lightning_app/components/serve/streamlit.py index ed543bd1de7b8..1a325d60fecee 100644 --- a/src/lightning_app/components/serve/streamlit.py +++ b/src/lightning_app/components/serve/streamlit.py @@ -63,6 +63,9 @@ def on_exit(self) -> None: if self._process is not None: self._process.kill() + def configure_layout(self) -> str: + return self.url + class _PatchedWork: """The ``_PatchedWork`` is used to emulate a work instance from a subprocess. This is acheived by patching the diff --git a/src/lightning_app/core/flow.py b/src/lightning_app/core/flow.py index 56947b0d2cbef..321e594499089 100644 --- a/src/lightning_app/core/flow.py +++ b/src/lightning_app/core/flow.py @@ -10,7 +10,7 @@ from lightning_app.frontend import Frontend from lightning_app.storage import Path from lightning_app.storage.drive import _maybe_create_drive, Drive -from lightning_app.utilities.app_helpers import _is_json_serializable, _LightningAppRef, _set_child_name +from lightning_app.utilities.app_helpers import _is_json_serializable, _LightningAppRef, _set_child_name, is_overridden from lightning_app.utilities.component import _sanitize_state from lightning_app.utilities.exceptions import ExitAppException from lightning_app.utilities.introspection import _is_init_context, _is_run_context @@ -777,4 +777,6 @@ def run(self): self.work.run() def configure_layout(self): - return [{"name": "Main", "content": self.work}] + if is_overridden("configure_layout", self.work): + return [{"name": "Main", "content": self.work}] + return [] diff --git a/src/lightning_app/core/work.py b/src/lightning_app/core/work.py index 857cbc9447ff1..60d1ea62d8afb 100644 --- a/src/lightning_app/core/work.py +++ b/src/lightning_app/core/work.py @@ -3,7 +3,7 @@ import warnings from copy import deepcopy from functools import partial, wraps -from typing import Any, Callable, Dict, List, Optional, Type, Union +from typing import Any, Callable, Dict, List, Optional, Type, TYPE_CHECKING, Union from deepdiff import DeepHash, Delta @@ -33,6 +33,9 @@ ) from lightning_app.utilities.proxies import Action, LightningWorkSetAttrProxy, ProxyWorkRun, unwrap, WorkRunExecutor +if TYPE_CHECKING: + from lightning_app.frontend import Frontend + class LightningWork: @@ -629,3 +632,45 @@ def apply_flow_delta(self, delta: Delta): property_object.fset(self, value) else: self._default_setattr(name, value) + + def configure_layout(self) -> Union[None, str, "Frontend"]: + """Configure the UI of this LightningWork. + + You can either + + 1. Return a single :class:`~lightning_app.frontend.frontend.Frontend` object to serve a user interface + for this Work. + 2. Return a string containing a URL to act as the user interface for this Work. + 3. Return ``None`` to indicate that this Work doesn't currently have a user interface. + + **Example:** Serve a static directory (with at least a file index.html inside). + + .. code-block:: python + + from lightning_app.frontend import StaticWebFrontend + + + class Work(LightningWork): + def configure_layout(self): + return StaticWebFrontend("path/to/folder/to/serve") + + **Example:** Arrange the UI of my children in tabs (default UI by Lightning). + + .. code-block:: python + + class Work(LightningWork): + def configure_layout(self): + return [ + dict(name="First Tab", content=self.child0), + dict(name="Second Tab", content=self.child1), + dict(name="Lightning", content="https://lightning.ai"), + ] + + If you don't implement ``configure_layout``, Lightning will use ``self.url``. + + Note: + This hook gets called at the time of app creation and then again as part of the loop. If desired, a + returned URL can depend on the state. This is not the case if the work returns a + :class:`~lightning_app.frontend.frontend.Frontend`. These need to be provided at the time of app creation + in order for the runtime to start the server. + """ diff --git a/src/lightning_app/utilities/layout.py b/src/lightning_app/utilities/layout.py index 15079fcb6964b..11f26019cb406 100644 --- a/src/lightning_app/utilities/layout.py +++ b/src/lightning_app/utilities/layout.py @@ -4,7 +4,7 @@ import lightning_app from lightning_app.frontend.frontend import Frontend -from lightning_app.utilities.app_helpers import _MagicMockJsonSerializable +from lightning_app.utilities.app_helpers import _MagicMockJsonSerializable, is_overridden from lightning_app.utilities.cloud import is_running_in_cloud @@ -45,9 +45,9 @@ def _collect_layout(app: "lightning_app.LightningApp", flow: "lightning_app.Ligh app.frontends.setdefault(flow.name, "mock") return flow._layout elif isinstance(layout, dict): - layout = _collect_content_layout([layout], flow) + layout = _collect_content_layout([layout], app, flow) elif isinstance(layout, (list, tuple)) and all(isinstance(item, dict) for item in layout): - layout = _collect_content_layout(layout, flow) + layout = _collect_content_layout(layout, app, flow) else: lines = _add_comment_to_literal_code(flow.configure_layout, contains="return", comment=" <------- this guy") m = f""" @@ -76,7 +76,9 @@ def configure_layout(self): return layout -def _collect_content_layout(layout: List[Dict], flow: "lightning_app.LightningFlow") -> List[Dict]: +def _collect_content_layout( + layout: List[Dict], app: "lightning_app.LightningApp", flow: "lightning_app.LightningFlow" +) -> Union[List[Dict], Dict]: """Process the layout returned by the ``configure_layout()`` method if the returned format represents an aggregation of child layouts.""" for entry in layout: @@ -102,12 +104,43 @@ def _collect_content_layout(layout: List[Dict], flow: "lightning_app.LightningFl entry["content"] = entry["content"].name elif isinstance(entry["content"], lightning_app.LightningWork): - if entry["content"].url and not entry["content"].url.startswith("/"): - entry["content"] = entry["content"].url - entry["target"] = entry["content"] - else: + work = entry["content"] + work_layout = _collect_work_layout(work) + + if work_layout is None: entry["content"] = "" - entry["target"] = "" + elif isinstance(work_layout, str): + entry["content"] = work_layout + entry["target"] = work_layout + elif isinstance(work_layout, (Frontend, _MagicMockJsonSerializable)): + if len(layout) > 1: + lines = _add_comment_to_literal_code( + flow.configure_layout, contains="return", comment=" <------- this guy" + ) + m = f""" + The return value of configure_layout() in `{flow.__class__.__name__}` is an + unsupported format: + \n{lines} + + The tab containing a `{work.__class__.__name__}` must be the only tab in the + layout of this flow. + + (see the docs for `LightningWork.configure_layout`). + """ + raise TypeError(m) + + if isinstance(work_layout, Frontend): + # If the work returned a frontend, treat it as belonging to the flow. + # NOTE: This could evolve in the future to run the Frontend directly in the work machine. + frontend = work_layout + frontend.flow = flow + elif isinstance(work_layout, _MagicMockJsonSerializable): + # The import was mocked, we set a dummy `Frontend` so that `is_headless` knows there is a UI. + frontend = "mock" + + app.frontends.setdefault(flow.name, frontend) + return flow._layout + elif isinstance(entry["content"], _MagicMockJsonSerializable): # The import was mocked, we just record dummy content so that `is_headless` knows there is a UI entry["content"] = "mock" @@ -126,3 +159,43 @@ def configure_layout(self): """ raise ValueError(m) return layout + + +def _collect_work_layout(work: "lightning_app.LightningWork") -> Union[None, str, Frontend, _MagicMockJsonSerializable]: + """Check if ``configure_layout`` is overridden on the given work and return the work layout (either a string, a + ``Frontend`` object, or an instance of a mocked import). + + Args: + work: The work to collect the layout for. + + Raises: + TypeError: If the value returned by ``configure_layout`` is not of a supported format. + """ + if is_overridden("configure_layout", work): + work_layout = work.configure_layout() + else: + work_layout = work.url + + if work_layout is None: + return None + elif isinstance(work_layout, str): + url = work_layout + # The URL isn't fully defined yet. Looks something like ``self.work.url + /something``. + if url and not url.startswith("/"): + return url + return "" + elif isinstance(work_layout, (Frontend, _MagicMockJsonSerializable)): + return work_layout + else: + m = f""" + The value returned by `{work.__class__.__name__}.configure_layout()` is of an unsupported type. + + {repr(work_layout)} + + Return a `Frontend` or a URL string, for example: + + class {work.__class__.__name__}(LightningWork): + def configure_layout(self): + return MyFrontend() OR 'http://some/url' + """ + raise TypeError(m) diff --git a/tests/tests_app/utilities/test_app_helpers.py b/tests/tests_app/utilities/test_app_helpers.py index 791d2011f7651..2241e262cd381 100644 --- a/tests/tests_app/utilities/test_app_helpers.py +++ b/tests/tests_app/utilities/test_app_helpers.py @@ -1,4 +1,5 @@ import os +from functools import partial from unittest import mock import pytest @@ -10,6 +11,8 @@ ) from lightning_app import LightningApp, LightningFlow, LightningWork +from lightning_app.core.flow import _RootFlow +from lightning_app.frontend import StaticWebFrontend from lightning_app.utilities.app_helpers import ( _handle_is_headless, _is_headless, @@ -119,14 +122,9 @@ def configure_layout(self): return {"name": "test", "content": "https://appurl"} -class FlowWithWorkLayout(Flow): - def __init__(self): - super().__init__() - - self.work = Work() - +class FlowWithFrontend(Flow): def configure_layout(self): - return {"name": "test", "content": self.work} + return StaticWebFrontend(".") class FlowWithMockedFrontend(Flow): @@ -153,16 +151,62 @@ def __init__(self): self.flow = FlowWithURLLayout() +class WorkWithStringLayout(Work): + def configure_layout(self): + return "http://appurl" + + +class WorkWithMockedFrontendLayout(Work): + def configure_layout(self): + return _MagicMockJsonSerializable() + + +class WorkWithFrontendLayout(Work): + def configure_layout(self): + return StaticWebFrontend(".") + + +class WorkWithNoneLayout(Work): + def configure_layout(self): + return None + + +class FlowWithWorkLayout(Flow): + def __init__(self, work): + super().__init__() + + self.work = work() + + def configure_layout(self): + return {"name": "test", "content": self.work} + + +class WorkClassRootFlow(_RootFlow): + """A ``_RootFlow`` which takes a work class rather than the work itself.""" + + def __init__(self, work): + super().__init__(work()) + + @pytest.mark.parametrize( "flow,expected", [ (Flow, True), (FlowWithURLLayout, False), - (FlowWithWorkLayout, False), + (FlowWithFrontend, False), (FlowWithMockedFrontend, False), (FlowWithMockedContent, False), (NestedFlow, True), (NestedFlowWithURLLayout, False), + (partial(WorkClassRootFlow, WorkWithStringLayout), False), + (partial(WorkClassRootFlow, WorkWithMockedFrontendLayout), False), + (partial(WorkClassRootFlow, WorkWithFrontendLayout), False), + (partial(WorkClassRootFlow, WorkWithNoneLayout), True), + (partial(FlowWithWorkLayout, Work), False), + (partial(FlowWithWorkLayout, WorkWithStringLayout), False), + (partial(FlowWithWorkLayout, WorkWithMockedFrontendLayout), False), + (partial(FlowWithWorkLayout, WorkWithFrontendLayout), False), + (partial(FlowWithWorkLayout, WorkWithNoneLayout), True), ], ) def test_is_headless(flow, expected): diff --git a/tests/tests_app/utilities/test_layout.py b/tests/tests_app/utilities/test_layout.py new file mode 100644 index 0000000000000..98921e3d0000e --- /dev/null +++ b/tests/tests_app/utilities/test_layout.py @@ -0,0 +1,143 @@ +import pytest + +from lightning_app.core.flow import LightningFlow +from lightning_app.core.work import LightningWork +from lightning_app.frontend.web import StaticWebFrontend +from lightning_app.utilities.layout import _collect_layout + + +class _MockApp: + def __init__(self) -> None: + self.frontends = {} + + +class FlowWithFrontend(LightningFlow): + def configure_layout(self): + return StaticWebFrontend(".") + + +class WorkWithFrontend(LightningWork): + def run(self): + pass + + def configure_layout(self): + return StaticWebFrontend(".") + + +class FlowWithWorkWithFrontend(LightningFlow): + def __init__(self): + super().__init__() + + self.work = WorkWithFrontend() + + def configure_layout(self): + return {"name": "work", "content": self.work} + + +class FlowWithUrl(LightningFlow): + def configure_layout(self): + return {"name": "test", "content": "https://test"} + + +class WorkWithUrl(LightningWork): + def run(self): + pass + + def configure_layout(self): + return "https://test" + + +class FlowWithWorkWithUrl(LightningFlow): + def __init__(self): + super().__init__() + + self.work = WorkWithUrl() + + def configure_layout(self): + return {"name": "test", "content": self.work} + + +@pytest.mark.parametrize( + "flow,expected_layout,expected_frontends", + [ + (FlowWithFrontend, {}, [("root", StaticWebFrontend)]), + (FlowWithWorkWithFrontend, {}, [("root", StaticWebFrontend)]), + (FlowWithUrl, [{"name": "test", "content": "https://test", "target": "https://test"}], []), + (FlowWithWorkWithUrl, [{"name": "test", "content": "https://test", "target": "https://test"}], []), + ], +) +def test_collect_layout(flow, expected_layout, expected_frontends): + app = _MockApp() + flow = flow() + layout = _collect_layout(app, flow) + + assert layout == expected_layout + assert set(app.frontends.keys()) == {key for key, _ in expected_frontends} + for key, frontend_type in expected_frontends: + assert isinstance(app.frontends[key], frontend_type) + + +class FlowWithBadLayout(LightningFlow): + def configure_layout(self): + return 100 + + +class FlowWithBadLayoutDict(LightningFlow): + def configure_layout(self): + return {"this_key_should_not_be_here": "http://appurl"} + + +class FlowWithBadContent(LightningFlow): + def configure_layout(self): + return {"content": 100} + + +class WorkWithBadLayout(LightningWork): + def run(self): + pass + + def configure_layout(self): + return 100 + + +class FlowWithWorkWithBadLayout(LightningFlow): + def __init__(self): + super().__init__() + + self.work = WorkWithBadLayout() + + def configure_layout(self): + return {"name": "test", "content": self.work} + + +class FlowWithMultipleWorksWithFrontends(LightningFlow): + def __init__(self): + super().__init__() + + self.work1 = WorkWithFrontend() + self.work2 = WorkWithFrontend() + + def configure_layout(self): + return [{"name": "test1", "content": self.work1}, {"name": "test2", "content": self.work2}] + + +@pytest.mark.parametrize( + "flow,error_type,match", + [ + (FlowWithBadLayout, TypeError, "is an unsupported layout format"), + (FlowWithBadLayoutDict, ValueError, "missing a key 'content'."), + (FlowWithBadContent, ValueError, "contains an unsupported entry."), + (FlowWithWorkWithBadLayout, TypeError, "is of an unsupported type."), + ( + FlowWithMultipleWorksWithFrontends, + TypeError, + "The tab containing a `WorkWithFrontend` must be the only tab", + ), + ], +) +def test_collect_layout_errors(flow, error_type, match): + app = _MockApp() + flow = flow() + + with pytest.raises(error_type, match=match): + _collect_layout(app, flow)