diff --git a/tests/tests_contrib.py b/tests/tests_contrib.py index cf684eeae..3a8396251 100644 --- a/tests/tests_contrib.py +++ b/tests/tests_contrib.py @@ -3,6 +3,9 @@ """ import sys +import pytest + +from tqdm import tqdm from tqdm.contrib import tenumerate, tmap, tzip from .tests_tqdm import StringIO, closing, importorskip @@ -13,49 +16,56 @@ def incr(x): return x + 1 -def test_enumerate(): +@pytest.mark.parametrize("tqdm_kwargs", [{}, {"tqdm_class": tqdm}]) +def test_enumerate(tqdm_kwargs): """Test contrib.tenumerate""" with closing(StringIO()) as our_file: a = range(9) - assert list(tenumerate(a, file=our_file)) == list(enumerate(a)) - assert list(tenumerate(a, 42, file=our_file)) == list(enumerate(a, 42)) + assert list(tenumerate(a, file=our_file, **tqdm_kwargs)) == list(enumerate(a)) + assert list(tenumerate(a, 42, file=our_file, **tqdm_kwargs)) == list( + enumerate(a, 42) + ) with closing(StringIO()) as our_file: - _ = list(tenumerate((i for i in a), file=our_file)) + _ = list(tenumerate((i for i in a), file=our_file, **tqdm_kwargs)) assert "100%" not in our_file.getvalue() with closing(StringIO()) as our_file: - _ = list(tenumerate((i for i in a), file=our_file, total=len(a))) + _ = list(tenumerate((i for i in a), file=our_file, total=len(a), **tqdm_kwargs)) assert "100%" in our_file.getvalue() def test_enumerate_numpy(): """Test contrib.tenumerate(numpy.ndarray)""" - np = importorskip('numpy') + np = importorskip("numpy") with closing(StringIO()) as our_file: a = np.random.random((42, 7)) assert list(tenumerate(a, file=our_file)) == list(np.ndenumerate(a)) -def test_zip(): +@pytest.mark.parametrize("tqdm_kwargs", [{}, {"tqdm_class": tqdm}]) +def test_zip(tqdm_kwargs): """Test contrib.tzip""" with closing(StringIO()) as our_file: a = range(9) b = [i + 1 for i in a] if sys.version_info[:1] < (3,): - assert tzip(a, b, file=our_file) == zip(a, b) + assert tzip(a, b, file=our_file, **tqdm_kwargs) == zip(a, b) else: - gen = tzip(a, b, file=our_file) + gen = tzip(a, b, file=our_file, **tqdm_kwargs) assert gen != list(zip(a, b)) assert list(gen) == list(zip(a, b)) -def test_map(): +@pytest.mark.parametrize("tqdm_kwargs", [{}, {"tqdm_class": tqdm}]) +def test_map(tqdm_kwargs): """Test contrib.tmap""" with closing(StringIO()) as our_file: a = range(9) b = [i + 1 for i in a] if sys.version_info[:1] < (3,): - assert tmap(lambda x: x + 1, a, file=our_file) == map(incr, a) + assert tmap(lambda x: x + 1, a, file=our_file, **tqdm_kwargs) == map( + incr, a + ) else: - gen = tmap(lambda x: x + 1, a, file=our_file) + gen = tmap(lambda x: x + 1, a, file=our_file, **tqdm_kwargs) assert gen != b assert list(gen) == b diff --git a/tqdm/contrib/__init__.py b/tqdm/contrib/__init__.py index 8a6d2474c..ac1969890 100644 --- a/tqdm/contrib/__init__.py +++ b/tqdm/contrib/__init__.py @@ -16,6 +16,7 @@ class DummyTqdmFile(ObjectWrapper): """Dummy file-like that will write to tqdm""" + def __init__(self, wrapped): super(DummyTqdmFile, self).__init__(wrapped) self._buf = [] @@ -80,7 +81,7 @@ def tzip(iter1, *iter2plus, **tqdm_kwargs): """ kwargs = tqdm_kwargs.copy() tqdm_class = kwargs.pop("tqdm_class", tqdm_auto) - for i in zip(tqdm_class(iter1, **tqdm_kwargs), *iter2plus): + for i in zip(tqdm_class(iter1, **kwargs), *iter2plus): yield i