diff --git a/more_itertools/recipes.py b/more_itertools/recipes.py index 76c6815d..3fb49b53 100644 --- a/more_itertools/recipes.py +++ b/more_itertools/recipes.py @@ -744,11 +744,12 @@ def true_iterator(): transition.append(elem) return - def remainder_iterator(): - yield from transition - yield from it + # Note: this is different from itertools recipes to allow nesting + # before_and_after remainders into before_and_after again. See tests + # for an example. + remainder_iterator = chain(transition, it) - return true_iterator(), remainder_iterator() + return true_iterator(), remainder_iterator def triplewise(iterable): diff --git a/tests/test_recipes.py b/tests/test_recipes.py index 6c3d9779..44e35d10 100644 --- a/tests/test_recipes.py +++ b/tests/test_recipes.py @@ -1,6 +1,7 @@ import warnings from doctest import DocTestSuite +from functools import reduce from itertools import combinations, count, permutations from math import factorial from unittest import TestCase @@ -783,6 +784,39 @@ def test_some_true(self): self.assertEqual(list(before), [1, True]) self.assertEqual(list(after), [0, False]) + @staticmethod + def _group_events(events): + events = iter(events) + + while True: + try: + operation = next(events) + except StopIteration: + break + assert operation in ["SUM", "MULTIPLY"] + + # Here, the remainder `events` is passed into `before_and_after` + # again, which would be problematic if the remainder is a + # generator function (as in Python 3.10 itertools recipes), since + # that creates recursion. `itertools.chain` solves this problem. + numbers, events = mi.before_and_after( + lambda e: isinstance(e, int), events + ) + + yield (operation, numbers) + + def test_nested_remainder(self): + events = ["SUM", 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] * 1000 + events += ["MULTIPLY", 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] * 1000 + + for operation, numbers in self._group_events(events): + if operation == "SUM": + res = sum(numbers) + self.assertEqual(res, 55) + elif operation == "MULTIPLY": + res = reduce(lambda a, b: a * b, numbers) + self.assertEqual(res, 3628800) + class TriplewiseTests(TestCase): def test_basic(self):