forked from iterative/dvc
-
Notifications
You must be signed in to change notification settings - Fork 1
/
__init__.py
532 lines (423 loc) · 15.7 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
import logging
import os
from collections import defaultdict
from contextlib import contextmanager
from functools import wraps
from typing import TYPE_CHECKING, Callable, Optional
from funcy import cached_property
from dvc.exceptions import FileMissingError
from dvc.exceptions import IsADirectoryError as DvcIsADirectoryError
from dvc.exceptions import NotDvcRepoError, OutputNotFoundError
from dvc.ignore import DvcIgnoreFilter
from dvc.utils import env2bool
from dvc.utils.fs import path_isin
if TYPE_CHECKING:
from dvc.fs import FileSystem
from dvc.repo.scm_context import SCMContext
from dvc.scm import Base
logger = logging.getLogger(__name__)
@contextmanager
def lock_repo(repo: "Repo"):
# pylint: disable=protected-access
depth = repo._lock_depth
repo._lock_depth += 1
try:
if depth > 0:
yield
else:
with repo.lock:
repo._reset()
yield
# Graph cache is no longer valid after we release the repo.lock
repo._reset()
finally:
repo._lock_depth = depth
def locked(f):
@wraps(f)
def wrapper(repo, *args, **kwargs):
with lock_repo(repo):
return f(repo, *args, **kwargs)
return wrapper
class Repo:
DVC_DIR = ".dvc"
from dvc.repo.add import add # type: ignore[misc]
from dvc.repo.checkout import checkout # type: ignore[misc]
from dvc.repo.commit import commit # type: ignore[misc]
from dvc.repo.destroy import destroy # type: ignore[misc]
from dvc.repo.diff import diff # type: ignore[misc]
from dvc.repo.fetch import fetch # type: ignore[misc]
from dvc.repo.freeze import freeze, unfreeze # type: ignore[misc]
from dvc.repo.gc import gc # type: ignore[misc]
from dvc.repo.get import get as _get # type: ignore[misc]
from dvc.repo.get_url import get_url as _get_url # type: ignore[misc]
from dvc.repo.imp import imp # type: ignore[misc]
from dvc.repo.imp_url import imp_url # type: ignore[misc]
from dvc.repo.install import install # type: ignore[misc]
from dvc.repo.ls import ls as _ls # type: ignore[misc]
from dvc.repo.move import move # type: ignore[misc]
from dvc.repo.pull import pull # type: ignore[misc]
from dvc.repo.push import push # type: ignore[misc]
from dvc.repo.remove import remove # type: ignore[misc]
from dvc.repo.reproduce import reproduce # type: ignore[misc]
from dvc.repo.run import run # type: ignore[misc]
from dvc.repo.status import status # type: ignore[misc]
from dvc.repo.update import update # type: ignore[misc]
ls = staticmethod(_ls)
get = staticmethod(_get)
get_url = staticmethod(_get_url)
def _get_repo_dirs(
self,
root_dir: str = None,
fs: "FileSystem" = None,
uninitialized: bool = False,
scm: "Base" = None,
):
from dvc.fs import localfs
from dvc.scm import SCM, SCMError
dvc_dir = None
tmp_dir = None
try:
root_dir = self.find_root(root_dir, fs)
fs = fs or localfs
dvc_dir = fs.path.join(root_dir, self.DVC_DIR)
tmp_dir = fs.path.join(dvc_dir, "tmp")
except NotDvcRepoError:
if not uninitialized:
raise
if not scm:
try:
scm = SCM(root_dir or os.curdir)
except SCMError:
scm = SCM(os.curdir, no_scm=True)
if not fs or not root_dir:
root_dir = scm.root_dir
assert root_dir
return root_dir, dvc_dir, tmp_dir
def _get_database_dir(self, db_name):
# NOTE: by default, store SQLite-based remote indexes and state's
# `links` and `md5s` caches in the repository itself to avoid any
# possible state corruption in 'shared cache dir' scenario, but allow
# user to override this through config when, say, the repository is
# located on a mounted volume — see
# https://github.com/iterative/dvc/issues/4420
base_db_dir = self.config.get(db_name, {}).get("dir", None)
if not base_db_dir:
return self.tmp_dir
import hashlib
from dvc.utils.fs import makedirs
root_dir_hash = hashlib.sha224(
self.root_dir.encode("utf-8")
).hexdigest()
db_dir = os.path.join(
base_db_dir,
self.DVC_DIR,
f"{os.path.basename(self.root_dir)}-{root_dir_hash[0:7]}",
)
makedirs(db_dir, exist_ok=True)
return db_dir
def __init__(
self,
root_dir=None,
fs=None,
rev=None,
subrepos=False,
uninitialized=False,
config=None,
url=None,
repo_factory=None,
scm=None,
):
from dvc.config import Config
from dvc.data_cloud import DataCloud
from dvc.fs import GitFileSystem, localfs
from dvc.lock import LockNoop, make_lock
from dvc.odbmgr import ODBManager
from dvc.repo.metrics import Metrics
from dvc.repo.params import Params
from dvc.repo.plots import Plots
from dvc.repo.stage import StageLoad
from dvc.scm import SCM
from dvc.stage.cache import StageCache
from dvc_data.hashfile.state import State, StateNoop
self.url = url
self._fs_conf = {"repo_factory": repo_factory}
self._fs = fs or localfs
self._scm = scm
if rev and not fs:
self._scm = scm = SCM(root_dir or os.curdir)
root_dir = "/"
self._fs = GitFileSystem(scm=self._scm, rev=rev)
self.root_dir, self.dvc_dir, self.tmp_dir = self._get_repo_dirs(
root_dir=root_dir,
fs=self.fs,
uninitialized=uninitialized,
scm=scm,
)
self.config = Config(self.dvc_dir, fs=self.fs, config=config)
self._uninitialized = uninitialized
# used by DvcFileSystem to determine if it should traverse subrepos
self.subrepos = subrepos
self.cloud = DataCloud(self)
self.stage = StageLoad(self)
if isinstance(self.fs, GitFileSystem) or not self.dvc_dir:
self.lock = LockNoop()
self.state = StateNoop()
self.odb = ODBManager(self)
self.tmp_dir = None
else:
from dvc.utils.fs import makedirs
makedirs(self.tmp_dir, exist_ok=True)
self.lock = make_lock(
os.path.join(self.tmp_dir, "lock"),
tmp_dir=self.tmp_dir,
hardlink_lock=self.config["core"].get("hardlink_lock", False),
friendly=True,
)
state_db_dir = self._get_database_dir("state")
self.state = State(self.root_dir, state_db_dir, self.dvcignore)
self.odb = ODBManager(self)
self.stage_cache = StageCache(self)
self._ignore()
self.metrics = Metrics(self)
self.plots = Plots(self)
self.params = Params(self)
self.stage_collection_error_handler: Optional[
Callable[[str, Exception], None]
] = None
self._lock_depth = 0
def __str__(self):
return self.url or self.root_dir
@cached_property
def index(self):
from dvc.repo.index import Index
return Index(self)
@staticmethod
def open(url, *args, **kwargs):
if url is None:
url = os.getcwd()
if os.path.exists(url):
try:
return Repo(url, *args, **kwargs)
except NotDvcRepoError:
pass # fallthrough to external_repo
from dvc.external_repo import external_repo
return external_repo(url, *args, **kwargs)
@cached_property
def scm(self):
from dvc.scm import SCM, SCMError
if self._scm:
return self._scm
no_scm = self.config["core"].get("no_scm", False)
try:
return SCM(self.root_dir, no_scm=no_scm)
except SCMError:
if self._uninitialized:
# might not be a git/dvc repo at all
# used in `params/metrics/plots/live` targets
return SCM(self.root_dir, no_scm=True)
raise
@cached_property
def scm_context(self) -> "SCMContext":
from dvc.repo.scm_context import SCMContext
return SCMContext(self.scm, self.config)
@cached_property
def dvcignore(self) -> DvcIgnoreFilter:
return DvcIgnoreFilter(self.fs, self.root_dir)
def get_rev(self):
from dvc.fs import LocalFileSystem
assert self.scm
if isinstance(self.fs, LocalFileSystem):
from dvc.scm import map_scm_exception
with map_scm_exception():
return self.scm.get_rev()
return self.fs.rev
@cached_property
def experiments(self):
from dvc.repo.experiments import Experiments
return Experiments(self)
@cached_property
def machine(self):
from dvc.machine import MachineManager
if self.tmp_dir and (
self.config["feature"].get("machine", False)
or env2bool("DVC_TEST")
):
return MachineManager(self)
return None
@property
def fs(self) -> "FileSystem":
return self._fs
@fs.setter
def fs(self, fs: "FileSystem"):
self._fs = fs
# Our graph cache is no longer valid, as it was based on the previous
# fs.
self._reset()
def __repr__(self):
return f"{self.__class__.__name__}: '{self.root_dir}'"
@classmethod
def find_root(cls, root=None, fs=None) -> str:
from dvc.fs import LocalFileSystem, localfs
fs = fs or localfs
root = root or os.curdir
root_dir = fs.path.realpath(root)
if not fs.isdir(root_dir):
raise NotDvcRepoError(f"directory '{root}' does not exist")
while True:
dvc_dir = fs.path.join(root_dir, cls.DVC_DIR)
if fs.isdir(dvc_dir):
return root_dir
if isinstance(fs, LocalFileSystem) and os.path.ismount(root_dir):
break
parent = fs.path.parent(root_dir)
if parent == root_dir:
break
root_dir = parent
msg = "you are not inside of a DVC repository"
if isinstance(fs, LocalFileSystem):
msg = f"{msg} (checked up to mount point '{root_dir}')"
raise NotDvcRepoError(msg)
@classmethod
def find_dvc_dir(cls, root=None):
root_dir = cls.find_root(root)
return os.path.join(root_dir, cls.DVC_DIR)
@staticmethod
def init(root_dir=os.curdir, no_scm=False, force=False, subdir=False):
from dvc.repo.init import init
return init(
root_dir=root_dir, no_scm=no_scm, force=force, subdir=subdir
)
def unprotect(self, target):
return self.odb.local.unprotect(target)
def _ignore(self):
flist = [self.config.files["local"], self.tmp_dir]
if path_isin(self.odb.local.cache_dir, self.root_dir):
flist += [self.odb.local.cache_dir]
for file in flist:
self.scm_context.ignore(file)
def brancher(self, *args, **kwargs):
from dvc.repo.brancher import brancher
return brancher(self, *args, **kwargs)
def used_objs(
self,
targets=None,
all_branches=False,
with_deps=False,
all_tags=False,
all_commits=False,
all_experiments=False,
remote=None,
force=False,
jobs=None,
recursive=False,
used_run_cache=None,
revs=None,
):
"""Get the stages related to the given target and collect
the `info` of its outputs.
This is useful to know what files from the cache are _in use_
(namely, a file described as an output on a stage).
The scope is, by default, the working directory, but you can use
`all_branches`/`all_tags`/`all_commits`/`all_experiments` to expand
the scope.
Returns:
A dict mapping (remote) ODB instances to sets of objects that
belong to each ODB. If the ODB instance is None, the objects
are naive and do not belong to a specific remote ODB.
"""
used = defaultdict(set)
for _ in self.brancher(
revs=revs,
all_branches=all_branches,
all_tags=all_tags,
all_commits=all_commits,
all_experiments=all_experiments,
):
for odb, objs in self.index.used_objs(
targets,
remote=remote,
force=force,
jobs=jobs,
recursive=recursive,
with_deps=with_deps,
).items():
used[odb].update(objs)
if used_run_cache:
for odb, objs in self.stage_cache.get_used_objs(
used_run_cache, remote=remote, force=force, jobs=jobs
).items():
used[odb].update(objs)
return used
@property
def stages(self): # obsolete, only for backward-compatibility
return self.index.stages
def find_outs_by_path(self, path, outs=None, recursive=False, strict=True):
# using `outs_graph` to ensure graph checks are run
outs = outs or self.index.outs_graph
abs_path = self.fs.path.abspath(path)
fs_path = abs_path
def func(out):
def eq(one, two):
return one == two
match = eq if strict else out.fs.path.isin_or_eq
if out.protocol == "local" and match(fs_path, out.fs_path):
return True
if recursive and out.fs.path.isin(out.fs_path, fs_path):
return True
return False
matched = list(filter(func, outs))
if not matched:
raise OutputNotFoundError(path, self)
return matched
def is_dvc_internal(self, path):
path_parts = self.fs.path.normpath(path).split(self.fs.sep)
return self.DVC_DIR in path_parts
@cached_property
def datafs(self):
from dvc.fs.data import DataFileSystem
return DataFileSystem(repo=self)
@cached_property
def dvcfs(self):
from dvc.fs.dvc import DvcFileSystem
return DvcFileSystem(
repo=self, subrepos=self.subrepos, **self._fs_conf
)
@cached_property
def index_db_dir(self):
return self._get_database_dir("index")
@contextmanager
def open_by_relpath(self, path, remote=None, mode="r", encoding=None):
"""Opens a specified resource as a file descriptor"""
from dvc.fs.data import DataFileSystem
from dvc.fs.dvc import DvcFileSystem
if os.path.isabs(path):
fs = DataFileSystem(repo=self, workspace="local")
fs_path = path
else:
fs = DvcFileSystem(repo=self, subrepos=True)
fs_path = fs.from_os_path(path)
try:
with fs.open(
fs_path,
mode=mode,
encoding=encoding,
remote=remote,
) as fobj:
yield fobj
except FileNotFoundError as exc:
raise FileMissingError(path) from exc
except IsADirectoryError as exc:
raise DvcIsADirectoryError(f"'{path}' is a directory") from exc
def close(self):
self.scm.close()
self.state.close()
def _reset(self):
self.state.close()
self.scm._reset() # pylint: disable=protected-access
self.__dict__.pop("index", None)
self.__dict__.pop("dvcignore", None)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self._reset()
self.scm.close()