-
Notifications
You must be signed in to change notification settings - Fork 4k
/
example.py
42 lines (30 loc) · 1.18 KB
/
example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import os
import tempfile
from pprint import pprint
import mlflow
def save_text(path, text):
with open(path, "w") as f:
f.write(text)
# NOTE: ensure the tracking server has been started with --serve-artifacts to enable
# MLflow artifact serving functionality.
def main():
assert "MLFLOW_TRACKING_URI" in os.environ
# Upload artifacts
with mlflow.start_run() as run, tempfile.TemporaryDirectory() as tmp_dir:
tmp_path_a = os.path.join(tmp_dir, "a.txt")
save_text(tmp_path_a, "0")
tmp_sub_dir = tmp_path_b = os.path.join(tmp_dir, "dir")
os.makedirs(tmp_sub_dir)
tmp_path_b = os.path.join(tmp_sub_dir, "b.txt")
save_text(tmp_path_b, "1")
mlflow.log_artifact(tmp_path_a)
mlflow.log_artifacts(tmp_sub_dir, artifact_path="dir")
# Download artifacts
client = mlflow.tracking.MlflowClient()
pprint(os.listdir(client.download_artifacts(run.info.run_id, "")))
pprint(os.listdir(client.download_artifacts(run.info.run_id, "dir")))
# List artifacts
pprint(client.list_artifacts(run.info.run_id))
pprint(client.list_artifacts(run.info.run_id, "dir"))
if __name__ == "__main__":
main()