-
Notifications
You must be signed in to change notification settings - Fork 4k
/
client.py
454 lines (391 loc) · 19.8 KB
/
client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
"""
Internal package providing a Python CRUD interface to MLflow experiments and runs.
This is a lower level API than the :py:mod:`mlflow.tracking.fluent` module, and is
exposed in the :py:mod:`mlflow.tracking` module.
"""
import time
import os
from mlflow.store.tracking import SEARCH_MAX_RESULTS_DEFAULT
from mlflow.tracking._tracking_service import utils
from mlflow.utils.validation import (
_validate_param_name,
_validate_tag_name,
_validate_run_id,
_validate_experiment_artifact_location,
_validate_experiment_name,
_validate_metric,
_validate_param_keys_unique,
PARAM_VALIDATION_MSG,
)
from mlflow.entities import Param, Metric, RunStatus, RunTag, ViewType, ExperimentTag
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE, ErrorCode
from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository
from mlflow.utils.mlflow_tags import MLFLOW_USER
from mlflow.utils.string_utils import is_string_type
from mlflow.utils.uri import add_databricks_profile_info_to_artifact_uri
from collections import OrderedDict
class TrackingServiceClient(object):
"""
Client of an MLflow Tracking Server that creates and manages experiments and runs.
"""
_artifact_repos_cache = OrderedDict()
def __init__(self, tracking_uri):
"""
:param tracking_uri: Address of local or remote tracking server.
"""
self.tracking_uri = tracking_uri
# NB: Fetch the tracking store (`self.store`) upon client initialization to ensure that
# the tracking URI is valid and the store can be properly resolved. We define `store` as a
# property method to ensure that the client is serializable, even if the store is not
# self.store # pylint: disable=pointless-statement
self.store
@property
def store(self):
return utils._get_store(self.tracking_uri)
def get_run(self, run_id):
"""
Fetch the run from backend store. The resulting :py:class:`Run <mlflow.entities.Run>`
contains a collection of run metadata -- :py:class:`RunInfo <mlflow.entities.RunInfo>`,
as well as a collection of run parameters, tags, and metrics --
:py:class:`RunData <mlflow.entities.RunData>`. In the case where multiple metrics with the
same key are logged for the run, the :py:class:`RunData <mlflow.entities.RunData>` contains
the most recently logged value at the largest step for each metric.
:param run_id: Unique identifier for the run.
:return: A single :py:class:`mlflow.entities.Run` object, if the run exists. Otherwise,
raises an exception.
"""
_validate_run_id(run_id)
return self.store.get_run(run_id)
def get_metric_history(self, run_id, key):
"""
Return a list of metric objects corresponding to all values logged for a given metric.
:param run_id: Unique identifier for run
:param key: Metric name within the run
:return: A list of :py:class:`mlflow.entities.Metric` entities if logged, else empty list
"""
return self.store.get_metric_history(run_id=run_id, metric_key=key)
def create_run(self, experiment_id, start_time=None, tags=None):
"""
Create a :py:class:`mlflow.entities.Run` object that can be associated with
metrics, parameters, artifacts, etc.
Unlike :py:func:`mlflow.projects.run`, creates objects but does not run code.
Unlike :py:func:`mlflow.start_run`, does not change the "active run" used by
:py:func:`mlflow.log_param`.
:param experiment_id: The ID of then experiment to create a run in.
:param start_time: If not provided, use the current timestamp.
:param tags: A dictionary of key-value pairs that are converted into
:py:class:`mlflow.entities.RunTag` objects.
:return: :py:class:`mlflow.entities.Run` that was created.
"""
tags = tags if tags else {}
# Extract user from tags
# This logic is temporary; the user_id attribute of runs is deprecated and will be removed
# in a later release.
user_id = tags.get(MLFLOW_USER, "unknown")
return self.store.create_run(
experiment_id=experiment_id,
user_id=user_id,
start_time=start_time or int(time.time() * 1000),
tags=[RunTag(key, value) for (key, value) in tags.items()],
)
def list_run_infos(
self,
experiment_id,
run_view_type=ViewType.ACTIVE_ONLY,
max_results=SEARCH_MAX_RESULTS_DEFAULT,
order_by=None,
page_token=None,
):
"""
Return run information for runs which belong to the experiment_id.
:param experiment_id: The experiment id which to search
:param run_view_type: ACTIVE_ONLY, DELETED_ONLY, or ALL runs
:param max_results: Maximum number of results desired.
:param order_by: List of order_by clauses. Currently supported values are
are ``metric.key``, ``parameter.key``, ``tag.key``, ``attribute.key``.
For example, ``order_by=["tag.release ASC", "metric.click_rate DESC"]``.
:return: A :py:class:`PagedList <mlflow.store.entities.PagedList>` of
:py:class:`RunInfo <mlflow.entities.RunInfo>` objects that satisfy the search
expressions. If the underlying tracking store supports pagination, the token for the
next page may be obtained via the ``token`` attribute of the returned object.
"""
return self.store.list_run_infos(
experiment_id, run_view_type, max_results, order_by, page_token
)
def list_experiments(self, view_type=ViewType.ACTIVE_ONLY, max_results=None, page_token=None):
"""
:param view_type: Qualify requested type of experiments.
:param max_results: If passed, specifies the maximum number of experiments desired.
If not passed, all experiments will be returned for the File and
SQLAlchemy backends. For the REST backend, the server will determine
an appropriate number of experiments to return.
:param page_token: Token specifying the next page of results. It should be obtained from
a ``list_experiments`` call.
:return: A :py:class:`PagedList <mlflow.store.entities.PagedList>` of
:py:class:`Experiment <mlflow.entities.Experiment>` objects. The pagination token
for the next page can be obtained via the ``token`` attribute of the object.
"""
return self.store.list_experiments(
view_type=view_type, max_results=max_results, page_token=page_token
)
def get_experiment(self, experiment_id):
"""
:param experiment_id: The experiment ID returned from ``create_experiment``.
:return: :py:class:`mlflow.entities.Experiment`
"""
return self.store.get_experiment(experiment_id)
def get_experiment_by_name(self, name):
"""
:param name: The experiment name.
:return: :py:class:`mlflow.entities.Experiment`
"""
return self.store.get_experiment_by_name(name)
def create_experiment(self, name, artifact_location=None, tags=None):
"""Create an experiment.
:param name: The experiment name. Must be unique.
:param artifact_location: The location to store run artifacts.
If not provided, the server picks an appropriate default.
:param tags: A dictionary of key-value pairs that are converted into
:py:class:`mlflow.entities.ExperimentTag` objects.
:return: Integer ID of the created experiment.
"""
_validate_experiment_name(name)
_validate_experiment_artifact_location(artifact_location)
return self.store.create_experiment(
name=name,
artifact_location=artifact_location,
tags=[ExperimentTag(key, value) for (key, value) in tags.items()] if tags else [],
)
def delete_experiment(self, experiment_id):
"""
Delete an experiment from the backend store.
:param experiment_id: The experiment ID returned from ``create_experiment``.
"""
self.store.delete_experiment(experiment_id)
def restore_experiment(self, experiment_id):
"""
Restore a deleted experiment unless permanently deleted.
:param experiment_id: The experiment ID returned from ``create_experiment``.
"""
self.store.restore_experiment(experiment_id)
def rename_experiment(self, experiment_id, new_name):
"""
Update an experiment's name. The new name must be unique.
:param experiment_id: The experiment ID returned from ``create_experiment``.
"""
self.store.rename_experiment(experiment_id, new_name)
def log_metric(self, run_id, key, value, timestamp=None, step=None):
"""
Log a metric against the run ID.
:param run_id: The run id to which the metric should be logged.
:param key: Metric name (string). This string may only contain alphanumerics,
underscores (_), dashes (-), periods (.), spaces ( ), and slashes (/).
All backend stores will support keys up to length 250, but some may
support larger keys.
:param value: Metric value (float). Note that some special values such
as +/- Infinity may be replaced by other values depending on the store. For
example, the SQLAlchemy store replaces +/- Inf with max / min float values.
All backend stores will support values up to length 5000, but some
may support larger values.
:param timestamp: Time when this metric was calculated. Defaults to the current system time.
:param step: Training step (iteration) at which was the metric calculated. Defaults to 0.
"""
timestamp = timestamp if timestamp is not None else int(time.time() * 1000)
step = step if step is not None else 0
_validate_metric(key, value, timestamp, step)
metric = Metric(key, value, timestamp, step)
self.store.log_metric(run_id, metric)
def log_param(self, run_id, key, value):
"""
Log a parameter against the run ID. Value is converted to a string.
"""
_validate_param_name(key)
param = Param(key, str(value))
try:
self.store.log_param(run_id, param)
except MlflowException as e:
if e.error_code == ErrorCode.Name(INVALID_PARAMETER_VALUE):
msg = f"{e.message}{PARAM_VALIDATION_MSG}'"
raise MlflowException(msg, INVALID_PARAMETER_VALUE)
else:
raise e
def set_experiment_tag(self, experiment_id, key, value):
"""
Set a tag on the experiment with the specified ID. Value is converted to a string.
:param experiment_id: String ID of the experiment.
:param key: Name of the tag.
:param value: Tag value (converted to a string).
"""
_validate_tag_name(key)
tag = ExperimentTag(key, str(value))
self.store.set_experiment_tag(experiment_id, tag)
def set_tag(self, run_id, key, value):
"""
Set a tag on the run with the specified ID. Value is converted to a string.
:param run_id: String ID of the run.
:param key: Tag name (string). This string may only contain alphanumerics, underscores
(_), dashes (-), periods (.), spaces ( ), and slashes (/).
All backend stores will support keys up to length 250, but some may
support larger keys.
:param value: Tag value (string, but will be string-ified if not).
All backend stores will support values up to length 5000, but some
may support larger values.
"""
_validate_tag_name(key)
tag = RunTag(key, str(value))
self.store.set_tag(run_id, tag)
def delete_tag(self, run_id, key):
"""
Delete a tag from a run. This is irreversible.
:param run_id: String ID of the run
:param key: Name of the tag
"""
self.store.delete_tag(run_id, key)
def log_batch(self, run_id, metrics=(), params=(), tags=()):
"""
Log multiple metrics, params, and/or tags.
:param run_id: String ID of the run
:param metrics: If provided, List of Metric(key, value, timestamp) instances.
:param params: If provided, List of Param(key, value) instances.
:param tags: If provided, List of RunTag(key, value) instances.
Raises an MlflowException if any errors occur.
:return: None
"""
if len(metrics) == 0 and len(params) == 0 and len(tags) == 0:
return
if len(params) > 1:
_validate_param_keys_unique(params)
for metric in metrics:
_validate_metric(metric.key, metric.value, metric.timestamp, metric.step)
for param in params:
_validate_param_name(param.key)
for tag in tags:
_validate_tag_name(tag.key)
self.store.log_batch(run_id=run_id, metrics=metrics, params=params, tags=tags)
def _record_logged_model(self, run_id, mlflow_model):
from mlflow.models import Model
if not isinstance(mlflow_model, Model):
raise TypeError(
"Argument 'mlflow_model' should be of type mlflow.models.Model but was "
"{}".format(type(mlflow_model))
)
self.store.record_logged_model(run_id, mlflow_model)
def _get_artifact_repo(self, run_id):
# Attempt to fetch the artifact repo from a local cache
cached_repo = TrackingServiceClient._artifact_repos_cache.get(run_id)
if cached_repo is not None:
return cached_repo
else:
run = self.get_run(run_id)
artifact_uri = add_databricks_profile_info_to_artifact_uri(
run.info.artifact_uri, self.tracking_uri
)
artifact_repo = get_artifact_repository(artifact_uri)
# Cache the artifact repo to avoid a future network call, removing the oldest
# entry in the cache if there are too many elements
if len(TrackingServiceClient._artifact_repos_cache) > 1024:
TrackingServiceClient._artifact_repos_cache.popitem(last=False)
TrackingServiceClient._artifact_repos_cache[run_id] = artifact_repo
return artifact_repo
def log_artifact(self, run_id, local_path, artifact_path=None):
"""
Write a local file or directory to the remote ``artifact_uri``.
:param local_path: Path to the file or directory to write.
:param artifact_path: If provided, the directory in ``artifact_uri`` to write to.
"""
artifact_repo = self._get_artifact_repo(run_id)
if os.path.isdir(local_path):
dir_name = os.path.basename(os.path.normpath(local_path))
path_name = (
os.path.join(artifact_path, dir_name) if artifact_path is not None else dir_name
)
artifact_repo.log_artifacts(local_path, path_name)
else:
artifact_repo.log_artifact(local_path, artifact_path)
def log_artifacts(self, run_id, local_dir, artifact_path=None):
"""
Write a directory of files to the remote ``artifact_uri``.
:param local_dir: Path to the directory of files to write.
:param artifact_path: If provided, the directory in ``artifact_uri`` to write to.
"""
self._get_artifact_repo(run_id).log_artifacts(local_dir, artifact_path)
def list_artifacts(self, run_id, path=None):
"""
List the artifacts for a run.
:param run_id: The run to list artifacts from.
:param path: The run's relative artifact path to list from. By default it is set to None
or the root artifact path.
:return: List of :py:class:`mlflow.entities.FileInfo`
"""
return self._get_artifact_repo(run_id).list_artifacts(path)
def download_artifacts(self, run_id, path, dst_path=None):
"""
Download an artifact file or directory from a run to a local directory if applicable,
and return a local path for it.
:param run_id: The run to download artifacts from.
:param path: Relative source path to the desired artifact.
:param dst_path: Absolute path of the local filesystem destination directory to which to
download the specified artifacts. This directory must already exist.
If unspecified, the artifacts will either be downloaded to a new
uniquely-named directory on the local filesystem or will be returned
directly in the case of the LocalArtifactRepository.
:return: Local path of desired artifact.
"""
return self._get_artifact_repo(run_id).download_artifacts(path, dst_path)
def set_terminated(self, run_id, status=None, end_time=None):
"""Set a run's status to terminated.
:param status: A string value of :py:class:`mlflow.entities.RunStatus`.
Defaults to "FINISHED".
:param end_time: If not provided, defaults to the current time."""
end_time = end_time if end_time else int(time.time() * 1000)
status = status if status else RunStatus.to_string(RunStatus.FINISHED)
self.store.update_run_info(
run_id, run_status=RunStatus.from_string(status), end_time=end_time
)
def delete_run(self, run_id):
"""
Deletes a run with the given ID.
"""
self.store.delete_run(run_id)
def restore_run(self, run_id):
"""
Restores a deleted run with the given ID.
"""
self.store.restore_run(run_id)
def search_runs(
self,
experiment_ids,
filter_string="",
run_view_type=ViewType.ACTIVE_ONLY,
max_results=SEARCH_MAX_RESULTS_DEFAULT,
order_by=None,
page_token=None,
):
"""
Search experiments that fit the search criteria.
:param experiment_ids: List of experiment IDs, or a single int or string id.
:param filter_string: Filter query string, defaults to searching all runs.
:param run_view_type: one of enum values ACTIVE_ONLY, DELETED_ONLY, or ALL runs
defined in :py:class:`mlflow.entities.ViewType`.
:param max_results: Maximum number of runs desired.
:param order_by: List of columns to order by (e.g., "metrics.rmse"). The ``order_by`` column
can contain an optional ``DESC`` or ``ASC`` value. The default is ``ASC``.
The default ordering is to sort by ``start_time DESC``, then ``run_id``.
:param page_token: Token specifying the next page of results. It should be obtained from
a ``search_runs`` call.
:return: A :py:class:`PagedList <mlflow.store.entities.PagedList>` of
:py:class:`Run <mlflow.entities.Run>` objects that satisfy the search expressions.
If the underlying tracking store supports pagination, the token for the next page may
be obtained via the ``token`` attribute of the returned object.
"""
if isinstance(experiment_ids, int) or is_string_type(experiment_ids):
experiment_ids = [experiment_ids]
return self.store.search_runs(
experiment_ids=experiment_ids,
filter_string=filter_string,
run_view_type=run_view_type,
max_results=max_results,
order_by=order_by,
page_token=page_token,
)