diff --git a/docs/api.rst b/docs/api.rst index d035b1c5..ee1ffb8a 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -174,6 +174,7 @@ These tools yield certain items from an iterable. .. autofunction:: last(iterable[, default]) .. autofunction:: one(iterable, too_short=ValueError, too_long=ValueError) .. autofunction:: only(iterable, default=None, too_long=ValueError) +.. autofunction:: strictly_n(iterable, too_short=func, too_long=func) .. autofunction:: strip .. autofunction:: lstrip .. autofunction:: rstrip diff --git a/more_itertools/more.py b/more_itertools/more.py index 685c1560..de5ed21c 100755 --- a/more_itertools/more.py +++ b/more_itertools/more.py @@ -112,6 +112,7 @@ 'spy', 'stagger', 'strip', + 'strictly_n', 'substrings', 'substrings_indexes', 'time_limited', @@ -591,41 +592,78 @@ def raise_(exception, *args): raise exception(*args) -def nth_exactly( +def strictly_n( iterable, - n=1, - default=None, - too_short=lambda expected, given: raise_( + n, + too_short=lambda items: raise_( ValueError, - 'Too few items in iterable (expected {expected}, but got {given}).', + f'Too few items in iterable (got {len(items)}).', ), - too_long=lambda expected, nth_value, after_value: raise_( + too_long=lambda items: raise_( ValueError, - 'Too many items in iterable ' - '(expected exactly {expected} items in iterable, ' - 'but got {after_value} after {nth_value} and perhaps more.', + f'Too many items in iterable (got at least {len(items)})', ), ): - """A more generalized version of `one`""" - it = iter(iterable) + """Validate that *iterable* has exactly *n* items and return them if + it does. If it has fewer than *n* items, call function *too_short* + with those items. If it has more than *n* items, call function + *too_long* with the first ``n + 1`` items. + + >>> iterable = ['a', 'b', 'c', 'd'] + >>> n = 4 + >>> strictly_n(iterable, n) + ['a', 'b', 'c', 'd'] - counter = count() - consume(zip(it, counter), n - 1) + By default, *too_short* and *too_long* are functions that raise + ``ValueError``. + + >>> strictly_n(['a', 'b'], 3) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError: too few items in iterable (got 2)' + + >>> strictly_n(['a', 'b', 'c'], 2) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError: too many items in iterable (got at least 3)' + You can supply your own functions for *too_short* and *too_long*. + The *too_short* function should accept a list of the items that were + in *iterable*. The *too_long* function should accept a list of the first + ``n + 1`` items that were in *iterable*. + + >>> def too_short(items): + ... average = sum(items) / len(items) + ... return items + [average] + >>> strictly_n([5.0, 5.0, 20.0], 4, too_short=too_short) + [5.0, 5.0, 20.0, 10.0] + + >>> def too_long(items): + ... average = sum(items) / len(items) + ... return [x + average for x in items[:-1]] + >>> strictly_n([5.0, 5.0, 20.0], 2, too_long=too_long) + [15.0, 15.0] + + """ + it = iter(iterable) + + items = take(n - 1, it) try: nth_value = next(it) except StopIteration: - too_short(expected=n, given=next(counter)) - return default + return too_short(items) + else: + items.append(nth_value) try: after_value = next(it) except StopIteration: pass else: - too_long(expected=n, nth_value=nth_value, after_value=after_value) + items.append(after_value) + return too_long(items) - return nth_value + return items def distinct_permutations(iterable, r=None): diff --git a/more_itertools/more.pyi b/more_itertools/more.pyi index 0975e2d2..97b96e9e 100644 --- a/more_itertools/more.pyi +++ b/more_itertools/more.pyi @@ -85,13 +85,12 @@ def one( too_long: Optional[_Raisable] = ..., ) -> _T: ... def raise_(exception: _Raisable, *args: Any) -> None: ... -def nth_exactly( +def strictly_n( iterable: Iterable[_T], n: int, - default: _U, - too_short: _GenFn, - too_long: _GenFn, -) -> Union[_T, _U]: ... + too_short: Optional[_Raisable] = ..., + too_long: Optional[_Raisable] = ..., +) -> List[_T]: ... def distinct_permutations( iterable: Iterable[_T], r: Optional[int] = ... ) -> Iterator[Tuple[_T, ...]]: ... diff --git a/tests/test_more.py b/tests/test_more.py index e59bf16c..4206dafc 100644 --- a/tests/test_more.py +++ b/tests/test_more.py @@ -4719,3 +4719,40 @@ def test_key(self): actual = list(mi.unique_in_window(iterable, n, key=key)) expected = [0, 3, 6, 9] self.assertEqual(actual, expected) + + +class StrictlyNTests(TestCase): + def test_basic(self): + iterable = ['a', 'b', 'c', 'd'] + n = 4 + actual = mi.strictly_n(iter(iterable), n) + expected = iterable + self.assertEqual(actual, expected) + + def test_too_short_default(self): + iterable = ['a', 'b', 'c', 'd'] + n = 5 + with self.assertRaises(ValueError): + mi.strictly_n(iter(iterable), n) + + def test_too_long_default(self): + iterable = ['a', 'b', 'c', 'd'] + n = 3 + with self.assertRaises(ValueError): + mi.strictly_n(iter(iterable), n) + + def test_too_short_custom(self): + iterable = ['a', 'b', 'c', 'd'] + n = 6 + too_short = lambda items: items + (['?'] * (n - len(items))) + actual = mi.strictly_n(iter(iterable), n, too_short=too_short) + expected = ['a', 'b', 'c', 'd', '?', '?'] + self.assertEqual(actual, expected) + + def test_too_long_custom(self): + iterable = ['a', 'b', 'c', 'd'] + n = 2 + too_long = lambda items: items[-n:] + actual = mi.strictly_n(iter(iterable), n, too_long=too_long) + expected = ['b', 'c'] + self.assertEqual(actual, expected)