diff --git a/src/lightning_app/cli/lightning_cli.py b/src/lightning_app/cli/lightning_cli.py index 87b0ef91755b10..031dc2c11781a6 100644 --- a/src/lightning_app/cli/lightning_cli.py +++ b/src/lightning_app/cli/lightning_cli.py @@ -207,7 +207,7 @@ def login() -> None: auth.clear() try: - auth._run_server() + auth.authenticate() except ConnectionError: click.echo(f"Unable to connect to {get_lightning_cloud_url()}. Please check your internet connection.") exit(1) diff --git a/src/lightning_app/utilities/login.py b/src/lightning_app/utilities/login.py index 4539ef805eafa7..508aca237b2cf9 100644 --- a/src/lightning_app/utilities/login.py +++ b/src/lightning_app/utilities/login.py @@ -4,6 +4,7 @@ import pathlib from dataclasses import dataclass from enum import Enum +from time import sleep from typing import Optional from urllib.parse import urlencode @@ -44,9 +45,11 @@ def __post_init__(self): setattr(self, key.suffix, os.environ.get(key.value, None)) self._with_env_var = bool(self.user_id and self.api_key) # used by authenticate method - if self.api_key and not self.user_id: + if self._with_env_var: + self.save("", self.user_id, self.api_key, self.user_id) + logger.info("Credentials loaded from environment variables") + elif self.api_key or self.user_id: raise ValueError( - f"{Keys.USER_ID.value} is missing from env variables. " "To use env vars for authentication both " f"{Keys.USER_ID.value} and {Keys.API_KEY.value} should be set." ) @@ -135,7 +138,8 @@ def authenticate(self) -> Optional[str]: class AuthServer: - def get_auth_url(self, port: int) -> str: + @staticmethod + def get_auth_url(port: int) -> str: redirect_uri = f"http://localhost:{port}/login-complete" params = urlencode(dict(redirectTo=redirect_uri)) return f"{get_lightning_cloud_url()}/sign-in?{params}" @@ -144,6 +148,7 @@ def login_with_browser(self, auth: Auth) -> None: app = FastAPI() port = find_free_network_port() url = self.get_auth_url(port) + try: # check if server is reachable or catch any network errors requests.head(url) @@ -156,32 +161,42 @@ def login_with_browser(self, auth: Auth) -> None: f"An error occurred with the request. Please report this issue to Lightning Team \n{e}" # E501 ) - logger.info(f"login started for lightning.ai, opening {url}") + logger.info( + "\nAttempting to automatically open the login page in your default browser.\n" + "If the browser does not open, navigate to the \"Keys\" tab on your Lightning AI profile page:\n\n" + f"{get_lightning_cloud_url()}/me/keys\n\n" + "Copy the \"Headless CLI Login\" command, and execute it in your terminal.\n" + ) click.launch(url) @app.get("/login-complete") async def save_token(request: Request, token="", key="", user_id: str = Query("", alias="userID")): - if token: - auth.save(token=token, username=user_id, user_id=user_id, api_key=key) - logger.info("Authentication Successful") - else: + async def stop_server_once_request_is_done(): + while not await request.is_disconnected(): + sleep(0.25) + server.should_exit = True + + if not token: logger.warn( - "Authentication Failed. This is most likely because you're using an older version of the CLI. \n" # noqa E501 + "Login Failed. This is most likely because you're using an older version of the CLI. \n" # noqa E501 "Please try to update the CLI or open an issue with this information \n" # E501 f"expected token in {request.query_params.items()}" ) + return RedirectResponse( + url=f"{get_lightning_cloud_url()}/cli-login-failed", + background=BackgroundTask(stop_server_once_request_is_done), + ) + + auth.save(token=token, username=user_id, user_id=user_id, api_key=key) + logger.info("Login Successful") # Include the credentials in the redirect so that UI will also be logged in params = urlencode(dict(token=token, key=key, userID=user_id)) return RedirectResponse( - url=f"{get_lightning_cloud_url()}/me/apps?{params}", - # The response background task is being executed right after the server finished writing the response - background=BackgroundTask(stop_server), + url=f"{get_lightning_cloud_url()}/cli-login-successful?{params}", + background=BackgroundTask(stop_server_once_request_is_done), ) - def stop_server(): - server.should_exit = True - server = uvicorn.Server(config=uvicorn.Config(app, port=port, log_level="error")) server.run()