-
Notifications
You must be signed in to change notification settings - Fork 4k
/
uri.py
318 lines (263 loc) · 12.3 KB
/
uri.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
import sys
import posixpath
import urllib.parse
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE
from mlflow.store.db.db_types import DATABASE_ENGINES
from mlflow.store.tracking import DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH, DEFAULT_ARTIFACTS_URI
from mlflow.utils.validation import _validate_db_type_string
from mlflow.utils.logging_utils import eprint
_INVALID_DB_URI_MSG = (
"Please refer to https://mlflow.org/docs/latest/tracking.html#storage for "
"format specifications."
)
_DBFS_FUSE_PREFIX = "/dbfs/"
_DBFS_HDFS_URI_PREFIX = "dbfs:/"
def is_local_uri(uri):
"""Returns true if this is a local file path (/foo or file:/foo)."""
scheme = urllib.parse.urlparse(uri).scheme
return uri != "databricks" and (scheme == "" or scheme == "file")
def is_http_uri(uri):
scheme = urllib.parse.urlparse(uri).scheme
return scheme == "http" or scheme == "https"
def is_databricks_uri(uri):
"""
Databricks URIs look like 'databricks' (default profile) or 'databricks://profile'
or 'databricks://secret_scope:secret_key_prefix'.
"""
scheme = urllib.parse.urlparse(uri).scheme
return scheme == "databricks" or uri == "databricks"
def construct_db_uri_from_profile(profile):
if profile:
return "databricks://" + profile
# Both scope and key_prefix should not contain special chars for URIs, like '/'
# and ':'.
def validate_db_scope_prefix_info(scope, prefix):
for c in ["/", ":", " "]:
if c in scope:
raise MlflowException(
"Unsupported Databricks profile name: %s." % scope
+ " Profile names cannot contain '%s'." % c
)
if prefix and c in prefix:
raise MlflowException(
"Unsupported Databricks profile key prefix: %s." % prefix
+ " Key prefixes cannot contain '%s'." % c
)
if prefix is not None and prefix.strip() == "":
raise MlflowException(
"Unsupported Databricks profile key prefix: '%s'." % prefix
+ " Key prefixes cannot be empty."
)
def get_db_info_from_uri(uri):
"""
Get the Databricks profile specified by the tracking URI (if any), otherwise
returns None.
"""
parsed_uri = urllib.parse.urlparse(uri)
if parsed_uri.scheme == "databricks":
# netloc should not be an empty string unless URI is formatted incorrectly.
if parsed_uri.netloc == "":
raise MlflowException(
"URI is formatted incorrectly: no netloc in URI '%s'." % uri
+ " This may be the case if there is only one slash in the URI."
)
profile_tokens = parsed_uri.netloc.split(":")
parsed_scope = profile_tokens[0]
if len(profile_tokens) == 1:
parsed_key_prefix = None
elif len(profile_tokens) == 2:
parsed_key_prefix = profile_tokens[1]
else:
# parse the content before the first colon as the profile.
parsed_key_prefix = ":".join(profile_tokens[1:])
validate_db_scope_prefix_info(parsed_scope, parsed_key_prefix)
return parsed_scope, parsed_key_prefix
return None, None
def get_databricks_profile_uri_from_artifact_uri(uri):
"""
Retrieves the netloc portion of the URI as a ``databricks://`` URI,
if it is a proper Databricks profile specification, e.g.
``profile@databricks`` or ``secret_scope:key_prefix@databricks``.
"""
parsed = urllib.parse.urlparse(uri)
if not parsed.netloc or parsed.hostname != "databricks":
return None
if not parsed.username: # no profile or scope:key
return "databricks" # the default tracking/registry URI
validate_db_scope_prefix_info(parsed.username, parsed.password)
key_prefix = ":" + parsed.password if parsed.password else ""
return "databricks://" + parsed.username + key_prefix
def remove_databricks_profile_info_from_artifact_uri(artifact_uri):
"""
Only removes the netloc portion of the URI if it is a Databricks
profile specification, e.g.
``profile@databricks`` or ``secret_scope:key_prefix@databricks``.
"""
parsed = urllib.parse.urlparse(artifact_uri)
if not parsed.netloc or parsed.hostname != "databricks":
return artifact_uri
return urllib.parse.urlunparse(parsed._replace(netloc=""))
def add_databricks_profile_info_to_artifact_uri(artifact_uri, databricks_profile_uri):
"""
Throws an exception if ``databricks_profile_uri`` is not valid.
"""
if not databricks_profile_uri or not is_databricks_uri(databricks_profile_uri):
return artifact_uri
artifact_uri_parsed = urllib.parse.urlparse(artifact_uri)
# Do not overwrite the authority section if there is already one
if artifact_uri_parsed.netloc:
return artifact_uri
scheme = artifact_uri_parsed.scheme
if scheme == "dbfs" or scheme == "runs" or scheme == "models":
if databricks_profile_uri == "databricks":
netloc = "databricks"
else:
(profile, key_prefix) = get_db_info_from_uri(databricks_profile_uri)
prefix = ":" + key_prefix if key_prefix else ""
netloc = profile + prefix + "@databricks"
new_parsed = artifact_uri_parsed._replace(netloc=netloc)
return urllib.parse.urlunparse(new_parsed)
else:
return artifact_uri
def extract_db_type_from_uri(db_uri):
"""
Parse the specified DB URI to extract the database type. Confirm the database type is
supported. If a driver is specified, confirm it passes a plausible regex.
"""
scheme = urllib.parse.urlparse(db_uri).scheme
scheme_plus_count = scheme.count("+")
if scheme_plus_count == 0:
db_type = scheme
elif scheme_plus_count == 1:
db_type, _ = scheme.split("+")
else:
error_msg = "Invalid database URI: '%s'. %s" % (db_uri, _INVALID_DB_URI_MSG)
raise MlflowException(error_msg, INVALID_PARAMETER_VALUE)
_validate_db_type_string(db_type)
return db_type
def get_uri_scheme(uri_or_path):
scheme = urllib.parse.urlparse(uri_or_path).scheme
if any([scheme.lower().startswith(db) for db in DATABASE_ENGINES]):
return extract_db_type_from_uri(uri_or_path)
else:
return scheme
def extract_and_normalize_path(uri):
parsed_uri_path = urllib.parse.urlparse(uri).path
normalized_path = posixpath.normpath(parsed_uri_path)
return normalized_path.lstrip("/")
def append_to_uri_path(uri, *paths):
"""
Appends the specified POSIX `paths` to the path component of the specified `uri`.
:param uri: The input URI, represented as a string.
:param paths: The POSIX paths to append to the specified `uri`'s path component.
:return: A new URI with a path component consisting of the specified `paths` appended to
the path component of the specified `uri`.
>>> uri1 = "s3://root/base/path?param=value"
>>> uri1 = append_to_uri_path(uri1, "some/subpath", "/anotherpath")
>>> assert uri1 == "s3://root/base/path/some/subpath/anotherpath?param=value"
>>> uri2 = "a/posix/path"
>>> uri2 = append_to_uri_path(uri2, "/some", "subpath")
>>> assert uri2 == "a/posixpath/some/subpath"
"""
path = ""
for subpath in paths:
path = _join_posixpaths_and_append_absolute_suffixes(path, subpath)
parsed_uri = urllib.parse.urlparse(uri)
if len(parsed_uri.scheme) == 0:
# If the input URI does not define a scheme, we assume that it is a POSIX path
# and join it with the specified input paths
return _join_posixpaths_and_append_absolute_suffixes(uri, path)
prefix = ""
if not parsed_uri.path.startswith("/"):
# For certain URI schemes (e.g., "file:"), urllib's unparse routine does
# not preserve the relative URI path component properly. In certain cases,
# urlunparse converts relative paths to absolute paths. We introduce this logic
# to circumvent urlunparse's erroneous conversion
prefix = parsed_uri.scheme + ":"
parsed_uri = parsed_uri._replace(scheme="")
new_uri_path = _join_posixpaths_and_append_absolute_suffixes(parsed_uri.path, path)
new_parsed_uri = parsed_uri._replace(path=new_uri_path)
return prefix + urllib.parse.urlunparse(new_parsed_uri)
def _join_posixpaths_and_append_absolute_suffixes(prefix_path, suffix_path):
"""
Joins the POSIX path `prefix_path` with the POSIX path `suffix_path`. Unlike posixpath.join(),
if `suffix_path` is an absolute path, it is appended to prefix_path.
>>> result1 = _join_posixpaths_and_append_absolute_suffixes("relpath1", "relpath2")
>>> assert result1 == "relpath1/relpath2"
>>> result2 = _join_posixpaths_and_append_absolute_suffixes("relpath", "/absolutepath")
>>> assert result2 == "relpath/absolutepath"
>>> result3 = _join_posixpaths_and_append_absolute_suffixes("/absolutepath", "relpath")
>>> assert result3 == "/absolutepath/relpath"
>>> result4 = _join_posixpaths_and_append_absolute_suffixes("/absolutepath1", "/absolutepath2")
>>> assert result4 == "/absolutepath1/absolutepath2"
"""
if len(prefix_path) == 0:
return suffix_path
# If the specified prefix path is non-empty, we must relativize the suffix path by removing
# the leading slash, if present. Otherwise, posixpath.join() would omit the prefix from the
# joined path
suffix_path = suffix_path.lstrip(posixpath.sep)
return posixpath.join(prefix_path, suffix_path)
def is_databricks_acled_artifacts_uri(artifact_uri):
_ACLED_ARTIFACT_URI = "databricks/mlflow-tracking/"
artifact_uri_path = extract_and_normalize_path(artifact_uri)
return artifact_uri_path.startswith(_ACLED_ARTIFACT_URI)
def is_databricks_model_registry_artifacts_uri(artifact_uri):
_MODEL_REGISTRY_ARTIFACT_URI = "databricks/mlflow-registry/"
artifact_uri_path = extract_and_normalize_path(artifact_uri)
return artifact_uri_path.startswith(_MODEL_REGISTRY_ARTIFACT_URI)
def construct_run_url(hostname, experiment_id, run_id, workspace_id=None):
if not hostname or not experiment_id or not run_id:
raise MlflowException(
"Hostname, experiment ID, and run ID are all required to construct a run URL"
)
prefix = hostname
if workspace_id and workspace_id != "0":
prefix += "?o=" + workspace_id
return prefix + "#mlflow/experiments/{experiment_id}/runs/{run_id}".format(
experiment_id=experiment_id, run_id=run_id
)
def is_valid_dbfs_uri(uri):
parsed = urllib.parse.urlparse(uri)
if parsed.scheme != "dbfs":
return False
try:
db_profile_uri = get_databricks_profile_uri_from_artifact_uri(uri)
except MlflowException:
db_profile_uri = None
return not parsed.netloc or db_profile_uri is not None
def dbfs_hdfs_uri_to_fuse_path(dbfs_uri):
"""
Converts the provided DBFS URI into a DBFS FUSE path
:param dbfs_uri: A DBFS URI like "dbfs:/my-directory". Can also be a scheme-less URI like
"/my-directory" if running in an environment where the default HDFS filesystem
is "dbfs:/" (e.g. Databricks)
:return A DBFS FUSE-style path, e.g. "/dbfs/my-directory"
"""
if not is_valid_dbfs_uri(dbfs_uri) and dbfs_uri == posixpath.abspath(dbfs_uri):
# Convert posixpaths (e.g. "/tmp/mlflow") to DBFS URIs by adding "dbfs:/" as a prefix
dbfs_uri = "dbfs:" + dbfs_uri
if not dbfs_uri.startswith(_DBFS_HDFS_URI_PREFIX):
raise MlflowException(
"Path '%s' did not start with expected DBFS URI prefix '%s'"
% (dbfs_uri, _DBFS_HDFS_URI_PREFIX),
)
return _DBFS_FUSE_PREFIX + dbfs_uri[len(_DBFS_HDFS_URI_PREFIX) :]
def resolve_default_artifact_root(
serve_artifacts, default_artifact_root, backend_store_uri, resolve_to_local=False
):
if serve_artifacts and not default_artifact_root:
default_artifact_root = DEFAULT_ARTIFACTS_URI
elif not serve_artifacts and not default_artifact_root:
if is_local_uri(backend_store_uri):
default_artifact_root = backend_store_uri
elif resolve_to_local:
default_artifact_root = DEFAULT_LOCAL_FILE_AND_ARTIFACT_PATH
else:
eprint(
"Option 'default-artifact-root' is required, when backend store is not "
"local file based."
)
sys.exit(1)
return default_artifact_root