From 6475e5cc6aea74c9949d19ce02ef32bd4bc0dba9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Saugat=20Pachhai=20=28=E0=A4=B8=E0=A5=8C=E0=A4=97=E0=A4=BE?= =?UTF-8?q?=E0=A4=A4=29?= Date: Tue, 10 May 2022 17:40:17 +0545 Subject: [PATCH] fs.callbacks: simplify, ensure `None` does not break them, lazy richcallbacks --- dvc/data/checkout.py | 1 - dvc/data/stage.py | 2 +- dvc/fs/_callback.py | 114 +++++++++++++++++++++------------------ dvc/objects/db.py | 1 - dvc/output.py | 1 - dvc/repo/get.py | 1 - dvc/stage/cache.py | 1 - dvc/ui/_rich_progress.py | 10 +++- dvc/utils/fs.py | 2 +- tests/func/test_fs.py | 36 ++++++++++++- 10 files changed, 109 insertions(+), 60 deletions(-) diff --git a/dvc/data/checkout.py b/dvc/data/checkout.py index 05c094d866..df65cb271e 100644 --- a/dvc/data/checkout.py +++ b/dvc/data/checkout.py @@ -139,7 +139,6 @@ def __call__(self, cache, from_path, to_fs, to_path, callback=None): callback, desc=cache.fs.path.name(from_path), bytes=True, - total=-1, ) as cb: transfer( cache.fs, diff --git a/dvc/data/stage.py b/dvc/data/stage.py index 65d041b5e2..a5d5837966 100644 --- a/dvc/data/stage.py +++ b/dvc/data/stage.py @@ -39,7 +39,7 @@ def _upload_file(from_fs_path, fs, odb, upload_odb, callback=None): callback, desc=fs_path.name(from_fs_path), bytes=True, - total=size, + size=size, ) as cb: upload_odb.fs.put_file(stream, tmp_info, size=size, callback=cb) diff --git a/dvc/fs/_callback.py b/dvc/fs/_callback.py index d08990ce24..8eff340dd8 100644 --- a/dvc/fs/_callback.py +++ b/dvc/fs/_callback.py @@ -59,6 +59,14 @@ def __exit__(self, *exc_args): def close(self): """Handle here on exit.""" + def relative_update(self, inc: int = 1) -> None: + inc = inc if inc is not None else 0 + return super().relative_update(inc) + + def absolute_update(self, value: int) -> None: + value = value if value is not None else self.value + return super().absolute_update(value) + @classmethod def as_callback( cls, maybe_callback: Optional["FsspecCallback"] = None @@ -97,21 +105,29 @@ class NoOpCallback(FsspecCallback, fsspec.callbacks.NoOpCallback): class TqdmCallback(FsspecCallback): - def __init__(self, progress_bar: "Tqdm" = None, **tqdm_kwargs): + def __init__( + self, + size: Optional[int] = None, + value: int = 0, + progress_bar: "Tqdm" = None, + **tqdm_kwargs, + ): + tqdm_kwargs["total"] = size or -1 self._tqdm_kwargs = tqdm_kwargs self._progress_bar = progress_bar self._stack = ExitStack() - super().__init__() + super().__init__(size=size, value=value) @cached_property def progress_bar(self): from dvc.progress import Tqdm - return self._stack.enter_context( + progress_bar = ( self._progress_bar if self._progress_bar is not None else Tqdm(**self._tqdm_kwargs) ) + return self._stack.enter_context(progress_bar) def __enter__(self): return self @@ -120,18 +136,13 @@ def close(self): self._stack.close() def set_size(self, size): - if size is not None: - self.progress_bar.total = size - self.progress_bar.refresh() - super().set_size(size) - - def relative_update(self, inc=1): - self.progress_bar.update(inc) - super().relative_update(inc) + # Tqdm tries to be smart when to refresh, + # so we try to force it to re-render. + super().set_size(size) + self.progress_bar.refresh() - def absolute_update(self, value): - self.progress_bar.update_to(value) - super().absolute_update(value) + def call(self, hook_name=None, **kwargs): + self.progress_bar.update_to(self.value, total=self.size) def branch( self, @@ -140,72 +151,73 @@ def branch( kwargs, child: Optional[FsspecCallback] = None, ): - child = child or TqdmCallback(bytes=True, total=-1, desc=path_1) + child = child or TqdmCallback(bytes=True, desc=path_1) return super().branch(path_1, path_2, kwargs, child=child) class RichCallback(FsspecCallback): def __init__( self, + size: Optional[int] = None, + value: int = 0, progress: "RichTransferProgress" = None, desc: str = None, - total: int = None, bytes: bool = False, # pylint: disable=redefined-builtin unit: str = None, disable: bool = False, ) -> None: + self._progress = progress + self.disable = disable + self._task_kwargs = { + "description": desc or "", + "bytes": bytes, + "unit": unit, + "total": size or 0, + "visible": False, + "progress_type": None if bytes else "summary", + } + self._stack = ExitStack() + super().__init__(size=size, value=value) + + @cached_property + def progress(self): from dvc.ui import ui from dvc.ui._rich_progress import RichTransferProgress - self.progress = progress or RichTransferProgress( + if self._progress is not None: + return self._progress + + progress = RichTransferProgress( transient=True, - disable=disable, + disable=self.disable, console=ui.error_console, ) - self.visible = not disable - self._newly_created = progress is None - total = 0 if total is None or total < 0 else total - self.task = self.progress.add_task( - description=desc or "", - total=total, - bytes=bytes, - visible=False, - unit=unit, - progress_type=None if bytes else "summary", - ) - super().__init__() + return self._stack.enter_context(progress) + + @cached_property + def task(self): + return self.progress.add_task(**self._task_kwargs) def __enter__(self): - if self._newly_created: - self.progress.__enter__() return self def close(self): - if self._newly_created: - self.progress.stop() - try: - self.progress.remove_task(self.task) - except KeyError: - pass - - def set_size(self, size: int = None) -> None: - if size is not None: - self.progress.update(self.task, total=size, visible=self.visible) - super().set_size(size) - - def relative_update(self, inc: int = 1) -> None: - self.progress.update(self.task, advance=inc) - super().relative_update(inc) + self.progress.clear_task(self.task) + self._stack.close() - def absolute_update(self, value: int) -> None: - self.progress.update(self.task, completed=value) - super().absolute_update(value) + def call(self, hook_name=None, **kwargs): + self.progress.update( + self.task, + completed=self.value, + total=self.size, + visible=not self.disable, + ) def branch( self, path_1, path_2, kwargs, child: Optional[FsspecCallback] = None ): child = child or RichCallback( - self.progress, desc=path_1, bytes=True, total=-1 + progress=self.progress, desc=path_1, bytes=True ) return super().branch(path_1, path_2, kwargs, child=child) diff --git a/dvc/objects/db.py b/dvc/objects/db.py index 66f5f2f850..d583968fce 100644 --- a/dvc/objects/db.py +++ b/dvc/objects/db.py @@ -134,7 +134,6 @@ def add( callback, desc=fs.path.name(fs_path), bytes=True, - total=-1, ) as cb: self._add_file( fs, diff --git a/dvc/output.py b/dvc/output.py index b2ac7d7ec5..8e1ce94fa7 100644 --- a/dvc/output.py +++ b/dvc/output.py @@ -701,7 +701,6 @@ def download(self, to, jobs=None): from dvc.fs._callback import FsspecCallback with FsspecCallback.as_tqdm_callback( - total=-1, desc=f"Downloading {self.fs.path.name(self.fs_path)}", unit="files", ) as cb: diff --git a/dvc/repo/get.py b/dvc/repo/get.py index 6373f429cb..0a1224479d 100644 --- a/dvc/repo/get.py +++ b/dvc/repo/get.py @@ -61,7 +61,6 @@ def get(url, path, out=None, rev=None, jobs=None): fs_path = fs.from_os_path(path) with FsspecCallback.as_tqdm_callback( - total=-1, desc=f"Downloading {fs.path.name(path)}", unit="files", ) as cb: diff --git a/dvc/stage/cache.py b/dvc/stage/cache.py index e4e9225469..c85cb5eff0 100644 --- a/dvc/stage/cache.py +++ b/dvc/stage/cache.py @@ -239,7 +239,6 @@ def transfer(self, from_odb, to_odb): with FsspecCallback.as_tqdm_callback( desc=src_name, bytes=True, - total=-1, ) as cb: func(from_fs, src, to_fs, dst, callback=cb) ret.append((parent_name, src_name)) diff --git a/dvc/ui/_rich_progress.py b/dvc/ui/_rich_progress.py index afd185a804..0373d7da1e 100644 --- a/dvc/ui/_rich_progress.py +++ b/dvc/ui/_rich_progress.py @@ -20,7 +20,15 @@ def render(self, task): return ret.append(f" {unit}") if unit else ret -class RichTransferProgress(Progress): +class RichProgress(Progress): + def clear_task(self, task): + try: + self.remove_task(task) + except KeyError: + pass + + +class RichTransferProgress(RichProgress): SUMMARY_COLS = ( TextColumn("[magenta]{task.description}[bold green]"), MofNCompleteColumnWithUnit(), diff --git a/dvc/utils/fs.py b/dvc/utils/fs.py index 7a71fb06e6..f7cb65ab6b 100644 --- a/dvc/utils/fs.py +++ b/dvc/utils/fs.py @@ -205,7 +205,7 @@ def copyfile(src, dest, callback=None, no_progress_bar=False, name=None): with open(src, "rb") as fsrc, open(dest, "wb+") as fdest: with FsspecCallback.as_tqdm_callback( callback, - total=total, + size=total, bytes=True, disable=no_progress_bar, desc=name, diff --git a/tests/func/test_fs.py b/tests/func/test_fs.py index 5610e433c2..90e0143b16 100644 --- a/tests/func/test_fs.py +++ b/tests/func/test_fs.py @@ -3,8 +3,10 @@ from operator import itemgetter from os.path import join +import pytest + from dvc.fs import get_cloud_fs -from dvc.fs._callback import FsspecCallback +from dvc.fs._callback import DEFAULT_CALLBACK, FsspecCallback from dvc.fs.local import LocalFileSystem from dvc.repo import Repo @@ -323,3 +325,35 @@ def test_callback_on_repo_fs(tmp_dir, dvc, scm, mocker): assert branch.call_count == 1 assert branch.spy_return.size == size assert branch.spy_return.value == size + + +@pytest.mark.parametrize( + "api", ["set_size", "relative_update", "absolute_update"] +) +@pytest.mark.parametrize( + "callback_factory, kwargs", + [ + (FsspecCallback.as_callback, {}), + (FsspecCallback.as_tqdm_callback, {"desc": "test"}), + (FsspecCallback.as_rich_callback, {"desc": "test"}), + ], +) +def test_callback_with_none(request, api, callback_factory, kwargs, mocker): + """ + Test that callback don't fail if they receive None. + + The callbacks should not receive None, but there may be some + filesystems that are not compliant, we may want to maintain + maximum compatibility, and not break UI in these edge-cases. + See https://github.com/iterative/dvc/issues/7704. + """ + callback = callback_factory(**kwargs) + request.addfinalizer(callback.close) + + call_mock = mocker.spy(callback, "call") + method = getattr(callback, api) + method(None) + call_mock.assert_called_once_with() + if callback is not DEFAULT_CALLBACK: + assert callback.size is None + assert callback.value == 0