Skip to content

Commit

Permalink
fs.callbacks: simplify, ensure None does not break them, lazy richc…
Browse files Browse the repository at this point in the history
…allbacks
  • Loading branch information
skshetry committed May 10, 2022
1 parent f10dbef commit 6475e5c
Show file tree
Hide file tree
Showing 10 changed files with 109 additions and 60 deletions.
1 change: 0 additions & 1 deletion dvc/data/checkout.py
Expand Up @@ -139,7 +139,6 @@ def __call__(self, cache, from_path, to_fs, to_path, callback=None):
callback,
desc=cache.fs.path.name(from_path),
bytes=True,
total=-1,
) as cb:
transfer(
cache.fs,
Expand Down
2 changes: 1 addition & 1 deletion dvc/data/stage.py
Expand Up @@ -39,7 +39,7 @@ def _upload_file(from_fs_path, fs, odb, upload_odb, callback=None):
callback,
desc=fs_path.name(from_fs_path),
bytes=True,
total=size,
size=size,
) as cb:
upload_odb.fs.put_file(stream, tmp_info, size=size, callback=cb)

Expand Down
114 changes: 63 additions & 51 deletions dvc/fs/_callback.py
Expand Up @@ -59,6 +59,14 @@ def __exit__(self, *exc_args):
def close(self):
"""Handle here on exit."""

def relative_update(self, inc: int = 1) -> None:
inc = inc if inc is not None else 0
return super().relative_update(inc)

def absolute_update(self, value: int) -> None:
value = value if value is not None else self.value
return super().absolute_update(value)

