Skip to content

Commit

Permalink
output/remote: get rid of NamedCache (#6008)
Browse files Browse the repository at this point in the history
* objects: make HashInfo & HashFile hashable

* stage/output: collect objects (HashFiles) instead of NamedCache

* remote/data_cloud: use objs instead of NamedCache

* repo: use objects instead of NamedCache

* rerun black after rebase

* remove NamedCache usage from run-cache

* drop NamedCache

* update HashInfo.__hash__

* update unit tests

* revert remote_name change

* start cleaning up test failures

* fix tree obj expansion

* update missing cache logging

* fix external dict handling

* update naive odb.gc

* update state save behavior for pull

* update test_data_cloud

* update test_remote

* fix gc with multiple repos

* objects.tree: support filtering w/preserved hash_info

* fix stage cache restore

* drop fs.path_info from HashFile.__hash__

* objects: add optional name attribute

* revert missing cache logging name/suffix changes

* revert name changes in remote tests

* review fixes

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update dvc/remote/base.py

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: Ruslan Kuprieiev <kupruser@gmail.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
3 people committed Jun 2, 2021
1 parent 5fb181f commit eb46c03
Show file tree
Hide file tree
Showing 21 changed files with 402 additions and 344 deletions.
85 changes: 53 additions & 32 deletions dvc/data_cloud.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
"""Manages dvc remotes that user can use with push/pull/status commands."""

import logging
from typing import TYPE_CHECKING, Iterable, Optional

if TYPE_CHECKING:
from dvc.objects.file import HashFile
from dvc.remote.base import Remote

logger = logging.getLogger(__name__)

Expand All @@ -19,7 +24,11 @@ class DataCloud:
def __init__(self, repo):
self.repo = repo

def get_remote(self, name=None, command="<command>"):
def get_remote(
self,
name: Optional[str] = None,
command: str = "<command>",
) -> "Remote":
from dvc.config import NoRemoteError

if not name:
Expand Down Expand Up @@ -48,76 +57,88 @@ def _init_remote(self, name):

return get_remote(self.repo, name=name)

def push(self, cache, jobs=None, remote=None, show_checksums=False):
def push(
self,
objs: Iterable["HashFile"],
jobs: Optional[int] = None,
remote: Optional[str] = None,
show_checksums: bool = False,
):
"""Push data items in a cloud-agnostic way.
Args:
cache (NamedCache): named checksums to push to the cloud.
jobs (int): number of jobs that can be running simultaneously.
remote (dvc.remote.base.BaseRemote): optional remote to push to.
objs: objects to push to the cloud.
jobs: number of jobs that can be running simultaneously.
remote: optional remote to push to.
By default remote from core.remote config option is used.
show_checksums (bool): show checksums instead of file names in
show_checksums: show checksums instead of file names in
information messages.
"""
remote = self.get_remote(remote, "push")
remote_obj = self.get_remote(remote, "push")

return remote.push(
return remote_obj.push(
self.repo.odb.local,
cache,
objs,
jobs=jobs,
show_checksums=show_checksums,
)

def pull(self, cache, jobs=None, remote=None, show_checksums=False):
def pull(
self,
objs: Iterable["HashFile"],
jobs: Optional[int] = None,
remote: Optional[str] = None,
show_checksums: bool = False,
):
"""Pull data items in a cloud-agnostic way.
Args:
cache (NamedCache): named checksums to pull from the cloud.
jobs (int): number of jobs that can be running simultaneously.
remote (dvc.remote.base.BaseRemote): optional remote to pull from.
objs: objects to pull from the cloud.
jobs: number of jobs that can be running simultaneously.
remote: optional remote to pull from.
By default remote from core.remote config option is used.
show_checksums (bool): show checksums instead of file names in
show_checksums: show checksums instead of file names in
information messages.
"""
remote = self.get_remote(remote, "pull")
remote_obj = self.get_remote(remote, "pull")

return remote.pull(
return remote_obj.pull(
self.repo.odb.local,
cache,
objs,
jobs=jobs,
show_checksums=show_checksums,
)

def status(
self,
cache,
jobs=None,
remote=None,
show_checksums=False,
log_missing=True,
objs: Iterable["HashFile"],
jobs: Optional[int] = None,
remote: Optional[str] = None,
show_checksums: bool = False,
log_missing: bool = True,
):
"""Check status of data items in a cloud-agnostic way.
Args:
cache (NamedCache): named checksums to check status for.
jobs (int): number of jobs that can be running simultaneously.
remote (dvc.remote.base.BaseRemote): optional remote to compare
objs: objects to check status for.
jobs: number of jobs that can be running simultaneously.
remote: optional remote to compare
cache to. By default remote from core.remote config option
is used.
show_checksums (bool): show checksums instead of file names in
show_checksums: show checksums instead of file names in
information messages.
log_missing (bool): log warning messages if file doesn't exist
log_missing: log warning messages if file doesn't exist
neither in cache, neither in cloud.
"""
remote = self.get_remote(remote, "status")
return remote.status(
remote_obj = self.get_remote(remote, "status")
return remote_obj.status(
self.repo.odb.local,
cache,
objs,
jobs=jobs,
show_checksums=show_checksums,
log_missing=log_missing,
)

def get_url_for(self, remote, checksum):
remote = self.get_remote(remote)
return str(remote.odb.hash_to_path_info(checksum))
remote_obj = self.get_remote(remote)
return str(remote_obj.odb.hash_to_path_info(checksum))
11 changes: 9 additions & 2 deletions dvc/dependency/repo.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
from typing import NamedTuple, Optional

from voluptuous import Required

from dvc.path_info import PathInfo

from .base import Dependency


class RepoPair(NamedTuple):
url: str
rev: Optional[str] = None


class RepoDependency(Dependency):
PARAM_REPO = "repo"
PARAM_URL = "url"
Expand All @@ -31,10 +38,10 @@ def is_in_repo(self):
return False

@property
def repo_pair(self):
def repo_pair(self) -> RepoPair:
d = self.def_repo
rev = d.get(self.PARAM_REV_LOCK) or d.get(self.PARAM_REV)
return d[self.PARAM_URL], rev
return RepoPair(d[self.PARAM_URL], rev)

def __str__(self):
return "{} ({})".format(self.def_path, self.def_repo[self.PARAM_URL])
Expand Down
3 changes: 3 additions & 0 deletions dvc/hash_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ def __bool__(self):
def __str__(self):
return f"{self.name}: {self.value}"

def __hash__(self):
return hash((self.name, self.value))

@classmethod
def from_dict(cls, d):
_d = d.copy() if d else {}
Expand Down
98 changes: 0 additions & 98 deletions dvc/objects/db/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from collections import defaultdict

from dvc.scheme import Schemes


Expand Down Expand Up @@ -80,99 +78,3 @@ def __getattr__(self, name):
def by_scheme(self):
self._init_odb(self.CLOUD_SCHEMES)
yield from self._odb.items()


class NamedCacheItem:
def __init__(self):
self.names = set()
self.children = defaultdict(NamedCacheItem)

def __eq__(self, other):
return self.names == other.names and self.children == other.children

def child_keys(self):
for key, child in self.children.items():
yield key
yield from child.child_keys()

def child_names(self):
for key, child in self.children.items():
yield key, child.names
yield from child.child_names()

def add(self, checksum, item):
self.children[checksum].update(item)

def update(self, item, suffix=""):
if suffix:
self.names.update(n + suffix for n in item.names)
else:
self.names.update(item.names)
for checksum, child_item in item.children.items():
self.children[checksum].update(child_item)


class NamedCache:
# pylint: disable=protected-access
def __init__(self):
self._items = defaultdict(lambda: defaultdict(NamedCacheItem))
self.external = defaultdict(set)

@classmethod
def make(cls, scheme, checksum, name):
cache = cls()
cache.add(scheme, checksum, name)
return cache

def __getitem__(self, key):
return self._items[key]

def add(self, scheme, checksum, name):
"""Add a mapped name for the specified checksum."""
self._items[scheme][checksum].names.add(name)

def add_child_cache(self, checksum, cache, suffix=""):
"""Add/update child cache for the specified checksum."""
for scheme, src in cache._items.items():
dst = self._items[scheme][checksum].children
for child_checksum, item in src.items():
dst[child_checksum].update(item, suffix=suffix)

for repo_pair, files in cache.external.items():
self.external[repo_pair].update(files)

def add_external(self, url, rev, path):
self.external[url, rev].add(path)

def update(self, cache, suffix=""):
for scheme, src in cache._items.items():
dst = self._items[scheme]
for checksum, item in src.items():
dst[checksum].update(item, suffix=suffix)

for repo_pair, files in cache.external.items():
self.external[repo_pair].update(files)

def scheme_keys(self, scheme):
"""Iterate over a flat list of all keys for the specified scheme,
including children.
"""
for key, item in self._items[scheme].items():
yield key
yield from item.child_keys()

def scheme_names(self, scheme):
"""Iterate over a flat list of checksum, names items for the specified
scheme, including children.
"""
for key, item in self._items[scheme].items():
yield key, item.names
yield from item.child_names()

def dir_keys(self, scheme):
return (
key for key, item in self._items[scheme].items() if item.children
)

def child_keys(self, scheme, checksum):
return self._items[scheme][checksum].child_keys()
11 changes: 10 additions & 1 deletion dvc/objects/db/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from dvc.objects.errors import ObjectFormatError
from dvc.objects.file import HashFile
from dvc.objects.tree import Tree
from dvc.progress import Tqdm

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -317,14 +318,22 @@ def _remove_unpacked_dir(self, hash_):
pass

def gc(self, used, jobs=None):
used_hashes = set()
for obj in used:
used_hashes.add(obj.hash_info.value)
if isinstance(obj, Tree):
used_hashes.update(
entry_obj.hash_info.value for _, entry_obj in obj
)

removed = False
# hashes must be sorted to ensure we always remove .dir files first
for hash_ in sorted(
self.all(jobs, str(self.path_info)),
key=self.fs.is_dir_hash,
reverse=True,
):
if hash_ in used:
if hash_ in used_hashes:
continue
path_info = self.hash_to_path_info(hash_)
if self.fs.is_dir_hash(hash_):
Expand Down
24 changes: 23 additions & 1 deletion dvc/objects/file.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,30 @@
import errno
import logging
import os
from typing import TYPE_CHECKING, Optional

from .errors import ObjectFormatError

if TYPE_CHECKING:
from dvc.fs.base import BaseFileSystem
from dvc.hash_info import HashInfo
from dvc.types import DvcPath

logger = logging.getLogger(__name__)


class HashFile:
def __init__(self, path_info, fs, hash_info):
def __init__(
self,
path_info: Optional["DvcPath"],
fs: Optional["BaseFileSystem"],
hash_info: "HashInfo",
name: Optional[str] = None,
):
self.path_info = path_info
self.fs = fs
self.hash_info = hash_info
self.name = name

@property
def size(self):
Expand All @@ -37,6 +50,15 @@ def __eq__(self, other):
and self.hash_info == other.hash_info
)

def __hash__(self):
return hash(
(
self.hash_info,
self.path_info,
self.fs.scheme if self.fs else None,
)
)

def check(self, odb, check_hash=True):
from .stage import get_file_hash

Expand Down

0 comments on commit eb46c03

Please sign in to comment.