Skip to content

Commit

Permalink
Merge pull request #646 from more-itertools/rename-batched
Browse files Browse the repository at this point in the history
Rename `batched` to `constrained_batches`, add `batched` recipe
  • Loading branch information
bbayles committed Oct 8, 2022
2 parents 9a1168a + 5829849 commit d63ed69
Show file tree
Hide file tree
Showing 7 changed files with 58 additions and 12 deletions.
3 changes: 2 additions & 1 deletion docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ These tools yield groups of items from a source iterable.
.. autofunction:: ichunked
.. autofunction:: chunked_even
.. autofunction:: sliced
.. autofunction:: batched(iterable, max_size, max_count=None, get_len=len, strict=True)
.. autofunction:: constrained_batches(iterable, max_size, max_count=None, get_len=len, strict=True)
.. autofunction:: distribute
.. autofunction:: divide
.. autofunction:: split_at
Expand All @@ -32,6 +32,7 @@ These tools yield groups of items from a source iterable.

**Itertools recipes**

.. autofunction:: batched
.. autofunction:: grouper
.. autofunction:: partition

Expand Down
10 changes: 6 additions & 4 deletions more_itertools/more.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@
'all_unique',
'always_iterable',
'always_reversible',
'batched',
'bucket',
'callback_iter',
'chunked',
Expand All @@ -56,6 +55,7 @@
'collate',
'combination_index',
'consecutive_groups',
'constrained_batches',
'consumer',
'count_cycle',
'countable',
Expand Down Expand Up @@ -4336,19 +4336,21 @@ def minmax(iterable_or_value, *others, key=None, default=_marker):
return lo, hi


def batched(iterable, max_size, max_count=None, get_len=len, strict=True):
def constrained_batches(
iterable, max_size, max_count=None, get_len=len, strict=True
):
"""Yield batches of items from *iterable* with a combined size limited by
*max_size*.
>>> iterable = [b'12345', b'123', b'12345678', b'1', b'1', b'12', b'1']
>>> list(batched(iterable, 10))
>>> list(constrained_batches(iterable, 10))
[(b'12345', b'123'), (b'12345678', b'1', b'1'), (b'12', b'1')]
If a *max_count* is supplied, the number of items per batch is also
limited:
>>> iterable = [b'12345', b'123', b'12345678', b'1', b'1', b'12', b'1']
>>> list(batched(iterable, 10, max_count = 2))
>>> list(constrained_batches(iterable, 10, max_count = 2))
[(b'12345', b'123'), (b'12345678', b'1'), (b'1', b'12'), (b'1',)]
If a *get_len* function is supplied, use that instead of :func:`len` to
Expand Down
2 changes: 1 addition & 1 deletion more_itertools/more.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -666,7 +666,7 @@ def longest_common_prefix(
iterables: Iterable[Iterable[_T]],
) -> Iterator[_T]: ...
def iequals(*iterables: Iterable[object]) -> bool: ...
def batched(
def constrained_batches(
iterable: Iterable[object],
max_size: int,
max_count: Optional[int] = ...,
Expand Down
18 changes: 18 additions & 0 deletions more_itertools/recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

__all__ = [
'all_equal',
'batched',
'before_and_after',
'consume',
'convolve',
Expand Down Expand Up @@ -826,3 +827,20 @@ def sieve(n):
data[p + p : n : p] = bytearray(len(range(p + p, n, p)))

return compress(count(), data)


def batched(iterable, n):
"""Batch data into lists of length *n*. The last batch may be shorter.
>>> list(batched('ABCDEFG', 3))
[['A', 'B', 'C'], ['D', 'E', 'F'], ['G']]
This recipe is from the ``itertools`` docs. This library also provides
:func:`chunked`, which has a different implementation.
"""
it = iter(iterable)
while True:
batch = list(islice(it, n))
if not batch:
break
yield batch
4 changes: 4 additions & 0 deletions more_itertools/recipes.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,7 @@ def sliding_window(
def subslices(iterable: Iterable[_T]) -> Iterator[List[_T]]: ...
def polynomial_from_roots(roots: Sequence[int]) -> List[int]: ...
def sieve(n: int) -> Iterator[int]: ...
def batched(
iterable: Iterable[_T],
n: int,
) -> Iterator[List[_T]]: ...
16 changes: 10 additions & 6 deletions tests/test_more.py
Original file line number Diff line number Diff line change
Expand Up @@ -5130,7 +5130,7 @@ def test_not_identical_but_equal(self):
self.assertTrue([1, True], [1.0, complex(1, 0)])


class BatchedTests(TestCase):
class ConstrainedBatchesTests(TestCase):
def test_basic(self):
zen = [
'Beautiful is better than ugly',
Expand Down Expand Up @@ -5185,24 +5185,24 @@ def test_basic(self):
),
):
with self.subTest(size=size):
actual = list(mi.batched(iter(zen), size))
actual = list(mi.constrained_batches(iter(zen), size))
self.assertEqual(actual, expected)

def test_max_count(self):
iterable = ['1', '1', '12345678', '12345', '12345']
max_size = 10
max_count = 2
actual = list(mi.batched(iterable, max_size, max_count))
actual = list(mi.constrained_batches(iterable, max_size, max_count))
expected = [('1', '1'), ('12345678',), ('12345', '12345')]
self.assertEqual(actual, expected)

def test_strict(self):
iterable = ['1', '123456789', '1']
size = 8
with self.assertRaises(ValueError):
list(mi.batched(iterable, size))
list(mi.constrained_batches(iterable, size))

actual = list(mi.batched(iterable, size, strict=False))
actual = list(mi.constrained_batches(iterable, size, strict=False))
expected = [('1',), ('123456789',), ('1',)]
self.assertEqual(actual, expected)

Expand All @@ -5218,6 +5218,10 @@ def total_size(self):
iterable = [record_3, record_5, record_10, record_2]

self.assertEqual(
list(mi.batched(iterable, 10, get_len=lambda x: x.total_size())),
list(
mi.constrained_batches(
iterable, 10, get_len=lambda x: x.total_size()
)
),
[(record_3, record_5), (record_10,), (record_2,)],
)
17 changes: 17 additions & 0 deletions tests/test_recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -903,3 +903,20 @@ def test_small_numbers(self):
for n in (0, 1, 2):
with self.subTest(n=n):
self.assertEqual(list(mi.sieve(n)), [])


class BatchedTests(TestCase):
def test_basic(self):
iterable = range(1, 5 + 1)
for n, expected in (
(0, []),
(1, [[1], [2], [3], [4], [5]]),
(2, [[1, 2], [3, 4], [5]]),
(3, [[1, 2, 3], [4, 5]]),
(4, [[1, 2, 3, 4], [5]]),
(5, [[1, 2, 3, 4, 5]]),
(6, [[1, 2, 3, 4, 5]]),
):
with self.subTest(n=n):
actual = list(mi.batched(iterable, n))
self.assertEqual(actual, expected)

0 comments on commit d63ed69

Please sign in to comment.