Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding an asyncio.gather() replacement for tqdm #1136

Merged
merged 5 commits into from Mar 5, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
14 changes: 14 additions & 0 deletions tests/py37_asyncio.py
Expand Up @@ -10,6 +10,7 @@
tqdm = partial(tqdm_asyncio, miniters=0, mininterval=0)
trange = partial(tarange, miniters=0, mininterval=0)
as_completed = partial(tqdm_asyncio.as_completed, miniters=0, mininterval=0)
gather = partial(tqdm_asyncio.gather, miniters=0, mininterval=0)


def count(start=0, step=1):
Expand Down Expand Up @@ -112,3 +113,16 @@ async def test_as_completed(capsys, tol):
except AssertionError:
if retry == 2:
raise


async def double(i):
return i * 2


@mark.asyncio
async def test_gather(capsys):
"""Test asyncio gather"""
res = await gather(list(map(double, range(30))))
_, err = capsys.readouterr()
assert '30/30' in err
assert res == list(range(0, 30 * 2, 2))
15 changes: 14 additions & 1 deletion tqdm/asyncio.py
Expand Up @@ -17,7 +17,7 @@

class tqdm_asyncio(std_tqdm):
"""
Asynchronous-friendly version of tqdm (Python 3.5+).
Asynchronous-friendly version of tqdm (Python 3.6+).
"""
def __init__(self, iterable=None, *args, **kwargs):
super(tqdm_asyncio, self).__init__(iterable, *args, **kwargs)
Expand Down Expand Up @@ -63,6 +63,19 @@ def as_completed(cls, fs, *, loop=None, timeout=None, total=None, **tqdm_kwargs)
yield from cls(asyncio.as_completed(fs, loop=loop, timeout=timeout),
total=total, **tqdm_kwargs)

@classmethod
async def gather(cls, fs, *, loop=None, timeout=None, total=None, **tqdm_kwargs):
"""
Wrapper for `asyncio.gather`.
"""
async def wrap_awaitable(i, f):
return i, await f

ifs = [wrap_awaitable(i, f) for i, f in enumerate(fs)]
res = [await f for f in cls.as_completed(ifs, loop=loop, timeout=timeout,
total=total, **tqdm_kwargs)]
return [i for _, i in sorted(res)]


def tarange(*args, **kwargs):
"""
Expand Down
6 changes: 3 additions & 3 deletions tqdm/auto.py
Expand Up @@ -4,7 +4,7 @@
Method resolution order:

- `tqdm.autonotebook` without import warnings
- `tqdm.asyncio` on Python3.5+
- `tqdm.asyncio` on Python3.6+
- `tqdm.std` base class

Usage:
Expand All @@ -22,10 +22,10 @@
from .autonotebook import tqdm as notebook_tqdm
from .autonotebook import trange as notebook_trange

if sys.version_info[:2] < (3, 5):
if sys.version_info[:2] < (3, 6):
tqdm = notebook_tqdm
trange = notebook_trange
else: # Python3.5+
else: # Python3.6+
from .asyncio import tqdm as asyncio_tqdm
from .std import tqdm as std_tqdm

Expand Down