From 4debdd349dc27f25ae67a143c71f46987f25a2a8 Mon Sep 17 00:00:00 2001 From: thomas chaton Date: Wed, 14 Dec 2022 18:09:30 +0100 Subject: [PATCH] [App] Improve lightning connect experience (#16035) (cherry picked from commit e522a12d177fe5d3a733b67c44d2a558a18aa116) --- examples/app_installation_commands/app.py | 13 +- src/lightning_app/CHANGELOG.md | 2 + src/lightning_app/cli/commands/connection.py | 220 +++++++++++------- src/lightning_app/cli/lightning_cli.py | 2 - .../components/database/server.py | 6 +- src/lightning_app/runners/multiprocess.py | 2 +- src/lightning_app/utilities/cli_helpers.py | 4 + src/lightning_app/utilities/commands/base.py | 12 +- src/lightning_app/utilities/frontend.py | 1 + tests/tests_app/cli/test_connect.py | 58 +---- .../public/test_commands_and_api.py | 2 +- 11 files changed, 171 insertions(+), 151 deletions(-) diff --git a/examples/app_installation_commands/app.py b/examples/app_installation_commands/app.py index 087d84b1335b2..f69df99ad9e82 100644 --- a/examples/app_installation_commands/app.py +++ b/examples/app_installation_commands/app.py @@ -13,9 +13,14 @@ def run(self): print("lmdb successfully installed") print("accessing a module in a Work or Flow body works!") - @property - def ready(self) -> bool: - return True + +class RootFlow(L.LightningFlow): + def __init__(self, work): + super().__init__() + self.work = work + + def run(self): + self.work.run() print(f"accessing an object in main code body works!: version={lmdb.version()}") @@ -24,4 +29,4 @@ def ready(self) -> bool: # run on a cloud machine compute = L.CloudCompute("cpu") worker = YourComponent(cloud_compute=compute) -app = L.LightningApp(worker) +app = L.LightningApp(RootFlow(worker)) diff --git a/src/lightning_app/CHANGELOG.md b/src/lightning_app/CHANGELOG.md index f93f4a0a8d0fd..7ebe58fd6cb46 100644 --- a/src/lightning_app/CHANGELOG.md +++ b/src/lightning_app/CHANGELOG.md @@ -11,6 +11,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `Lightning{Flow,Work}.lightningignores` attributes to programmatically ignore files before uploading to the cloud ([#15818](https://github.com/Lightning-AI/lightning/pull/15818)) +- Added a progres bar while connecting to an app through the CLI ([#16035](https://github.com/Lightning-AI/lightning/pull/16035)) + ### Changed diff --git a/src/lightning_app/cli/commands/connection.py b/src/lightning_app/cli/commands/connection.py index ee0bf7edc5d67..f5600947ab427 100644 --- a/src/lightning_app/cli/commands/connection.py +++ b/src/lightning_app/cli/commands/connection.py @@ -8,6 +8,7 @@ import click import psutil from lightning_utilities.core.imports import package_available +from rich.progress import Progress from lightning_app.utilities.cli_helpers import _LightningAppOpenAPIRetriever from lightning_app.utilities.cloud import _get_project @@ -16,15 +17,33 @@ from lightning_app.utilities.network import LightningClient _HOME = os.path.expanduser("~") -_PPID = str(psutil.Process(os.getpid()).ppid()) +_PPID = os.getenv("LIGHTNING_CONNECT_PPID", str(psutil.Process(os.getpid()).ppid())) _LIGHTNING_CONNECTION = os.path.join(_HOME, ".lightning", "lightning_connection") _LIGHTNING_CONNECTION_FOLDER = os.path.join(_LIGHTNING_CONNECTION, _PPID) @click.argument("app_name_or_id", required=True) -@click.option("-y", "--yes", required=False, is_flag=True, help="Whether to download the commands automatically.") -def connect(app_name_or_id: str, yes: bool = False): - """Connect to a Lightning App.""" +def connect(app_name_or_id: str): + """Connect your local terminal to a running lightning app. + + After connecting, the lightning CLI will respond to commands exposed by the app. + + Example: + + \b + # connect to an app named pizza-cooker-123 + lightning connect pizza-cooker-123 + \b + # this will now show the commands exposed by pizza-cooker-123 + lightning --help + \b + # while connected, you can run the cook-pizza command exposed + # by pizza-cooker-123.BTW, this should arguably generate an exception :-) + lightning cook-pizza --flavor pineapple + \b + # once done, disconnect and go back to the standard lightning CLI commands + lightning disconnect + """ from lightning_app.utilities.commands.base import _download_command _clean_lightning_connection() @@ -47,51 +66,64 @@ def connect(app_name_or_id: str, yes: bool = False): click.echo(f"You are already connected to the cloud Lightning App: {app_name_or_id}.") else: disconnect() - connect(app_name_or_id, yes) + connect(app_name_or_id) elif app_name_or_id.startswith("localhost"): - if app_name_or_id != "localhost": - raise Exception("You need to pass localhost to connect to the local Lightning App.") + with Progress() as progress_bar: + connecting = progress_bar.add_task("[magenta]Setting things up for you...", total=1.0) - retriever = _LightningAppOpenAPIRetriever(None) + if app_name_or_id != "localhost": + raise Exception("You need to pass localhost to connect to the local Lightning App.") - if retriever.api_commands is None: - raise Exception(f"The commands weren't found. Is your app {app_name_or_id} running ?") + retriever = _LightningAppOpenAPIRetriever(None) - commands_folder = os.path.join(_LIGHTNING_CONNECTION_FOLDER, "commands") - if not os.path.exists(commands_folder): - os.makedirs(commands_folder) + if retriever.api_commands is None: + raise Exception(f"Connection wasn't successful. Is your app {app_name_or_id} running?") - _write_commands_metadata(retriever.api_commands) + increment = 1 / (1 + len(retriever.api_commands)) - with open(os.path.join(commands_folder, "openapi.json"), "w") as f: - json.dump(retriever.openapi, f) + progress_bar.update(connecting, advance=increment) - _install_missing_requirements(retriever, yes) + commands_folder = os.path.join(_LIGHTNING_CONNECTION_FOLDER, "commands") + if not os.path.exists(commands_folder): + os.makedirs(commands_folder) - for command_name, metadata in retriever.api_commands.items(): - if "cls_path" in metadata: - target_file = os.path.join(commands_folder, f"{command_name.replace(' ','_')}.py") - _download_command( - command_name, - metadata["cls_path"], - metadata["cls_name"], - None, - target_file=target_file, - ) - repr_command_name = command_name.replace("_", " ") - click.echo(f"Storing `{repr_command_name}` at {target_file}") - else: - with open(os.path.join(commands_folder, f"{command_name}.txt"), "w") as f: - f.write(command_name) + _write_commands_metadata(retriever.api_commands) + + with open(os.path.join(commands_folder, "openapi.json"), "w") as f: + json.dump(retriever.openapi, f) - click.echo(f"You can review all the downloaded commands at {commands_folder}") + _install_missing_requirements(retriever) + + for command_name, metadata in retriever.api_commands.items(): + if "cls_path" in metadata: + target_file = os.path.join(commands_folder, f"{command_name.replace(' ','_')}.py") + _download_command( + command_name, + metadata["cls_path"], + metadata["cls_name"], + None, + target_file=target_file, + ) + else: + with open(os.path.join(commands_folder, f"{command_name}.txt"), "w") as f: + f.write(command_name) + + progress_bar.update(connecting, advance=increment) with open(connected_file, "w") as f: f.write(app_name_or_id + "\n") - click.echo("You are connected to the local Lightning App.") + click.echo("The lightning CLI now responds to app commands. Use 'lightning --help' to see them.") + click.echo(" ") + + Popen( + f"LIGHTNING_CONNECT_PPID={_PPID} {sys.executable} -m lightning --help", + shell=True, + stdout=sys.stdout, + stderr=sys.stderr, + ).wait() elif matched_connection_path: @@ -101,40 +133,39 @@ def connect(app_name_or_id: str, yes: bool = False): commands = os.path.join(_LIGHTNING_CONNECTION_FOLDER, "commands") shutil.copytree(matched_commands, commands) shutil.copy(matched_connected_file, connected_file) - copied_files = [el for el in os.listdir(commands) if os.path.splitext(el)[1] == ".py"] - click.echo("Found existing connection, reusing cached commands") - for target_file in copied_files: - pretty_command_name = os.path.splitext(target_file)[0].replace("_", " ") - click.echo(f"Storing `{pretty_command_name}` at {os.path.join(commands, target_file)}") - click.echo(f"You can review all the commands at {commands}") + click.echo("The lightning CLI now responds to app commands. Use 'lightning --help' to see them.") click.echo(" ") - click.echo(f"You are connected to the cloud Lightning App: {app_name_or_id}.") - else: + Popen( + f"LIGHTNING_CONNECT_PPID={_PPID} {sys.executable} -m lightning --help", + shell=True, + stdout=sys.stdout, + stderr=sys.stderr, + ).wait() - retriever = _LightningAppOpenAPIRetriever(app_name_or_id) + else: + with Progress() as progress_bar: + connecting = progress_bar.add_task("[magenta]Setting things up for you...", total=1.0) + + retriever = _LightningAppOpenAPIRetriever(app_name_or_id) + + if not retriever.api_commands: + client = LightningClient() + project = _get_project(client) + apps = client.lightningapp_instance_service_list_lightningapp_instances(project_id=project.project_id) + click.echo( + "We didn't find a matching App. Here are the available Apps that you can" + f"connect to {[app.name for app in apps.lightningapps]}." + ) + return - if not retriever.api_commands: - client = LightningClient() - project = _get_project(client) - apps = client.lightningapp_instance_service_list_lightningapp_instances(project_id=project.project_id) - click.echo( - "We didn't find a matching App. Here are the available Apps that could be " - f"connected to {[app.name for app in apps.lightningapps]}." - ) - return + increment = 1 / (1 + len(retriever.api_commands)) - _install_missing_requirements(retriever, yes) + progress_bar.update(connecting, advance=increment) - if not yes: - yes = click.confirm( - f"The Lightning App `{app_name_or_id}` provides a command-line (CLI). " - "Do you want to proceed and install its CLI ?" - ) - click.echo(" ") + _install_missing_requirements(retriever) - if yes: commands_folder = os.path.join(_LIGHTNING_CONNECTION_FOLDER, "commands") if not os.path.exists(commands_folder): os.makedirs(commands_folder) @@ -151,26 +182,25 @@ def connect(app_name_or_id: str, yes: bool = False): retriever.app_id, target_file=target_file, ) - pretty_command_name = command_name.replace("_", " ") - click.echo(f"Storing `{pretty_command_name}` at {target_file}") else: with open(os.path.join(commands_folder, f"{command_name}.txt"), "w") as f: f.write(command_name) - click.echo(f"You can review all the downloaded commands at {commands_folder}") - - click.echo(" ") - click.echo("The client interface has been successfully installed. ") - click.echo("You can now run the following commands:") - for command in retriever.api_commands: - pretty_command_name = command.replace("_", " ") - click.echo(f" lightning {pretty_command_name}") + progress_bar.update(connecting, advance=increment) with open(connected_file, "w") as f: f.write(retriever.app_name + "\n") f.write(retriever.app_id + "\n") + + click.echo("The lightning CLI now responds to app commands. Use 'lightning --help' to see them.") click.echo(" ") - click.echo(f"You are connected to the cloud Lightning App: {app_name_or_id}.") + + Popen( + f"LIGHTNING_CONNECT_PPID={_PPID} {sys.executable} -m lightning --help", + shell=True, + stdout=sys.stdout, + stderr=sys.stderr, + ).wait() def disconnect(logout: bool = False): @@ -244,22 +274,37 @@ def _list_app_commands(echo: bool = True) -> List[str]: click.echo("The current Lightning App doesn't have commands.") return [] + app_info = metadata[command_names[0]].get("app_info", None) + + title, description, on_connect_end = "Lightning", None, None + if app_info: + title = app_info.get("title") + description = app_info.get("description") + on_connect_end = app_info.get("on_connect_end") + if echo: - click.echo("Usage: lightning [OPTIONS] COMMAND [ARGS]...") - click.echo("") - click.echo(" --help Show this message and exit.") + click.echo(f"{title} App") + if description: + click.echo("") + click.echo("Description:") + if description.endswith("\n"): + description = description[:-2] + click.echo(f" {description}") click.echo("") - click.echo("Lightning App Commands") + click.echo("Commands:") max_length = max(len(n) for n in command_names) for command_name in command_names: padding = (max_length + 1 - len(command_name)) * " " click.echo(f" {command_name}{padding}{metadata[command_name].get('description', '')}") + if "LIGHTNING_CONNECT_PPID" in os.environ and on_connect_end: + if on_connect_end.endswith("\n"): + on_connect_end = on_connect_end[:-2] + click.echo(on_connect_end) return command_names def _install_missing_requirements( retriever: _LightningAppOpenAPIRetriever, - yes_global: bool = False, fail_if_missing: bool = False, ): requirements = set() @@ -281,20 +326,15 @@ def _install_missing_requirements( sys.exit(0) for req in missing_requirements: - if not yes_global: - yes = click.confirm( - f"The Lightning App CLI `{retriever.app_id}` requires `{req}`. Do you want to install it ?" - ) - else: - print(f"Installing missing `{req}` requirement.") - yes = yes_global - if yes: - std_out_out = get_logfile("output.log") - with open(std_out_out, "wb") as stdout: - Popen( - f"{sys.executable} -m pip install {req}", shell=True, stdout=stdout, stderr=sys.stderr - ).wait() - print() + std_out_out = get_logfile("output.log") + with open(std_out_out, "wb") as stdout: + Popen( + f"{sys.executable} -m pip install {req}", + shell=True, + stdout=stdout, + stderr=stdout, + ).wait() + os.remove(std_out_out) def _clean_lightning_connection(): diff --git a/src/lightning_app/cli/lightning_cli.py b/src/lightning_app/cli/lightning_cli.py index 4696745ada95f..002f8c267c8ab 100644 --- a/src/lightning_app/cli/lightning_cli.py +++ b/src/lightning_app/cli/lightning_cli.py @@ -76,8 +76,6 @@ def main() -> None: else: message = f"You are connected to the cloud Lightning App: {app_name}." - click.echo(" ") - if (len(sys.argv) > 1 and sys.argv[1] in ["-h", "--help"]) or len(sys.argv) == 1: _list_app_commands() else: diff --git a/src/lightning_app/components/database/server.py b/src/lightning_app/components/database/server.py index 6d187e4cda133..0a93bf94f985b 100644 --- a/src/lightning_app/components/database/server.py +++ b/src/lightning_app/components/database/server.py @@ -14,6 +14,7 @@ from lightning_app.components.database.utilities import _create_database, _Delete, _Insert, _SelectAll, _Update from lightning_app.core.work import LightningWork from lightning_app.storage import Drive +from lightning_app.utilities.app_helpers import Logger from lightning_app.utilities.imports import _is_sqlmodel_available from lightning_app.utilities.packaging.build_config import BuildConfig @@ -23,6 +24,9 @@ SQLModel = object +logger = Logger(__name__) + + # Required to avoid Uvicorn Server overriding Lightning App signal handlers. # Discussions: https://github.com/encode/uvicorn/discussions/1708 class _DatabaseUvicornServer(uvicorn.Server): @@ -167,7 +171,7 @@ def store_database(self): drive = Drive("lit://database", component_name=self.name, root_folder=tmpdir) drive.put(os.path.basename(tmp_db_filename)) - print("Stored the database to the Drive.") + logger.debug("Stored the database to the Drive.") except Exception: print(traceback.print_exc()) diff --git a/src/lightning_app/runners/multiprocess.py b/src/lightning_app/runners/multiprocess.py index 673e8601043d7..e5d34fb76800f 100644 --- a/src/lightning_app/runners/multiprocess.py +++ b/src/lightning_app/runners/multiprocess.py @@ -82,7 +82,7 @@ def dispatch(self, *args: Any, open_ui: bool = True, **kwargs: Any): if is_overridden("configure_commands", self.app.root): commands = _prepare_commands(self.app) - apis += _commands_to_api(commands) + apis += _commands_to_api(commands, info=self.app.info) kwargs = dict( apis=apis, diff --git a/src/lightning_app/utilities/cli_helpers.py b/src/lightning_app/utilities/cli_helpers.py index caa414e163ffc..0ec6eabd3022c 100644 --- a/src/lightning_app/utilities/cli_helpers.py +++ b/src/lightning_app/utilities/cli_helpers.py @@ -69,6 +69,7 @@ def _get_metadata_from_openapi(paths: Dict, path: str): cls_name = paths[path]["post"].get("cls_name", None) description = paths[path]["post"].get("description", None) requirements = paths[path]["post"].get("requirements", None) + app_info = paths[path]["post"].get("app_info", None) metadata = {"tag": tag, "parameters": {}} @@ -84,6 +85,9 @@ def _get_metadata_from_openapi(paths: Dict, path: str): if description: metadata["requirements"] = requirements + if app_info: + metadata["app_info"] = app_info + if not parameters: return metadata diff --git a/src/lightning_app/utilities/commands/base.py b/src/lightning_app/utilities/commands/base.py index 53aefe5725194..947456d8e4e1c 100644 --- a/src/lightning_app/utilities/commands/base.py +++ b/src/lightning_app/utilities/commands/base.py @@ -5,6 +5,7 @@ import shutil import sys import traceback +from dataclasses import asdict from getpass import getuser from importlib.util import module_from_spec, spec_from_file_location from tempfile import gettempdir @@ -16,6 +17,7 @@ from lightning_app.api.http_methods import Post from lightning_app.api.request_types import _APIRequest, _CommandRequest, _RequestResponse +from lightning_app.utilities import frontend from lightning_app.utilities.app_helpers import is_overridden, Logger from lightning_app.utilities.cloud import _get_project from lightning_app.utilities.network import LightningClient @@ -250,7 +252,7 @@ def _process_requests(app, requests: List[Union[_APIRequest, _CommandRequest]]) app.api_response_queue.put(responses) -def _collect_open_api_extras(command) -> Dict: +def _collect_open_api_extras(command, info) -> Dict: if not isinstance(command, ClientCommand): if command.__doc__ is not None: return {"description": command.__doc__} @@ -263,10 +265,14 @@ def _collect_open_api_extras(command) -> Dict: } if command.requirements: extras.update({"requirements": command.requirements}) + if info: + extras.update({"app_info": asdict(info)}) return extras -def _commands_to_api(commands: List[Dict[str, Union[Callable, ClientCommand]]]) -> List: +def _commands_to_api( + commands: List[Dict[str, Union[Callable, ClientCommand]]], info: Optional[frontend.AppInfo] = None +) -> List: """Convert user commands to API endpoint.""" api = [] for command in commands: @@ -278,7 +284,7 @@ def _commands_to_api(commands: List[Dict[str, Union[Callable, ClientCommand]]]) v.method if isinstance(v, ClientCommand) else v, method_name=k, tags=["app_client_command"] if isinstance(v, ClientCommand) else ["app_command"], - openapi_extra=_collect_open_api_extras(v), + openapi_extra=_collect_open_api_extras(v, info), ) ) return api diff --git a/src/lightning_app/utilities/frontend.py b/src/lightning_app/utilities/frontend.py index 315c119935b6f..470036436a63c 100644 --- a/src/lightning_app/utilities/frontend.py +++ b/src/lightning_app/utilities/frontend.py @@ -12,6 +12,7 @@ class AppInfo: image: Optional[str] = None # ensure the meta tags are correct or the UI might fail to load. meta_tags: Optional[List[str]] = None + on_connect_end: Optional[str] = None def update_index_file(ui_root: str, info: Optional[AppInfo] = None, root_path: str = "") -> None: diff --git a/tests/tests_app/cli/test_connect.py b/tests/tests_app/cli/test_connect.py index a8924ab375db2..adbd55385815e 100644 --- a/tests/tests_app/cli/test_connect.py +++ b/tests/tests_app/cli/test_connect.py @@ -1,6 +1,5 @@ import json import os -import sys from unittest.mock import MagicMock import click @@ -37,11 +36,8 @@ def monkeypatch_connection(monkeypatch, tmpdir, ppid): def test_connect_disconnect_local(tmpdir, monkeypatch): disconnect() - ppid = str(psutil.Process(os.getpid()).ppid()) - connection_path = monkeypatch_connection(monkeypatch, tmpdir, ppid=ppid) - - with pytest.raises(Exception, match="The commands weren't found. Is your app localhost running ?"): - connect("localhost", True) + with pytest.raises(Exception, match="Connection wasn't successful. Is your app localhost running ?"): + connect("localhost") with open(os.path.join(os.path.dirname(__file__), "jsons/connect_1.json")) as f: data = json.load(f) @@ -64,23 +60,14 @@ def fn(msg): response.status_code = 200 response.json.return_value = data monkeypatch.setattr(cli_helpers.requests, "get", MagicMock(return_value=response)) - connect("localhost", True) + connect("localhost") assert _retrieve_connection_to_an_app() == ("localhost", None) command_path = _resolve_command_path("nested_command") assert not os.path.exists(command_path) command_path = _resolve_command_path("command_with_client") assert os.path.exists(command_path) - s = "/" if sys.platform != "win32" else "\\" - command_folder_path = f"{connection_path}{s}commands" - expected = [ - f"Storing `command with client` at {command_folder_path}{s}command_with_client.py", - f"You can review all the downloaded commands at {command_folder_path}", - "You are connected to the local Lightning App.", - ] - assert messages == expected - messages = [] - connect("localhost", True) + connect("localhost") assert messages == ["You are connected to the local Lightning App."] messages = [] @@ -101,8 +88,6 @@ def test_connect_disconnect_cloud(tmpdir, monkeypatch): ppid_1 = str(psutil.Process(os.getpid()).ppid()) ppid_2 = "222" - connection_path = monkeypatch_connection(monkeypatch, tmpdir, ppid=ppid_1) - target_file = _resolve_command_path("command_with_client") if os.path.exists(target_file): @@ -153,7 +138,7 @@ def fn(msg): with open(data["paths"]["/command/command_with_client"]["post"]["cls_path"], "rb") as f: response.content = f.read() - connect("example", True) + connect("example") assert _retrieve_connection_to_an_app() == ("example", "1234") commands = _list_app_commands() assert commands == ["command with client", "command without client", "nested command"] @@ -161,50 +146,25 @@ def fn(msg): assert not os.path.exists(command_path) command_path = _resolve_command_path("command_with_client") assert os.path.exists(command_path) - s = "/" if sys.platform != "win32" else "\\" - command_folder_path = f"{connection_path}{s}commands" - expected = [ - f"Storing `command with client` at {command_folder_path}{s}command_with_client.py", - f"You can review all the downloaded commands at {command_folder_path}", - " ", - "The client interface has been successfully installed. ", - "You can now run the following commands:", - " lightning command without client", - " lightning command with client", - " lightning nested command", - " ", - "You are connected to the cloud Lightning App: example.", - "Usage: lightning [OPTIONS] COMMAND [ARGS]...", - "", - " --help Show this message and exit.", - "", - "Lightning App Commands", - " command with client A command with a client.", - " command without client A command without a client.", - " nested command A nested command.", - ] - assert messages == expected - messages = [] - connect("example", True) + connect("example") assert messages == ["You are already connected to the cloud Lightning App: example."] _ = monkeypatch_connection(monkeypatch, tmpdir, ppid=ppid_2) messages = [] - connect("example", True) - assert messages[0] == "Found existing connection, reusing cached commands" + connect("example") + assert "The lightning CLI now responds to app commands" in messages[0] messages = [] disconnect() - print(messages) assert messages == ["You are disconnected from the cloud Lightning App: example."] _ = monkeypatch_connection(monkeypatch, tmpdir, ppid=ppid_1) messages = [] disconnect() - assert messages == ["You are disconnected from the cloud Lightning App: example."] + assert "You aren't connected to any Lightning App" in messages[0] messages = [] disconnect() diff --git a/tests/tests_examples_app/public/test_commands_and_api.py b/tests/tests_examples_app/public/test_commands_and_api.py index a6a015d02a84a..290cf0d49b4cb 100644 --- a/tests/tests_examples_app/public/test_commands_and_api.py +++ b/tests/tests_examples_app/public/test_commands_and_api.py @@ -25,7 +25,7 @@ def test_commands_and_api_example_cloud() -> None: # 2: Connect to the App and send the first & second command with the client # Requires to be run within the same process. - cmd_1 = f"python -m lightning connect {app_id} -y" + cmd_1 = f"python -m lightning connect {app_id}" cmd_2 = "python -m lightning command with client --name=this" cmd_3 = "python -m lightning command without client --name=is" cmd_4 = "lightning disconnect"