Skip to content

Commit

Permalink
Add rule for list comprehensions passed to any()/all() (#427)
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrugman committed Apr 13, 2023
1 parent b6c2b95 commit 867247f
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 3 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ Changelog

* Add rule C418 to check for calls passing a dict literal or dict comprehension to ``dict()``.

* Add rule C419 to check for calls passing a list comprehension to ``any()``/``all()``.

3.11.1 (2023-03-21)
-------------------

Expand Down
11 changes: 11 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -216,3 +216,14 @@ For example:

* Rewrite ``dict({})`` as ``{}``
* Rewrite ``dict({"a": 1})`` as ``{"a": 1}``

C419 Unnecessary list comprehension in ``<any/all>``\() prevents short-circuiting - rewrite as a generator.
-----------------------------------------------------------------------------------------------------------

Using a list comprehension inside a call to ``any()``/``all()`` prevents short-circuiting when a ``True`` / ``False`` value is found.
The whole list will be constructed before calling ``any()``/``all()``, potentially wasting work.part-way.
Rewrite to use a generator expression, which can stop part way.
For example:

* Rewrite ``all([condition(x) for x in iterable])`` as ``all(condition(x) for x in iterable)``
* Rewrite ``any([condition(x) for x in iterable])`` as ``any(condition(x) for x in iterable)``
16 changes: 13 additions & 3 deletions src/flake8_comprehensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def __init__(self, tree: ast.AST) -> None:
"C418 Unnecessary {type} passed to dict() - "
+ "remove the outer call to dict()."
),
"C419": (
"C419 Unnecessary list comprehension passed to {func}() prevents "
+ "short-circuiting - rewrite as a generator."
),
}

def run(self) -> Generator[tuple[int, int, str, type[Any]], None, None]:
Expand Down Expand Up @@ -93,13 +97,19 @@ def run(self) -> Generator[tuple[int, int, str, type[Any]], None, None]:
elif (
num_positional_args == 1
and isinstance(node.args[0], ast.ListComp)
and node.func.id in ("list", "set")
and node.func.id in ("list", "set", "any", "all")
):
msg_key = {"list": "C411", "set": "C403"}[node.func.id]
msg_key = {
"list": "C411",
"set": "C403",
"any": "C419",
"all": "C419",
}[node.func.id]
msg = self.messages[msg_key].format(func=node.func.id)
yield (
node.lineno,
node.col_offset,
self.messages[msg_key],
msg,
type(self),
)

Expand Down
38 changes: 38 additions & 0 deletions tests/test_flake8_comprehensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,3 +938,41 @@ def test_C418_fail(code, failures, flake8_path):
(flake8_path / "example.py").write_text(dedent(code))
result = flake8_path.run_flake8()
assert result.out_lines == failures


@pytest.mark.parametrize(
"code",
[
"any(num == 3 for num in range(5))",
"all(num == 3 for num in range(5))",
],
)
def test_C419_pass(code, flake8_path):
(flake8_path / "example.py").write_text(dedent(code))
result = flake8_path.run_flake8()
assert result.out_lines == []


@pytest.mark.parametrize(
"code,failures",
[
(
"any([num == 3 for num in range(5)])",
[
"./example.py:1:1: C419 Unnecessary list comprehension passed "
+ "to any() prevents short-circuiting - rewrite as a generator."
],
),
(
"all([num == 3 for num in range(5)])",
[
"./example.py:1:1: C419 Unnecessary list comprehension passed "
+ "to all() prevents short-circuiting - rewrite as a generator."
],
),
],
)
def test_C419_fail(code, failures, flake8_path):
(flake8_path / "example.py").write_text(dedent(code))
result = flake8_path.run_flake8()
assert result.out_lines == failures

0 comments on commit 867247f

Please sign in to comment.