diff --git a/flow/project.py b/flow/project.py index 1ec2bc22e..f84b330b9 100644 --- a/flow/project.py +++ b/flow/project.py @@ -37,16 +37,6 @@ from jinja2 import TemplateNotFound as Jinja2TemplateNotFound from signac.contrib.filterparse import parse_filter_arg -try: - # If ipywidgets is installed, use "auto" tqdm to improve notebook support. - # Otherwise, use only text-based progress bars. This workaround can be - # removed after https://github.com/tqdm/tqdm/pull/1218. - import ipywidgets # noqa: F401 -except ImportError: - from tqdm import tqdm -else: - from tqdm.auto import tqdm - from .aggregates import ( _AggregatesCursor, _AggregateStore, @@ -82,6 +72,7 @@ add_cwd_to_environment_pythonpath, roundrobin, switch_to_directory, + tqdm, ) from .util.translate import abbreviate, shorten diff --git a/flow/render_status.py b/flow/render_status.py index fa4e582b8..657808d0f 100644 --- a/flow/render_status.py +++ b/flow/render_status.py @@ -2,18 +2,9 @@ # All rights reserved. # This software is licensed under the BSD 3-Clause License. """Status rendering logic.""" -try: - # If ipywidgets is installed, use "auto" tqdm to improve notebook support. - # Otherwise, use only text-based progress bars. This workaround can be - # removed after https://github.com/tqdm/tqdm/pull/1218. - import ipywidgets # noqa: F401 -except ImportError: - from tqdm import tqdm -else: - from tqdm.auto import tqdm - from .scheduling.base import JobStatus from .util import mistune +from .util.misc import tqdm def _render_status( diff --git a/flow/util/misc.py b/flow/util/misc.py index 0190d4c83..12b295e7f 100644 --- a/flow/util/misc.py +++ b/flow/util/misc.py @@ -14,6 +14,16 @@ from tqdm.contrib import tmap from tqdm.contrib.concurrent import process_map, thread_map +try: + # If ipywidgets is installed, use "auto" tqdm to improve notebook support. + # Otherwise, use only text-based progress bars. This workaround can be + # removed after https://github.com/tqdm/tqdm/pull/1218. + import ipywidgets # noqa: F401 +except ImportError: + from tqdm import tqdm +else: + from tqdm.auto import tqdm + def _positive_int(value): """Parse a command line argument as a positive integer. @@ -348,7 +358,10 @@ def _get_parallel_executor(parallelization="none"): """ if parallelization == "thread": - parallel_executor = thread_map + + def parallel_executor(func, iterable, **kwargs): + return thread_map(func, iterable, tqdm_class=tqdm, **kwargs) + elif parallelization == "process": def parallel_executor(func, iterable, **kwargs): @@ -365,6 +378,7 @@ def parallel_executor(func, iterable, **kwargs): # regardless of whether it is a local function. partial(_run_cloudpickled_func, cloudpickle.dumps(func)), map(cloudpickle.dumps, iterable), + tqdm_class=tqdm, **kwargs, ) @@ -374,6 +388,6 @@ def parallel_executor(func, iterable, **kwargs): if "chunksize" in kwargs: # Chunk size only applies to thread/process parallel executors del kwargs["chunksize"] - return list(tmap(func, iterable, **kwargs)) + return list(tmap(func, iterable, tqdm_class=tqdm, **kwargs)) return parallel_executor