@classmethod
def as_callback(
cls, maybe_callback: Optional["FsspecCallback"] = None
Expand Down Expand Up @@ -97,21 +105,29 @@ class NoOpCallback(FsspecCallback, fsspec.callbacks.NoOpCallback):


class TqdmCallback(FsspecCallback):
def __init__(self, progress_bar: "Tqdm" = None, **tqdm_kwargs):
def __init__(
self,
size: Optional[int] = None,
value: int = 0,
progress_bar: "Tqdm" = None,
**tqdm_kwargs,
):
tqdm_kwargs["total"] = size or -1
self._tqdm_kwargs = tqdm_kwargs
self._progress_bar = progress_bar
self._stack = ExitStack()
super().__init__()
super().__init__(size=size, value=value)

@cached_property
def progress_bar(self):
from dvc.progress import Tqdm

return self._stack.enter_context(
progress_bar = (
self._progress_bar
if self._progress_bar is not None
else Tqdm(**self._tqdm_kwargs)
)
return self._stack.enter_context(progress_bar)

def __enter__(self):
return self
Expand All @@ -120,18 +136,13 @@ def close(self):
self._stack.close()

def set_size(self, size):
if size is not None:
self.progress_bar.total = size
self.progress_bar.refresh()
super().set_size(size)

def relative_update(self, inc=1):
self.progress_bar.update(inc)
super().relative_update(inc)
# Tqdm tries to be smart when to refresh,
# so we try to force it to re-render.
super().set_size(size)
self.progress_bar.refresh()

def absolute_update(self, value):
self.progress_bar.update_to(value)
super().absolute_update(value)
def call(self, hook_name=None, **kwargs):
self.progress_bar.update_to(self.value, total=self.size)

def branch(
self,
Expand All @@ -140,72 +151,73 @@ def branch(
kwargs,
child: Optional[FsspecCallback] = None,
):
child = child or TqdmCallback(bytes=True, total=-1, desc=path_1)
child = child or TqdmCallback(bytes=True, desc=path_1)
return super().branch(path_1, path_2, kwargs, child=child)


class RichCallback(FsspecCallback):
def __init__(
self,
size: Optional[int] = None,
value: int = 0,
progress: "RichTransferProgress" = None,
desc: str = None,
total: int = None,
bytes: bool = False, # pylint: disable=redefined-builtin
unit: str = None,
disable: bool = False,
) -> None:
self._progress = progress
self.disable = disable
self._task_kwargs = {
"description": desc or "",
"bytes": bytes,
"unit": unit,
"total": size or 0,
"visible": False,
"progress_type": None if bytes else "summary",
}
self._stack = ExitStack()
super().__init__(size=size, value=value)

@cached_property
def progress(self):
from dvc.ui import ui
from dvc.ui._rich_progress import RichTransferProgress

self.progress = progress or RichTransferProgress(
if self._progress is not None:
return self._progress

progress = RichTransferProgress(
transient=True,
disable=disable,
disable=self.disable,
console=ui.error_console,
)
self.visible = not disable
self._newly_created = progress is None
total = 0 if total is None or total < 0 else total
self.task = self.progress.add_task(
description=desc or "",
total=total,
bytes=bytes,
visible=False,
unit=unit,
progress_type=None if bytes else "summary",
)
super().__init__()
return self._stack.enter_context(progress)

@cached_property
def task(self):
return self.progress.add_task(**self._task_kwargs)

def __enter__(self):
if self._newly_created:
self.progress.__enter__()
return self

def close(self):
if self._newly_created:
self.progress.stop()
try:
self.progress.remove_task(self.task)
except KeyError:
pass

def set_size(self, size: int = None) -> None:
if size is not None:
self.progress.update(self.task, total=size, visible=self.visible)
super().set_size(size)

def relative_update(self, inc: int = 1) -> None:
self.progress.update(self.task, advance=inc)
super().relative_update(inc)
self.progress.clear_task(self.task)
self._stack.close()

def absolute_update(self, value: int) -> None:
self.progress.update(self.task, completed=value)
super().absolute_update(value)
def call(self, hook_name=None, **kwargs):
self.progress.update(
self.task,
completed=self.value,
total=self.size,
visible=not self.disable,
)

def branch(
self, path_1, path_2, kwargs, child: Optional[FsspecCallback] = None
):
child = child or RichCallback(
self.progress, desc=path_1, bytes=True, total=-1
progress=self.progress, desc=path_1, bytes=True
)
return super().branch(path_1, path_2, kwargs, child=child)

Expand Down
1 change: 0 additions & 1 deletion dvc/objects/db.py
Expand Up @@ -134,7 +134,6 @@ def add(
callback,
desc=fs.path.name(fs_path),
bytes=True,
total=-1,
) as cb:
self._add_file(
fs,
Expand Down
1 change: 0 additions & 1 deletion dvc/output.py
Expand Up @@ -701,7 +701,6 @@ def download(self, to, jobs=None):
from dvc.fs._callback import FsspecCallback

with FsspecCallback.as_tqdm_callback(
total=-1,
desc=f"Downloading {self.fs.path.name(self.fs_path)}",
unit="files",
) as cb:
Expand Down
1 change: 0 additions & 1 deletion dvc/repo/get.py
Expand Up @@ -61,7 +61,6 @@ def get(url, path, out=None, rev=None, jobs=None):
fs_path = fs.from_os_path(path)

with FsspecCallback.as_tqdm_callback(
total=-1,
desc=f"Downloading {fs.path.name(path)}",
unit="files",
) as cb:
Expand Down
1 change: 0 additions & 1 deletion dvc/stage/cache.py
Expand Up @@ -239,7 +239,6 @@ def transfer(self, from_odb, to_odb):
with FsspecCallback.as_tqdm_callback(
desc=src_name,
bytes=True,
total=-1,
) as cb:
func(from_fs, src, to_fs, dst, callback=cb)
ret.append((parent_name, src_name))
Expand Down
10 changes: 9 additions & 1 deletion dvc/ui/_rich_progress.py
Expand Up @@ -20,7 +20,15 @@ def render(self, task):
return ret.append(f" {unit}") if unit else ret


class RichTransferProgress(Progress):
class RichProgress(Progress):
def clear_task(self, task):
try:
self.remove_task(task)
except KeyError:
pass


class RichTransferProgress(RichProgress):
SUMMARY_COLS = (
TextColumn("[magenta]{task.description}[bold green]"),
MofNCompleteColumnWithUnit(),
Expand Down
2 changes: 1 addition & 1 deletion dvc/utils/fs.py
Expand Up @@ -205,7 +205,7 @@ def copyfile(src, dest, callback=None, no_progress_bar=False, name=None):
with open(src, "rb") as fsrc, open(dest, "wb+") as fdest:
with FsspecCallback.as_tqdm_callback(
callback,
total=total,
size=total,
bytes=True,
disable=no_progress_bar,
desc=name,
Expand Down
36 changes: 35 additions & 1 deletion tests/func/test_fs.py
Expand Up @@ -3,8 +3,10 @@
from operator import itemgetter
from os.path import join

import pytest

from dvc.fs import get_cloud_fs
from dvc.fs._callback import FsspecCallback
from dvc.fs._callback import DEFAULT_CALLBACK, FsspecCallback
from dvc.fs.local import LocalFileSystem
from dvc.repo import Repo

Expand Down Expand Up @@ -323,3 +325,35 @@ def test_callback_on_repo_fs(tmp_dir, dvc, scm, mocker):
assert branch.call_count == 1
assert branch.spy_return.size == size
assert branch.spy_return.value == size


@pytest.mark.parametrize(
"api", ["set_size", "relative_update", "absolute_update"]
)
@pytest.mark.parametrize(
"callback_factory, kwargs",
[
(FsspecCallback.as_callback, {}),
(FsspecCallback.as_tqdm_callback, {"desc": "test"}),
(FsspecCallback.as_rich_callback, {"desc": "test"}),
],
)
def test_callback_with_none(request, api, callback_factory, kwargs, mocker):
"""
Test that callback don't fail if they receive None.
The callbacks should not receive None, but there may be some
filesystems that are not compliant, we may want to maintain
maximum compatibility, and not break UI in these edge-cases.
See https://github.com/iterative/dvc/issues/7704.
"""
callback = callback_factory(**kwargs)
request.addfinalizer(callback.close)

call_mock = mocker.spy(callback, "call")
method = getattr(callback, api)
method(None)
call_mock.assert_called_once_with()
if callback is not DEFAULT_CALLBACK:
assert callback.size is None
assert callback.value == 0

0 comments on commit 6475e5c

Please sign in to comment.