Skip to content

Commit

Permalink
Merge pull request #1148 from grst/master
Browse files Browse the repository at this point in the history
  • Loading branch information
casperdcl committed Apr 5, 2021
2 parents ca11f27 + c633fc5 commit 9fb0f23
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 13 deletions.
34 changes: 22 additions & 12 deletions tests/tests_contrib.py
Expand Up @@ -3,6 +3,9 @@
"""
import sys

import pytest

from tqdm import tqdm
from tqdm.contrib import tenumerate, tmap, tzip

from .tests_tqdm import StringIO, closing, importorskip
Expand All @@ -13,49 +16,56 @@ def incr(x):
return x + 1


def test_enumerate():
@pytest.mark.parametrize("tqdm_kwargs", [{}, {"tqdm_class": tqdm}])
def test_enumerate(tqdm_kwargs):
"""Test contrib.tenumerate"""
with closing(StringIO()) as our_file:
a = range(9)
assert list(tenumerate(a, file=our_file)) == list(enumerate(a))
assert list(tenumerate(a, 42, file=our_file)) == list(enumerate(a, 42))
assert list(tenumerate(a, file=our_file, **tqdm_kwargs)) == list(enumerate(a))
assert list(tenumerate(a, 42, file=our_file, **tqdm_kwargs)) == list(
enumerate(a, 42)
)
with closing(StringIO()) as our_file:
_ = list(tenumerate((i for i in a), file=our_file))
_ = list(tenumerate((i for i in a), file=our_file, **tqdm_kwargs))
assert "100%" not in our_file.getvalue()
with closing(StringIO()) as our_file:
_ = list(tenumerate((i for i in a), file=our_file, total=len(a)))
_ = list(tenumerate((i for i in a), file=our_file, total=len(a), **tqdm_kwargs))
assert "100%" in our_file.getvalue()


def test_enumerate_numpy():
"""Test contrib.tenumerate(numpy.ndarray)"""
np = importorskip('numpy')
np = importorskip("numpy")
with closing(StringIO()) as our_file:
a = np.random.random((42, 7))
assert list(tenumerate(a, file=our_file)) == list(np.ndenumerate(a))


def test_zip():
@pytest.mark.parametrize("tqdm_kwargs", [{}, {"tqdm_class": tqdm}])
def test_zip(tqdm_kwargs):
"""Test contrib.tzip"""
with closing(StringIO()) as our_file:
a = range(9)
b = [i + 1 for i in a]
if sys.version_info[:1] < (3,):
assert tzip(a, b, file=our_file) == zip(a, b)
assert tzip(a, b, file=our_file, **tqdm_kwargs) == zip(a, b)
else:
gen = tzip(a, b, file=our_file)
gen = tzip(a, b, file=our_file, **tqdm_kwargs)
assert gen != list(zip(a, b))
assert list(gen) == list(zip(a, b))


def test_map():
@pytest.mark.parametrize("tqdm_kwargs", [{}, {"tqdm_class": tqdm}])
def test_map(tqdm_kwargs):
"""Test contrib.tmap"""
with closing(StringIO()) as our_file:
a = range(9)
b = [i + 1 for i in a]
if sys.version_info[:1] < (3,):
assert tmap(lambda x: x + 1, a, file=our_file) == map(incr, a)
assert tmap(lambda x: x + 1, a, file=our_file, **tqdm_kwargs) == map(
incr, a
)
else:
gen = tmap(lambda x: x + 1, a, file=our_file)
gen = tmap(lambda x: x + 1, a, file=our_file, **tqdm_kwargs)
assert gen != b
assert list(gen) == b
3 changes: 2 additions & 1 deletion tqdm/contrib/__init__.py
Expand Up @@ -16,6 +16,7 @@

class DummyTqdmFile(ObjectWrapper):
"""Dummy file-like that will write to tqdm"""

def __init__(self, wrapped):
super(DummyTqdmFile, self).__init__(wrapped)
self._buf = []
Expand Down Expand Up @@ -80,7 +81,7 @@ def tzip(iter1, *iter2plus, **tqdm_kwargs):
"""
kwargs = tqdm_kwargs.copy()
tqdm_class = kwargs.pop("tqdm_class", tqdm_auto)
for i in zip(tqdm_class(iter1, **tqdm_kwargs), *iter2plus):
for i in zip(tqdm_class(iter1, **kwargs), *iter2plus):
yield i


Expand Down

0 comments on commit 9fb0f23

Please sign in to comment.