forked from mlflow/mlflow
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ModelRegistryMlflowClientTest.java
146 lines (119 loc) · 5.69 KB
/
ModelRegistryMlflowClientTest.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
package org.mlflow.tracking;
import com.google.common.collect.Lists;
import org.apache.commons.io.FileUtils;
import org.mlflow.api.proto.ModelRegistry;
import org.mlflow.api.proto.ModelRegistry.ModelVersion;
import org.mlflow.api.proto.Service.RunInfo;
import org.mockito.Mockito;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.testng.Assert;
import org.testng.annotations.AfterTest;
import org.testng.annotations.BeforeTest;
import org.testng.annotations.Test;
import java.io.File;
import java.io.FilenameFilter;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.util.List;
import java.util.UUID;
import static org.mlflow.tracking.TestUtils.createExperimentName;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doReturn;
public class ModelRegistryMlflowClientTest {
private static final Logger logger = LoggerFactory.getLogger(ModelRegistryMlflowClientTest.class);
private static final MlflowProtobufMapper mapper = new MlflowProtobufMapper();
private final TestClientProvider testClientProvider = new TestClientProvider();
private MlflowClient client;
private String modelName;
private File tempDir;
private File tempFile;
private static final String content = "Hello, Worldz!";
// As only a single `.txt` is stored as a model version artifact, this filter is used to
// extract the written file.
FilenameFilter filter = new FilenameFilter() {
@Override
public boolean accept(File f, String name) {
return name.endsWith(".txt");
}
};
@BeforeTest
public void before() throws IOException {
client = testClientProvider.initializeClientAndSqlLiteBasedServer();
modelName = "Model-" + UUID.randomUUID().toString();
String expName = createExperimentName();
String expId = client.createExperiment(expName);
RunInfo runCreated = client.createRun(expId);
String runId = runCreated.getRunUuid();
tempDir = Files.createTempDirectory("tempDir").toFile();
tempFile = Files.createTempFile(tempDir.toPath(), "file", ".txt").toFile();
FileUtils.writeStringToFile(tempFile, content, StandardCharsets.UTF_8);
client.sendPost("registered-models/create",
mapper.makeCreateModel(modelName));
client.sendPost("model-versions/create",
mapper.makeCreateModelVersion(modelName, runId, tempDir.getAbsolutePath()));
}
@AfterTest
public void after() throws InterruptedException {
testClientProvider.cleanupClientAndServer();
}
@Test
public void testGetLatestModelVersions() throws IOException {
// a list of stages
List<ModelVersion> versions = client.getLatestVersions(modelName,
Lists.newArrayList("None"));
Assert.assertEquals(versions.size(), 1);
validateDetailedModelVersion(versions.get(0), modelName, "None", "1");
client.sendPatch("model-versions/update", mapper.makeUpdateModelVersion(modelName,
"1"));
// get the latest version of all stages
List<ModelVersion> modelVersion = client.getLatestVersions(modelName);
Assert.assertEquals(modelVersion.size(), 1);
validateDetailedModelVersion(modelVersion.get(0), modelName, "None", "1");
client.sendPost("model-versions/transition-stage",
mapper.makeTransitionModelVersionStage(modelName, "1", "Staging"));
modelVersion = client.getLatestVersions(modelName);
Assert.assertEquals(modelVersion.size(), 1);
validateDetailedModelVersion(modelVersion.get(0),
modelName, "Staging", "1");
}
@Test
public void testGetModelVersionDownloadUri() {
String downloadUri = client.getModelVersionDownloadUri(modelName, "1");
Assert.assertEquals(tempDir.getAbsolutePath(), downloadUri);
}
@Test
public void testDownloadModelVersion() throws IOException {
File tempDownloadDir = client.downloadModelVersion(modelName, "1");
File[] tempDownloadFile = tempDownloadDir.listFiles(filter);
Assert.assertEquals(tempDownloadFile.length, 1);
String downloadedContent = FileUtils.readFileToString(tempDownloadFile[0],
StandardCharsets.UTF_8);
Assert.assertEquals(content, downloadedContent);
}
@Test
public void testDownloadLatestModelVersion() throws IOException {
File tempDownloadDir = client.downloadLatestModelVersion(modelName, "None");
File[] tempDownloadFile = tempDownloadDir.listFiles(filter);
Assert.assertEquals(tempDownloadFile.length, 1);
String downloadedContent = FileUtils.readFileToString(tempDownloadFile[0],
StandardCharsets.UTF_8);
Assert.assertEquals(content, downloadedContent);
}
@Test(expectedExceptions = MlflowClientException.class)
public void testDownloadLatestModelVersionWhenMoreThanOneVersionIsReturned() {
MlflowClient mockedClient = Mockito.spy(client);
List<ModelVersion> modelVersions = Lists.newArrayList();
modelVersions.add(ModelVersion.newBuilder().build());
modelVersions.add(ModelVersion.newBuilder().build());
doReturn(modelVersions).when(mockedClient).getLatestVersions(any(), any());
mockedClient.downloadLatestModelVersion(modelName, "None");
}
private void validateDetailedModelVersion(ModelVersion details, String modelName,
String stage, String version) {
Assert.assertEquals(details.getCurrentStage(), stage);
Assert.assertEquals(details.getName(), modelName);
Assert.assertEquals(details.getVersion(), version);
}
}