diff --git a/.gitignore b/.gitignore index 514e9ba8bb..d2a206fe3d 100644 --- a/.gitignore +++ b/.gitignore @@ -23,6 +23,7 @@ innosetup/config.ini *.exe .coverage +.coverage.* *.swp diff --git a/.mailmap b/.mailmap new file mode 100644 index 0000000000..a8eae84140 --- /dev/null +++ b/.mailmap @@ -0,0 +1,6 @@ +Paweł Redzyński +Dmitry Petrov +Earl Hathaway +Nabanita Dash +Kurian Benoy +Sritanu Chakraborty diff --git a/dvc/logger.py b/dvc/logger.py index 45432eed4b..f9fe77b8d1 100644 --- a/dvc/logger.py +++ b/dvc/logger.py @@ -3,6 +3,7 @@ from __future__ import unicode_literals from dvc.utils.compat import str, StringIO +from dvc.progress import Tqdm import logging import logging.handlers @@ -53,9 +54,6 @@ class ColorFormatter(logging.Formatter): ) def format(self, record): - if self._is_visible(record): - self._progress_aware() - if record.levelname == "INFO": return record.msg @@ -146,20 +144,25 @@ def _parse_exc(self, exc_info): return (exception, stack_trace) - def _progress_aware(self): - """Add a new line if progress bar hasn't finished""" - from dvc.progress import progress - - if not progress.is_finished: - progress._print() - progress.clearln() - class LoggerHandler(logging.StreamHandler): def handleError(self, record): super(LoggerHandler, self).handleError(record) raise LoggingException(record) + def emit(self, record): + """Write to Tqdm's stream so as to not break progressbars""" + try: + msg = self.format(record) + Tqdm.write( + msg, file=self.stream, end=getattr(self, "terminator", "\n") + ) + self.flush() + except RecursionError: + raise + except Exception: + self.handleError(record) + def setup(level=logging.INFO): colorama.init() diff --git a/dvc/output/base.py b/dvc/output/base.py index 8b55bd71f3..48fd5bebbd 100644 --- a/dvc/output/base.py +++ b/dvc/output/base.py @@ -277,6 +277,7 @@ def download(self, to): def checkout(self, force=False, progress_callback=None, tag=None): if not self.use_cache: + progress_callback(str(self.path_info), self.get_files_number()) return if tag: @@ -313,13 +314,10 @@ def move(self, out): self.repo.scm.ignore(self.fspath) def get_files_number(self): - if not self.use_cache or not self.checksum: + if not self.use_cache: return 0 - if self.is_dir_checksum: - return len(self.dir_cache) - - return 1 + return self.cache.get_files_number(self.checksum) def unprotect(self): if self.exists: diff --git a/dvc/progress.py b/dvc/progress.py index 31c882f52b..04d8d953d5 100644 --- a/dvc/progress.py +++ b/dvc/progress.py @@ -1,154 +1,94 @@ """Manages progress bars for dvc repo.""" - from __future__ import print_function -from __future__ import unicode_literals - -from dvc.utils.compat import str - -import sys -import threading +import logging +from tqdm import tqdm +from copy import deepcopy +from concurrent.futures import ThreadPoolExecutor -CLEARLINE_PATTERN = "\r\x1b[K" - -class Progress(object): +class TqdmThreadPoolExecutor(ThreadPoolExecutor): """ - Simple multi-target progress bar. + Ensure worker progressbars are cleared away properly. """ - def __init__(self): - self._n_total = 0 - self._n_finished = 0 - self._lock = threading.Lock() - self._line = None - - def set_n_total(self, total): - """Sets total number of targets.""" - self._n_total = total - self._n_finished = 0 - - @property - def is_finished(self): - """Returns if all targets have finished.""" - return self._n_total == self._n_finished - - def clearln(self): - self._print(CLEARLINE_PATTERN, end="") - - def _writeln(self, line): - self.clearln() - self._print(line, end="") - sys.stdout.flush() - - def reset(self): - with self._lock: - self._n_total = 0 - self._n_finished = 0 - self._line = None - - def refresh(self, line=None): - """Refreshes progress bar.""" - # Just go away if it is locked. Will update next time - if not self._lock.acquire(False): - return - - if line is None: - line = self._line - - if sys.stdout.isatty() and line is not None: - self._writeln(line) - self._line = line - - self._lock.release() - - def update_target(self, name, current, total): - """Updates progress bar for a specified target.""" - self.refresh(self._bar(name, current, total)) - - def finish_target(self, name): - """Finishes progress bar for a specified target.""" - # We have to write a msg about finished target - with self._lock: - pbar = self._bar(name, 100, 100) - - if sys.stdout.isatty(): - self.clearln() - - self._print(pbar) - - self._n_finished += 1 - self._line = None - - def _bar(self, target_name, current, total): + def __enter__(self): """ - Make a progress bar out of info, which looks like: - (1/2): [########################################] 100% master.zip + Creates a blank initial dummy progress bar if needed so that workers + are forced to create "nested" bars. """ - bar_len = 30 - - if total is None: - state = 0 - percent = "?% " - else: - total = int(total) - state = int((100 * current) / total) if current < total else 100 - percent = str(state) + "% " - - if self._n_total > 1: - num = "({}/{}): ".format(self._n_finished + 1, self._n_total) - else: - num = "" + blank_bar = Tqdm(bar_format="Multi-Threaded:", leave=False) + if blank_bar.pos > 0: + # already nested - don't need a placeholder bar + blank_bar.close() + self.bar = blank_bar + return super(TqdmThreadPoolExecutor, self).__enter__() - n_sh = int((state * bar_len) / 100) - n_sp = bar_len - n_sh - pbar = "[" + "#" * n_sh + " " * n_sp + "] " + def __exit__(self, *a, **k): + super(TqdmThreadPoolExecutor, self).__exit__(*a, **k) + self.bar.close() - return num + pbar + percent + target_name - - @staticmethod - def _print(*args, **kwargs): - import logging - - logger = logging.getLogger(__name__) - - if logger.getEffectiveLevel() == logging.CRITICAL: - return - - print(*args, **kwargs) - - def __enter__(self): - self._lock.acquire(True) - if self._line is not None: - self.clearln() - - def __exit__(self, typ, value, tbck): - if self._line is not None: - self.refresh() - self._lock.release() - - def __call__(self, seq, name="", total=None): - if total is None: - total = len(seq) - - self.update_target(name, 0, total) - for done, item in enumerate(seq, start=1): - yield item - self.update_target(name, done, total) - self.finish_target(name) - - -class ProgressCallback(object): - def __init__(self, total): - self.total = total - self.current = 0 - progress.reset() - - def update(self, name, progress_to_add=1): - self.current += progress_to_add - progress.update_target(name, self.current, self.total) - - def finish(self, name): - progress.finish_target(name) +class Tqdm(tqdm): + """ + maximum-compatibility tqdm-based progressbars + """ -progress = Progress() # pylint: disable=invalid-name + def __init__( + self, + iterable=None, + disable=None, + bytes=False, # pylint: disable=W0622 + desc_truncate=None, + leave=None, + **kwargs + ): + """ + bytes : shortcut for + `unit='B', unit_scale=True, unit_divisor=1024, miniters=1` + desc_truncate : like `desc` but will truncate to 10 chars + kwargs : anything accepted by `tqdm.tqdm()` + """ + kwargs = deepcopy(kwargs) + if bytes: + for k, v in dict( + unit="B", unit_scale=True, unit_divisor=1024, miniters=1 + ).items(): + kwargs.setdefault(k, v) + if desc_truncate is not None: + kwargs.setdefault("desc", self.truncate(desc_truncate)) + if disable is None: + disable = ( + logging.getLogger(__name__).getEffectiveLevel() + >= logging.CRITICAL + ) + super(Tqdm, self).__init__( + iterable=iterable, disable=disable, leave=leave, **kwargs + ) + + def update_desc(self, desc, n=1, truncate=True): + """ + Calls `set_description(truncate(desc))` and `update(n)` + """ + self.set_description( + self.truncate(desc) if truncate else desc, refresh=False + ) + self.update(n) + + def update_to(self, current, total=None): + if total: + self.total = total # pylint: disable=W0613,W0201 + self.update(current - self.n) + + @classmethod + def truncate(cls, s, max_len=25, end=True, fill="..."): + """ + Guarantee len(output) < max_lenself. + >>> truncate("hello", 4) + '...o' + """ + if len(s) <= max_len: + return s + if len(fill) > max_len: + return fill[-max_len:] if end else fill[:max_len] + i = max_len - len(fill) + return (fill + s[-i:]) if end else (s[:i] + fill) diff --git a/dvc/remote/azure.py b/dvc/remote/azure.py index aadc86a48b..2527b66bad 100644 --- a/dvc/remote/azure.py +++ b/dvc/remote/azure.py @@ -17,7 +17,7 @@ BlockBlobService = None from dvc.utils.compat import urlparse -from dvc.progress import progress +from dvc.progress import Tqdm from dvc.config import Config from dvc.remote.base import RemoteBASE from dvc.path_info import CloudURLInfo @@ -26,14 +26,6 @@ logger = logging.getLogger(__name__) -class Callback(object): - def __init__(self, name): - self.name = name - - def __call__(self, current, total): - progress.update_target(self.name, current, total) - - class RemoteAZURE(RemoteBASE): scheme = Schemes.AZURE path_cls = CloudURLInfo @@ -123,18 +115,28 @@ def list_cache_paths(self): def _upload( self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs ): - cb = None if no_progress_bar else Callback(name) - self.blob_service.create_blob_from_path( - to_info.bucket, to_info.path, from_file, progress_callback=cb - ) + with Tqdm( + desc_truncate=name, disable=no_progress_bar, bytes=True + ) as pbar: + self.blob_service.create_blob_from_path( + to_info.bucket, + to_info.path, + from_file, + progress_callback=pbar.update_to, + ) def _download( self, from_info, to_file, name=None, no_progress_bar=False, **_kwargs ): - cb = None if no_progress_bar else Callback(name) - self.blob_service.get_blob_to_path( - from_info.bucket, from_info.path, to_file, progress_callback=cb - ) + with Tqdm( + desc_truncate=name, disable=no_progress_bar, bytes=True + ) as pbar: + self.blob_service.get_blob_to_path( + from_info.bucket, + from_info.path, + to_file, + progress_callback=pbar.update_to, + ) def exists(self, path_info): paths = self._list_paths(path_info.bucket, path_info.path) diff --git a/dvc/remote/base.py b/dvc/remote/base.py index 050c166590..19031dac86 100644 --- a/dvc/remote/base.py +++ b/dvc/remote/base.py @@ -8,9 +8,9 @@ import logging import tempfile import itertools -from functools import partial from operator import itemgetter from multiprocessing import cpu_count +from functools import partial from concurrent.futures import ThreadPoolExecutor import dvc.prompt as prompt @@ -20,7 +20,7 @@ ConfirmRemoveError, DvcIgnoreInCollectedDirError, ) -from dvc.progress import progress, ProgressCallback +from dvc.progress import Tqdm, TqdmThreadPoolExecutor from dvc.utils import LARGE_DIR_SIZE, tmp_fname, move, relpath, makedirs from dvc.state import StateNoop from dvc.path_info import PathInfo, URLInfo @@ -145,20 +145,20 @@ def get_file_checksum(self, path_info): def _calculate_checksums(self, file_infos): file_infos = list(file_infos) - with ThreadPoolExecutor(max_workers=self.checksum_jobs) as executor: + with TqdmThreadPoolExecutor( + max_workers=self.checksum_jobs + ) as executor: tasks = executor.map(self.get_file_checksum, file_infos) if len(file_infos) > LARGE_DIR_SIZE: - msg = ( - "Computing md5 for a large number of files. " - "This is only done once." + logger.info( + ( + "Computing md5 for a large number of files. " + "This is only done once." + ) ) - logger.info(msg) - tasks = progress(tasks, total=len(file_infos)) - - checksums = { - file_infos[index]: task for index, task in enumerate(tasks) - } + tasks = Tqdm(tasks, total=len(file_infos), unit="md5") + checksums = dict(zip(file_infos, tasks)) return checksums def _collect_dir(self, path_info): @@ -439,9 +439,6 @@ def upload(self, from_info, to_info, name=None, no_progress_bar=False): name = name or from_info.name - if not no_progress_bar: - progress.update_target(name, 0, None) - try: self._upload( from_info.fspath, @@ -454,9 +451,6 @@ def upload(self, from_info, to_info, name=None, no_progress_bar=False): logger.exception(msg.format(from_info, to_info)) return 1 # 1 fail - if not no_progress_bar: - progress.finish_target(name) - return 0 def download( @@ -485,11 +479,6 @@ def download( name = name or to_info.name - if not no_progress_bar: - # real progress is not always available, - # lets at least show start and finish - progress.update_target(name, 0, None) - makedirs(to_info.parent, exist_ok=True, mode=dir_mode) tmp_file = tmp_fname(to_info) @@ -504,9 +493,6 @@ def download( move(tmp_file, to_info, mode=file_mode) - if not no_progress_bar: - progress.finish_target(name) - return 0 def open(self, path_info, mode="r", encoding=None): @@ -644,19 +630,18 @@ def cache_exists(self, checksums, jobs=None): if not self.no_traverse: return list(set(checksums) & set(self.all())) - progress_callback = ProgressCallback(len(checksums)) + with Tqdm(total=len(checksums), unit="md5") as pbar: - def exists_with_progress(path_info): - ret = self.exists(path_info) - progress_callback.update(str(path_info)) - return ret + def exists_with_progress(path_info): + ret = self.exists(path_info) + pbar.update() + return ret - with ThreadPoolExecutor(max_workers=jobs or self.JOBS) as executor: - path_infos = [self.checksum_to_path_info(x) for x in checksums] - in_remote = executor.map(exists_with_progress, path_infos) - ret = list(itertools.compress(checksums, in_remote)) - progress_callback.finish("") - return ret + with ThreadPoolExecutor(max_workers=jobs or self.JOBS) as executor: + path_infos = [self.checksum_to_path_info(x) for x in checksums] + in_remote = executor.map(exists_with_progress, path_infos) + ret = list(itertools.compress(checksums, in_remote)) + return ret def already_cached(self, path_info): current = self.get_checksum(path_info) @@ -694,7 +679,7 @@ def _checkout_file( self.state.save_link(path_info) self.state.save(path_info, checksum) if progress_callback: - progress_callback.update(str(path_info)) + progress_callback(str(path_info)) def makedirs(self, path_info): raise NotImplementedError @@ -724,7 +709,7 @@ def _checkout_dir( self.link(entry_cache_info, entry_info) self.state.save(entry_info, entry_checksum) if progress_callback: - progress_callback.update(str(entry_info)) + progress_callback(str(entry_info)) self._remove_redundant_files(path_info, dir_info, force) @@ -752,23 +737,28 @@ def checkout( raise NotImplementedError checksum = checksum_info.get(self.PARAM_CHECKSUM) + skip = False if not checksum: logger.warning( "No checksum info found for '{}'. " "It won't be created.".format(str(path_info)) ) self.safe_remove(path_info, force=force) - return + skip = True - if not self.changed(path_info, checksum_info): + elif not self.changed(path_info, checksum_info): msg = "Data '{}' didn't change." logger.debug(msg.format(str(path_info))) - return + skip = True - if self.changed_cache(checksum): + elif self.changed_cache(checksum): msg = "Cache '{}' not found. File '{}' won't be created." logger.warning(msg.format(checksum, str(path_info))) self.safe_remove(path_info, force=force) + skip = True + + if skip: + progress_callback(str(path_info), self.get_files_number(checksum)) return msg = "Checking out '{}' with cache '{}'." @@ -787,6 +777,15 @@ def _checkout( path_info, checksum, force, progress_callback=progress_callback ) + def get_files_number(self, checksum): + if not checksum: + return 0 + + if self.is_dir_checksum(checksum): + return len(self.get_dir_cache(checksum)) + + return 1 + @staticmethod def unprotect(path_info): pass diff --git a/dvc/remote/http.py b/dvc/remote/http.py index f377f0ddc2..07c0756ef7 100644 --- a/dvc/remote/http.py +++ b/dvc/remote/http.py @@ -1,35 +1,19 @@ from __future__ import unicode_literals from dvc.scheme import Schemes - from dvc.utils.compat import open -import threading import requests import logging -from dvc.progress import progress +from dvc.progress import Tqdm from dvc.exceptions import DvcException from dvc.config import Config from dvc.remote.base import RemoteBASE - logger = logging.getLogger(__name__) -class ProgressBarCallback(object): - def __init__(self, name, total): - self.name = name - self.total = total - self.current = 0 - self.lock = threading.Lock() - - def __call__(self, byts): - with self.lock: - self.current += byts - progress.update_target(self.name, self.current, self.total) - - class RemoteHTTP(RemoteBASE): scheme = Schemes.HTTP REQUEST_TIMEOUT = 10 @@ -43,36 +27,36 @@ def __init__(self, repo, config): self.path_info = self.path_cls(url) if url else None def _download(self, from_info, to_file, name=None, no_progress_bar=False): - callback = None - if not no_progress_bar: - total = self._content_length(from_info.url) - if total: - callback = ProgressBarCallback(name, total) - request = self._request("GET", from_info.url, stream=True) - - with open(to_file, "wb") as fd: - transferred_bytes = 0 - - for chunk in request.iter_content(chunk_size=self.CHUNK_SIZE): - fd.write(chunk) - fd.flush() - transferred_bytes += len(chunk) - - if callback: - callback(transferred_bytes) + with Tqdm( + total=None if no_progress_bar else self._content_length(from_info), + leave=False, + bytes=True, + desc_truncate=from_info.url if name is None else name, + disable=no_progress_bar, + ) as pbar: + with open(to_file, "wb") as fd: + for chunk in request.iter_content(chunk_size=self.CHUNK_SIZE): + fd.write(chunk) + fd.flush() + pbar.update(len(chunk)) def exists(self, path_info): return bool(self._request("HEAD", path_info.url)) - def _content_length(self, url): - return self._request("HEAD", url).headers.get("Content-Length") + def _content_length(self, url_or_request): + headers = getattr( + url_or_request, + "headers", + self._request("HEAD", url_or_request).headers, + ) + res = headers.get("Content-Length") + return int(res) if res else None def get_file_checksum(self, path_info): url = path_info.url - etag = self._request("HEAD", url).headers.get("ETag") or self._request( - "HEAD", url - ).headers.get("Content-MD5") + headers = self._request("HEAD", url).headers + etag = headers.get("ETag") or headers.get("Content-MD5") if not etag: raise DvcException( diff --git a/dvc/remote/local/__init__.py b/dvc/remote/local/__init__.py index a201ba5bc4..8381e35113 100644 --- a/dvc/remote/local/__init__.py +++ b/dvc/remote/local/__init__.py @@ -35,8 +35,7 @@ ) from dvc.config import Config from dvc.exceptions import DvcException -from dvc.progress import progress -from concurrent.futures import ThreadPoolExecutor +from dvc.progress import Tqdm, TqdmThreadPoolExecutor from dvc.path_info import PathInfo @@ -255,7 +254,7 @@ def move(self, from_info, to_info): def cache_exists(self, checksums, jobs=None): return [ checksum - for checksum in progress(checksums) + for checksum in Tqdm(checksums, unit="md5") if not self.changed_cache_file(checksum) ] @@ -339,7 +338,8 @@ def status( return ret - def _fill_statuses(self, checksum_info_dir, local_exists, remote_exists): + @staticmethod + def _fill_statuses(checksum_info_dir, local_exists, remote_exists): # Using sets because they are way faster for lookups local = set(local_exists) remote = set(remote_exists) @@ -352,8 +352,8 @@ def _get_plans(self, download, remote, status_info, status): cache = [] path_infos = [] names = [] - for md5, info in progress( - status_info.items(), name="Analysing status" + for md5, info in Tqdm( + status_info.items(), desc="Analysing status", unit="file" ): if info["status"] == status: cache.append(self.checksum_to_path_info(md5)) @@ -412,7 +412,7 @@ def _process( return 0 if jobs > 1: - with ThreadPoolExecutor(max_workers=jobs) as executor: + with TqdmThreadPoolExecutor(max_workers=jobs) as executor: fails = sum(executor.map(func, *plans)) else: fails = sum(map(func, *plans)) @@ -443,7 +443,8 @@ def pull(self, checksum_infos, remote, jobs=None, show_checksums=False): download=True, ) - def _log_missing_caches(self, checksum_info_dict): + @staticmethod + def _log_missing_caches(checksum_info_dict): missing_caches = [ (md5, info) for md5, info in checksum_info_dict.items() @@ -451,10 +452,8 @@ def _log_missing_caches(self, checksum_info_dict): ] if missing_caches: missing_desc = "".join( - [ - "\nname: {}, md5: {}".format(info["name"], md5) - for md5, info in missing_caches - ] + "\nname: {}, md5: {}".format(info["name"], md5) + for md5, info in missing_caches ) msg = ( "Some of the cache files do not exist neither locally " @@ -486,8 +485,8 @@ def _unprotect_file(path): os.chmod(path, os.stat(path).st_mode | stat.S_IWRITE) def _unprotect_dir(self, path): - for path in walk_files(path, self.repo.dvcignore): - RemoteLOCAL._unprotect_file(path) + for fname in walk_files(path, self.repo.dvcignore): + RemoteLOCAL._unprotect_file(fname) def unprotect(self, path_info): path = path_info.fspath @@ -546,7 +545,7 @@ def _update_unpacked_dir(self, checksum): def _create_unpacked_dir(self, checksum, dir_info, unpacked_dir_info): self.makedirs(unpacked_dir_info) - for entry in progress(dir_info, name="Created unpacked dir"): + for entry in Tqdm(dir_info, desc="Creating unpacked dir", unit="file"): entry_cache_info = self.checksum_to_path_info( entry[self.PARAM_CHECKSUM] ) diff --git a/dvc/remote/oss.py b/dvc/remote/oss.py index f90fa6100b..57e900087f 100644 --- a/dvc/remote/oss.py +++ b/dvc/remote/oss.py @@ -13,7 +13,7 @@ from dvc.config import Config from dvc.remote.base import RemoteBASE -from dvc.remote.azure import Callback +from dvc.progress import Tqdm from dvc.path_info import CloudURLInfo @@ -107,18 +107,22 @@ def list_cache_paths(self): def _upload( self, from_file, to_info, name=None, no_progress_bar=False, **_kwargs ): - cb = None if no_progress_bar else Callback(name) - self.oss_service.put_object_from_file( - to_info.path, from_file, progress_callback=cb - ) + with Tqdm( + desc_truncate=name, disable=no_progress_bar, bytes=True + ) as pbar: + self.oss_service.put_object_from_file( + to_info.path, from_file, progress_callback=pbar.update_to + ) def _download( self, from_info, to_file, name=None, no_progress_bar=False, **_kwargs ): - cb = None if no_progress_bar else Callback(name) - self.oss_service.get_object_to_file( - from_info.path, to_file, progress_callback=cb - ) + with Tqdm( + desc_truncate=name, disable=no_progress_bar, bytes=True + ) as pbar: + self.oss_service.get_object_to_file( + from_info.path, to_file, progress_callback=pbar.update_to + ) def _generate_download_url(self, path_info, expires=3600): assert path_info.bucket == self.path_info.bucket diff --git a/dvc/remote/s3.py b/dvc/remote/s3.py index 3d79f3eeda..c1bf272db7 100644 --- a/dvc/remote/s3.py +++ b/dvc/remote/s3.py @@ -1,7 +1,6 @@ from __future__ import unicode_literals import os -import threading import logging from funcy import cached_property @@ -10,7 +9,7 @@ except ImportError: boto3 = None -from dvc.progress import progress +from dvc.progress import Tqdm from dvc.config import Config from dvc.remote.base import RemoteBASE from dvc.exceptions import DvcException, ETagMismatchError @@ -20,19 +19,6 @@ logger = logging.getLogger(__name__) -class Callback(object): - def __init__(self, name, total): - self.name = name - self.total = total - self.current = 0 - self.lock = threading.Lock() - - def __call__(self, byts): - with self.lock: - self.current += byts - progress.update_target(self.name, self.current, self.total) - - class RemoteS3(RemoteBASE): scheme = Schemes.S3 path_cls = CloudURLInfo @@ -217,27 +203,36 @@ def exists(self, path_info): def _upload(self, from_file, to_info, name=None, no_progress_bar=False): total = os.path.getsize(from_file) - cb = None if no_progress_bar else Callback(name, total) - self.s3.upload_file( - from_file, - to_info.bucket, - to_info.path, - Callback=cb, - ExtraArgs=self.extra_args, - ) + with Tqdm( + disable=no_progress_bar, + total=total, + bytes=True, + desc_truncate=name, + ) as pbar: + self.s3.upload_file( + from_file, + to_info.bucket, + to_info.path, + Callback=pbar.update, + ExtraArgs=self.extra_args, + ) def _download(self, from_info, to_file, name=None, no_progress_bar=False): if no_progress_bar: - cb = None + total = None else: total = self.s3.head_object( Bucket=from_info.bucket, Key=from_info.path )["ContentLength"] - cb = Callback(name, total) - - self.s3.download_file( - from_info.bucket, from_info.path, to_file, Callback=cb - ) + with Tqdm( + disable=no_progress_bar, + total=total, + bytes=True, + desc_truncate=name, + ) as pbar: + self.s3.download_file( + from_info.bucket, from_info.path, to_file, Callback=pbar.update + ) def _generate_download_url(self, path_info, expires=3600): params = {"Bucket": path_info.bucket, "Key": path_info.path} diff --git a/dvc/remote/ssh/__init__.py b/dvc/remote/ssh/__init__.py index fee5fd0928..086a2fd717 100644 --- a/dvc/remote/ssh/__init__.py +++ b/dvc/remote/ssh/__init__.py @@ -10,9 +10,6 @@ from concurrent.futures import ThreadPoolExecutor from contextlib import contextmanager, closing -from dvc.progress import ProgressCallback -from dvc.utils import to_chunks - try: import paramiko except ImportError: @@ -20,10 +17,12 @@ import dvc.prompt as prompt from dvc.config import Config +from dvc.utils import to_chunks from dvc.utils.compat import urlparse, StringIO from dvc.remote.base import RemoteBASE from dvc.scheme import Schemes from dvc.remote.pool import get_connection +from dvc.progress import Tqdm from .connection import SSHConnection @@ -96,7 +95,7 @@ def ssh_config_filename(): @staticmethod def _load_user_ssh_config(hostname): user_config_file = RemoteSSH.ssh_config_filename() - user_ssh_config = dict() + user_ssh_config = {} if hostname and os.path.exists(user_config_file): ssh_config = paramiko.SSHConfig() with open(user_config_file) as f: @@ -241,7 +240,7 @@ def _exists(chunk_and_channel): if exc.errno != errno.ENOENT: raise ret.append(False) - callback.update(path) + callback(path) return ret with self.ssh(path_infos[0]) as ssh: @@ -263,19 +262,20 @@ def cache_exists(self, checksums, jobs=None): faster than current approach (relying on exists(path_info)) applied in remote/base. """ - progress_callback = ProgressCallback(len(checksums)) + if not self.no_traverse: + return list(set(checksums) & set(self.all())) + + with Tqdm(total=len(checksums), unit="md5") as pbar: - def exists_with_progress(chunks): - return self.batch_exists(chunks, callback=progress_callback) + def exists_with_progress(chunks): + return self.batch_exists(chunks, callback=pbar.update_desc) - if self.no_traverse: with ThreadPoolExecutor(max_workers=jobs or self.JOBS) as executor: path_infos = [self.checksum_to_path_info(x) for x in checksums] chunks = to_chunks(path_infos, num_chunks=self.JOBS) results = executor.map(exists_with_progress, chunks) in_remote = itertools.chain.from_iterable(results) ret = list(itertools.compress(checksums, in_remote)) - progress_callback.finish("") return ret - return list(set(checksums) & set(self.all())) + pbar.update_desc("", 0) # clear path name description diff --git a/dvc/remote/ssh/connection.py b/dvc/remote/ssh/connection.py index 8362857b8c..fbb0c2b7bf 100644 --- a/dvc/remote/ssh/connection.py +++ b/dvc/remote/ssh/connection.py @@ -12,7 +12,7 @@ from dvc.utils import tmp_fname from dvc.utils.compat import ignore_file_not_found -from dvc.progress import progress +from dvc.progress import Tqdm from dvc.exceptions import DvcException from dvc.remote.base import RemoteCmdError @@ -29,21 +29,6 @@ def sizeof_fmt(num, suffix="B"): return "%.1f%s%s" % (num, "Y", suffix) -def percent_cb(name, complete, total): - """ Callback for updating target progress """ - logger.debug( - "{}: {} transferred out of {}".format( - name, sizeof_fmt(complete), sizeof_fmt(total) - ) - ) - progress.update_target(name, complete, total) - - -def create_cb(name): - """ Create callback function for multipart object """ - return lambda cur, tot: percent_cb(name, cur, tot) - - class SSHConnection: def __init__(self, host, *args, **kwargs): logger.debug( @@ -183,14 +168,12 @@ def remove(self, path): self._remove_file(path) def download(self, src, dest, no_progress_bar=False, progress_title=None): - if no_progress_bar: - self.sftp.get(src, dest) - else: - if not progress_title: - progress_title = os.path.basename(src) - - self.sftp.get(src, dest, callback=create_cb(progress_title)) - progress.finish_target(progress_title) + with Tqdm( + desc_truncate=progress_title or os.path.basename(src), + disable=no_progress_bar, + bytes=True, + ) as pbar: + self.sftp.get(src, dest, callback=pbar.update_to) def move(self, src, dst): self.makedirs(posixpath.dirname(dst)) @@ -199,15 +182,13 @@ def move(self, src, dst): def upload(self, src, dest, no_progress_bar=False, progress_title=None): self.makedirs(posixpath.dirname(dest)) tmp_file = tmp_fname(dest) + if not progress_title: + progress_title = posixpath.basename(dest) - if no_progress_bar: - self.sftp.put(src, tmp_file) - else: - if not progress_title: - progress_title = posixpath.basename(dest) - - self.sftp.put(src, tmp_file, callback=create_cb(progress_title)) - progress.finish_target(progress_title) + with Tqdm( + desc_truncate=progress_title, disable=no_progress_bar, bytes=True + ) as pbar: + self.sftp.put(src, tmp_file, callback=pbar.update_to) self.sftp.rename(tmp_file, dest) diff --git a/dvc/repo/checkout.py b/dvc/repo/checkout.py index 2d743e41d5..e335763a2d 100644 --- a/dvc/repo/checkout.py +++ b/dvc/repo/checkout.py @@ -3,7 +3,7 @@ import logging from dvc.exceptions import CheckoutErrorSuggestGit -from dvc.progress import ProgressCallback +from dvc.progress import Tqdm logger = logging.getLogger(__name__) @@ -23,13 +23,6 @@ def get_all_files_numbers(stages): return sum(stage.get_all_files_number() for stage in stages) -def get_progress_callback(stages): - total_files_num = get_all_files_numbers(stages) - if total_files_num == 0: - return None - return ProgressCallback(total_files_num) - - def checkout(self, target=None, with_deps=False, force=False, recursive=False): from dvc.stage import StageFileDoesNotExistError, StageFileBadNameError @@ -44,15 +37,18 @@ def checkout(self, target=None, with_deps=False, force=False, recursive=False): with self.state: _cleanup_unused_links(self, all_stages) - progress_callback = get_progress_callback(stages) - - for stage in stages: - if stage.locked: - logger.warning( - "DVC-file '{path}' is locked. Its dependencies are" - " not going to be checked out.".format(path=stage.relpath) - ) - - stage.checkout(force=force, progress_callback=progress_callback) - if progress_callback: - progress_callback.finish("Checkout finished!") + total = get_all_files_numbers(stages) + with Tqdm( + total=total, unit="file", desc="Checkout", disable=total == 0 + ) as pbar: + for stage in stages: + if stage.locked: + logger.warning( + "DVC-file '{path}' is locked. Its dependencies are" + " not going to be checked out.".format( + path=stage.relpath + ) + ) + + stage.checkout(force=force, progress_callback=pbar.update_desc) + pbar.update_desc("Checkout", 0) # clear path name description diff --git a/dvc/utils/__init__.py b/dvc/utils/__init__.py index c5d1fbbc9c..050f65366c 100644 --- a/dvc/utils/__init__.py +++ b/dvc/utils/__init__.py @@ -33,8 +33,8 @@ logger = logging.getLogger(__name__) -LOCAL_CHUNK_SIZE = 1024 * 1024 -LARGE_FILE_SIZE = 1024 * 1024 * 1024 +LOCAL_CHUNK_SIZE = 2 ** 20 # 1 MB +LARGE_FILE_SIZE = 2 ** 30 # 1 GB LARGE_DIR_SIZE = 100 @@ -44,7 +44,7 @@ def dos2unix(data): def file_md5(fname): """ get the (md5 hexdigest, md5 digest) of a file """ - from dvc.progress import progress + from dvc.progress import Tqdm from dvc.istextfile import istextfile if os.path.exists(fname): @@ -56,28 +56,28 @@ def file_md5(fname): bar = True msg = "Computing md5 for a large file {}. This is only done once." logger.info(msg.format(relpath(fname))) - name = relpath(fname) - total = 0 - - with open(fname, "rb") as fobj: - while True: - data = fobj.read(LOCAL_CHUNK_SIZE) - if not data: - break - - if bar: - total += len(data) - progress.update_target(name, total, size) - - if binary: - chunk = data - else: - chunk = dos2unix(data) - - hash_md5.update(chunk) - - if bar: - progress.finish_target(name) + name = relpath(fname) + + with Tqdm( + desc_truncate=name, + disable=not bar, + total=size, + bytes=True, + leave=False, + ) as pbar: + with open(fname, "rb") as fobj: + while True: + data = fobj.read(LOCAL_CHUNK_SIZE) + if not data: + break + + if binary: + chunk = data + else: + chunk = dos2unix(data) + + hash_md5.update(chunk) + pbar.update(len(data)) return (hash_md5.hexdigest(), hash_md5.digest()) @@ -119,10 +119,9 @@ def dict_md5(d, exclude=()): def copyfile(src, dest, no_progress_bar=False, name=None): """Copy file with progress bar""" from dvc.exceptions import DvcException - from dvc.progress import progress + from dvc.progress import Tqdm from dvc.system import System - copied = 0 name = name if name else os.path.basename(dest) total = os.stat(src).st_size @@ -132,18 +131,19 @@ def copyfile(src, dest, no_progress_bar=False, name=None): try: System.reflink(src, dest) except DvcException: - with open(src, "rb") as fsrc, open(dest, "wb+") as fdest: - while True: - buf = fsrc.read(LOCAL_CHUNK_SIZE) - if not buf: - break - fdest.write(buf) - copied += len(buf) - if not no_progress_bar: - progress.update_target(name, copied, total) - - if not no_progress_bar: - progress.finish_target(name) + with Tqdm( + desc_truncate=name, + disable=no_progress_bar, + total=total, + bytes=True, + ) as pbar: + with open(src, "rb") as fsrc, open(dest, "wb+") as fdest: + while True: + buf = fsrc.read(LOCAL_CHUNK_SIZE) + if not buf: + break + fdest.write(buf) + pbar.update(len(buf)) def makedirs(path, exist_ok=False, mode=None): diff --git a/setup.py b/setup.py index 9e98ace7ac..e98ef6be8c 100644 --- a/setup.py +++ b/setup.py @@ -61,6 +61,7 @@ def run(self): "funcy>=1.12", "pathspec>=0.5.9", "shortuuid>=0.5.0", + "tqdm>=4.34.0", "win-unicode-console>=0.5; sys_platform == 'win32'", ] diff --git a/tests/func/test_checkout.py b/tests/func/test_checkout.py index 1aa7a83702..d489a71b9a 100644 --- a/tests/func/test_checkout.py +++ b/tests/func/test_checkout.py @@ -1,14 +1,10 @@ import os -import sys -import re - import shutil import filecmp import collections import logging from dvc.main import main -from dvc import progress from dvc.repo import Repo as DvcRepo from dvc.system import System from dvc.utils import walk_files, relpath @@ -403,108 +399,6 @@ def test(self): self.assertIsNone(exc.cause.cause) -class TestCheckoutShouldHaveSelfClearingProgressBar(TestDvc): - def setUp(self): - super(TestCheckoutShouldHaveSelfClearingProgressBar, self).setUp() - self._prepare_repo() - - def test(self): - with self._caplog.at_level(logging.INFO, logger="dvc"), patch.object( - sys, "stdout" - ) as stdout_mock: - self.stdout_mock = logger.handlers[0].stream = stdout_mock - - ret = main(["checkout"]) - self.assertEqual(0, ret) - - stdout_calls = self.stdout_mock.method_calls - write_calls = self.filter_out_non_write_calls(stdout_calls) - write_calls = self.filter_out_empty_write_calls(write_calls) - self.write_args = [w_c[1][0] for w_c in write_calls] - - pattern = re.compile(".*\\[.{30}\\].*%.*") - progress_bars = [ - arg - for arg in self.write_args - if pattern.match(arg) and "unpacked" not in arg - ] - - update_bars = progress_bars[:-1] - finish_bar = progress_bars[-1] - - self.assertEqual(4, len(update_bars)) - assert re.search(".*\\[#{7} {23}\\] 25%.*", progress_bars[0]) - assert re.search(".*\\[#{15} {15}\\] 50%.*", progress_bars[1]) - assert re.search(".*\\[#{22} {8}\\] 75%.*", progress_bars[2]) - assert re.search(".*\\[#{30}\\] 100%.*", progress_bars[3]) - - self.assertCaretReturnFollowsEach(update_bars) - self.assertNewLineFollows(finish_bar) - - self.assertAnyEndsWith(update_bars, self.FOO) - self.assertAnyEndsWith(update_bars, self.BAR) - self.assertAnyEndsWith(update_bars, self.DATA) - self.assertAnyEndsWith(update_bars, self.DATA_SUB) - - self.assertTrue(finish_bar.endswith("Checkout finished!")) - - def filter_out_empty_write_calls(self, calls): - def is_not_empty_write(call): - assert call[0] == "write" - return call[1][0] != "" - - return list(filter(is_not_empty_write, calls)) - - def filter_out_non_write_calls(self, calls): - def is_write_call(call): - return call[0] == "write" - - return list(filter(is_write_call, calls)) - - def _prepare_repo(self): - storage = self.mkdtemp() - - ret = main(["remote", "add", "-d", "myremote", storage]) - self.assertEqual(0, ret) - - ret = main(["add", self.DATA_DIR]) - self.assertEqual(0, ret) - - ret = main(["add", self.FOO]) - self.assertEqual(0, ret) - - ret = main(["add", self.BAR]) - self.assertEqual(0, ret) - - ret = main(["push"]) - self.assertEqual(0, ret) - - shutil.rmtree(self.DATA_DIR) - os.unlink(self.FOO) - os.unlink(self.BAR) - - def assertCaretReturnFollowsEach(self, update_bars): - for update_bar in update_bars: - - self.assertIn(update_bar, self.write_args) - - for index, arg in enumerate(self.write_args): - if arg == update_bar: - self.assertEqual( - progress.CLEARLINE_PATTERN, self.write_args[index + 1] - ) - - def assertNewLineFollows(self, finish_bar): - self.assertIn(finish_bar, self.write_args) - - for index, arg in enumerate(self.write_args): - if arg == finish_bar: - self.assertEqual("\n", self.write_args[index + 1]) - - def assertAnyEndsWith(self, update_bars, name): - self.assertTrue(any(ub for ub in update_bars if ub.endswith(name))) - - class TestCheckoutTargetRecursiveShouldNotRemoveOtherUsedFiles(TestDvc): def test(self): ret = main(["add", self.DATA_DIR, self.FOO, self.BAR]) diff --git a/tests/func/test_remote.py b/tests/func/test_remote.py index dadae44ff3..b938f97f42 100644 --- a/tests/func/test_remote.py +++ b/tests/func/test_remote.py @@ -162,15 +162,16 @@ def test(self): def test_large_dir_progress(repo_dir, dvc_repo): from dvc.utils import LARGE_DIR_SIZE - from dvc.progress import progress + from dvc.progress import Tqdm # Create a "large dir" for i in range(LARGE_DIR_SIZE + 1): repo_dir.create(os.path.join("gen", "{}.txt".format(i)), str(i)) - with patch.object(progress, "update_target") as update_target: + with patch.object(Tqdm, "update") as update: + assert not update.called dvc_repo.add("gen") - assert update_target.called + assert update.called def test_dir_checksum_should_be_key_order_agnostic(dvc_repo): diff --git a/tests/unit/output/test_local.py b/tests/unit/output/test_local.py index 8149e83ffc..6e2af11323 100644 --- a/tests/unit/output/test_local.py +++ b/tests/unit/output/test_local.py @@ -2,6 +2,7 @@ from dvc.stage import Stage from dvc.output import OutputLOCAL +from dvc.remote.local import RemoteLOCAL from tests.basic_env import TestDvc @@ -32,8 +33,12 @@ def test_return_0_on_no_cache(self): self.assertEqual(0, o.get_files_number()) @patch.object(OutputLOCAL, "checksum", "12345678.dir") - @patch.object(OutputLOCAL, "dir_cache", [{"md5": "asdf"}, {"md5": "qwe"}]) - def test_return_mutiple_for_dir(self): + @patch.object( + RemoteLOCAL, + "get_dir_cache", + return_value=[{"md5": "asdf"}, {"md5": "qwe"}], + ) + def test_return_mutiple_for_dir(self, mock_get_dir_cache): o = self._get_output() self.assertEqual(2, o.get_files_number()) diff --git a/tests/unit/test_logger.py b/tests/unit/test_logger.py index 12272c92f0..57c876b3d3 100644 --- a/tests/unit/test_logger.py +++ b/tests/unit/test_logger.py @@ -151,31 +151,31 @@ def test_nested_exceptions(self, caplog): assert expected == formatter.format(caplog.records[0]) def test_progress_awareness(self, mocker, capsys, caplog): - from dvc.progress import progress + from dvc.progress import Tqdm with mocker.patch("sys.stdout.isatty", return_value=True): - progress.set_n_total(100) - progress.update_target("progress", 1, 10) - - # logging an invisible message should not break - # the progress bar output - with caplog.at_level(logging.INFO, logger="dvc"): - debug_record = logging.LogRecord( - name="dvc", - level=logging.DEBUG, - pathname=__name__, - lineno=1, - msg="debug", - args=(), - exc_info=None, - ) - - formatter.format(debug_record) - captured = capsys.readouterr() - assert "\n" not in captured.out - - # just when the message is actually visible - with caplog.at_level(logging.INFO, logger="dvc"): - logger.info("some info") - captured = capsys.readouterr() - assert "\n" in captured.out + with Tqdm(total=100, desc="progress") as pbar: + pbar.update() + + # logging an invisible message should not break + # the progress bar output + with caplog.at_level(logging.INFO, logger="dvc"): + debug_record = logging.LogRecord( + name="dvc", + level=logging.DEBUG, + pathname=__name__, + lineno=1, + msg="debug", + args=(), + exc_info=None, + ) + + formatter.format(debug_record) + captured = capsys.readouterr() + assert captured.out == "" + + # when the message is actually visible + with caplog.at_level(logging.INFO, logger="dvc"): + logger.info("some info") + captured = capsys.readouterr() + assert captured.out == "" diff --git a/tests/unit/test_progress.py b/tests/unit/test_progress.py index e617c1775b..8149300c03 100644 --- a/tests/unit/test_progress.py +++ b/tests/unit/test_progress.py @@ -1,19 +1,17 @@ import logging -import mock -from dvc.progress import progress, ProgressCallback +from dvc.progress import Tqdm def test_quiet(caplog, capsys): with caplog.at_level(logging.CRITICAL, logger="dvc"): - progress._print("something") - assert capsys.readouterr().out == "" - - -class TestProgressCallback: - @mock.patch("dvc.progress.progress") - def test_should_init_reset_progress(self, progress_mock): - total_files_num = 1 - - ProgressCallback(total_files_num) - - assert [mock.call.reset()] == progress_mock.method_calls + for _ in Tqdm(range(10)): + pass + out_err = capsys.readouterr() + assert out_err.out == "" + assert out_err.err == "" + with caplog.at_level(logging.INFO, logger="dvc"): + for _ in Tqdm(range(10)): + pass + out_err = capsys.readouterr() + assert out_err.out == "" + assert "0/10" in out_err.err