/
lightning_cli_create.py
109 lines (100 loc) · 3.52 KB
/
lightning_cli_create.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
from pathlib import Path
from typing import Any, Optional, Union
import click
from lightning_cloud.openapi.rest import ApiException
from lightning_app.cli.cmd_clusters import _check_cluster_name_is_valid, AWSClusterManager
from lightning_app.cli.cmd_ssh_keys import _SSHKeyManager
@click.group("create")
def create() -> None:
"""Create Lightning AI self-managed resources (clusters, etc…)"""
pass
@create.command("cluster")
@click.argument("cluster_name", callback=_check_cluster_name_is_valid)
@click.option("--provider", "provider", type=str, default="aws", help="cloud provider to be used for your cluster")
@click.option("--external-id", "external_id", type=str, required=True)
@click.option(
"--role-arn", "role_arn", type=str, required=True, help="AWS role ARN attached to the associated resources."
)
@click.option(
"--region",
"region",
type=str,
required=False,
default="us-east-1",
help="AWS region that is used to host the associated resources.",
hidden=True,
)
@click.option(
"--enable-performance",
"enable_performance",
type=bool,
required=False,
default=False,
is_flag=True,
help=""""Use this flag to ensure that the cluster is created with a profile that is optimized for performance.
This makes runs more expensive but start-up times decrease.""",
)
@click.option(
"--edit-before-creation",
default=False,
is_flag=True,
help="Edit the cluster specs before submitting them to the API server.",
)
@click.option(
"--wait",
"wait",
type=bool,
required=False,
default=False,
is_flag=True,
help="Enabling this flag makes the CLI wait until the cluster is running.",
)
def create_cluster(
cluster_name: str,
region: str,
role_arn: str,
external_id: str,
provider: str,
edit_before_creation: bool,
enable_performance: bool,
wait: bool,
**kwargs: Any,
) -> None:
"""Create a Lightning AI BYOC compute cluster with your cloud provider credentials."""
if provider.lower() != "aws":
click.echo("Only AWS is supported for now. But support for more providers is coming soon.")
return
cluster_manager = AWSClusterManager()
cluster_manager.create(
cluster_name=cluster_name,
region=region,
role_arn=role_arn,
external_id=external_id,
edit_before_creation=edit_before_creation,
cost_savings=not enable_performance,
wait=wait,
)
@create.command("ssh-key")
@click.option("--name", "key_name", default=None, help="name of ssh key")
@click.option("--comment", "comment", default="", help="comment detailing your SSH key")
@click.option(
"--public-key",
"public_key",
help="public key or path to public key file",
required=True,
)
def add_ssh_key(
public_key: Union[str, "os.PathLike[str]"], key_name: Optional[str] = None, comment: Optional[str] = None
) -> None:
"""Add a new Lightning AI ssh-key to your account."""
ssh_key_manager = _SSHKeyManager()
new_public_key = Path(str(public_key)).read_text() if os.path.isfile(str(public_key)) else public_key
try:
ssh_key_manager.add_key(name=key_name, comment=comment, public_key=str(new_public_key))
except ApiException as e:
# if we got an exception it might be the user passed the private key file
if os.path.isfile(str(public_key)) and os.path.isfile(f"{public_key}.pub"):
ssh_key_manager.add_key(name=key_name, comment=comment, public_key=Path(f"{public_key}.pub").read_text())
else:
raise e