Skip to content

Commit

Permalink
[CLI] fix ssh listing stopped components (#15810)
Browse files Browse the repository at this point in the history
* [CLI] fix ssh listing stopped components
* update CHANGELOG
  • Loading branch information
nicolai86 committed Nov 28, 2022
1 parent 657bfc5 commit c786b3d
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 6 deletions.
3 changes: 1 addition & 2 deletions src/lightning_app/CHANGELOG.md
Expand Up @@ -36,8 +36,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed

- Fixed debugging with VSCode IDE ([#15747](https://github.com/Lightning-AI/lightning/pull/15747))


- Fixed SSH CLI command listing stopped components ([#15810](https://github.com/Lightning-AI/lightning/pull/15810))
- Fixed the work not stopped when successful when passed directly to the LightningApp ([#15801](https://github.com/Lightning-AI/lightning/pull/15801))

- Fixed the PyTorch Inference locally on GPU ([#15813](https://github.com/Lightning-AI/lightning/pull/15813))
Expand Down
8 changes: 6 additions & 2 deletions src/lightning_app/cli/cmd_apps.py
Expand Up @@ -54,9 +54,13 @@ def list_apps(
apps = apps + resp.lightningapps
return apps

def list_components(self, app_id: str) -> List[Externalv1Lightningwork]:
def list_components(self, app_id: str, phase_in: List[str] = []) -> List[Externalv1Lightningwork]:
project = _get_project(self.api_client)
resp = self.api_client.lightningwork_service_list_lightningwork(project_id=project.project_id, app_id=app_id)
resp = self.api_client.lightningwork_service_list_lightningwork(
project_id=project.project_id,
app_id=app_id,
phase_in=phase_in,
)
return resp.lightningworks

def list(self, cluster_id: str = None, limit: int = 100) -> None:
Expand Down
4 changes: 2 additions & 2 deletions src/lightning_app/cli/lightning_cli.py
Expand Up @@ -8,7 +8,7 @@
import click
import inquirer
import rich
from lightning_cloud.openapi import Externalv1LightningappInstance, V1LightningappInstanceState
from lightning_cloud.openapi import Externalv1LightningappInstance, V1LightningappInstanceState, V1LightningworkState
from lightning_cloud.openapi.rest import ApiException
from lightning_utilities.core.imports import RequirementCache
from requests.exceptions import ConnectionError
Expand Down Expand Up @@ -420,7 +420,7 @@ def ssh(app_name: str = None, component_name: str = None) -> None:
except ApiException:
raise click.ClickException("failed fetching app instance")

components = app_manager.list_components(app_id=app_id)
components = app_manager.list_components(app_id=app_id, phase_in=[V1LightningworkState.RUNNING])
available_component_names = [work.name for work in components] + ["flow"]
if component_name is None:
available_components = [
Expand Down
32 changes: 32 additions & 0 deletions tests/tests_app/cli/test_cmd_apps.py
Expand Up @@ -7,7 +7,9 @@
V1LightningappInstanceSpec,
V1LightningappInstanceState,
V1LightningappInstanceStatus,
V1LightningworkState,
V1ListLightningappInstancesResponse,
V1ListLightningworkResponse,
V1ListMembershipsResponse,
V1Membership,
)
Expand Down Expand Up @@ -97,6 +99,36 @@ def test_list_all_apps(list_memberships: mock.MagicMock, list_instances: mock.Ma
list_instances.assert_called_once_with(project_id="default-project", limit=100, phase_in=[])


@mock.patch("lightning_cloud.login.Auth.authenticate", MagicMock())
@mock.patch("lightning_app.utilities.network.LightningClient.lightningwork_service_list_lightningwork")
@mock.patch("lightning_app.utilities.network.LightningClient.projects_service_list_memberships")
def test_list_components(list_memberships: mock.MagicMock, list_components: mock.MagicMock):
list_memberships.return_value = V1ListMembershipsResponse(memberships=[V1Membership(project_id="default-project")])
list_components.return_value = V1ListLightningworkResponse(lightningworks=[])

cluster_manager = _AppManager()
cluster_manager.list_components(app_id="cheese")

list_memberships.assert_called_once()
list_components.assert_called_once_with(project_id="default-project", app_id="cheese", phase_in=[])


@mock.patch("lightning_cloud.login.Auth.authenticate", MagicMock())
@mock.patch("lightning_app.utilities.network.LightningClient.lightningwork_service_list_lightningwork")
@mock.patch("lightning_app.utilities.network.LightningClient.projects_service_list_memberships")
def test_list_components_with_phase(list_memberships: mock.MagicMock, list_components: mock.MagicMock):
list_memberships.return_value = V1ListMembershipsResponse(memberships=[V1Membership(project_id="default-project")])
list_components.return_value = V1ListLightningworkResponse(lightningworks=[])

cluster_manager = _AppManager()
cluster_manager.list_components(app_id="cheese", phase_in=[V1LightningworkState.RUNNING])

list_memberships.assert_called_once()
list_components.assert_called_once_with(
project_id="default-project", app_id="cheese", phase_in=[V1LightningworkState.RUNNING]
)


@mock.patch("lightning_cloud.login.Auth.authenticate", MagicMock())
@mock.patch("lightning_app.utilities.network.LightningClient.lightningapp_instance_service_list_lightningapp_instances")
@mock.patch("lightning_app.utilities.network.LightningClient.projects_service_list_memberships")
Expand Down

0 comments on commit c786b3d

Please sign in to comment.