From 69df59096281d0006b0861db79957e6d306c5301 Mon Sep 17 00:00:00 2001 From: Zac Hatfield-Dodds Date: Thu, 23 Jun 2022 00:45:49 -0700 Subject: [PATCH] Implement late-binding loop check --- README.rst | 3 ++ bugbear.py | 91 +++++++++++++++++++++++++++++++++++++++++++ tests/b023.py | 56 ++++++++++++++++++++++++++ tests/test_bugbear.py | 22 +++++++++++ 4 files changed, 172 insertions(+) create mode 100644 tests/b023.py diff --git a/README.rst b/README.rst index bde5850..600e96c 100644 --- a/README.rst +++ b/README.rst @@ -150,6 +150,9 @@ No exceptions will be suppressed and therefore this context manager is redundant N.B. this rule currently does not flag `suppress` calls to avoid potential false positives due to similarly named user-defined functions. +**B023**: Functions defined inside a loop must not use variables redefined in +the loop, because `late-binding closures are a classic gotcha +`__. Opinionated warnings ~~~~~~~~~~~~~~~~~~~~ diff --git a/bugbear.py b/bugbear.py index f1dcf9e..8e7ae0c 100644 --- a/bugbear.py +++ b/bugbear.py @@ -26,6 +26,7 @@ ast.DictComp, ast.GeneratorExp, ) +FUNCTION_NODES = (ast.AsyncFunctionDef, ast.FunctionDef, ast.Lambda) Context = namedtuple("Context", ["node", "stack"]) @@ -198,6 +199,21 @@ def _to_name_str(node): return _to_name_str(node.value) +def names_from_assignments(node): + for name in ast.walk(node): + if isinstance(name, ast.Name): + yield name.id + elif isinstance(name, ast.arg): + yield name.arg + + +def children_in_scope(node): + yield node + if not isinstance(node, FUNCTION_NODES): + for child in ast.iter_child_nodes(node): + yield from children_in_scope(child) + + def _typesafe_issubclass(cls, class_or_tuple): try: return issubclass(cls, class_or_tuple) @@ -348,6 +364,31 @@ def visit_Assign(self, node): def visit_For(self, node): self.check_for_b007(node) self.check_for_b020(node) + self.check_for_b023(node) + self.generic_visit(node) + + def visit_AsyncFor(self, node): + self.check_for_b023(node) + self.generic_visit(node) + + def visit_While(self, node): + self.check_for_b023(node) + self.generic_visit(node) + + def visit_ListComp(self, node): + self.check_for_b023(node) + self.generic_visit(node) + + def visit_SetComp(self, node): + self.check_for_b023(node) + self.generic_visit(node) + + def visit_DictComp(self, node): + self.check_for_b023(node) + self.generic_visit(node) + + def visit_GeneratorExp(self, node): + self.check_for_b023(node) self.generic_visit(node) def visit_Assert(self, node): @@ -520,6 +561,47 @@ def check_for_b020(self, node): n = targets.names[name][0] self.errors.append(B020(n.lineno, n.col_offset, vars=(name,))) + def check_for_b023(self, loop_node): + """Check that functions (including lambdas) do not use loop variables. + + https://docs.python-guide.org/writing/gotchas/#late-binding-closures from + functions - usually but not always lambdas - defined inside a loop are a + classic source of bugs. + + For each use of a variable inside a function defined inside a loop, we + emit a warning if that variable is reassigned on each loop iteration + (outside the function). This includes but is not limited to explicit + loop variables like the `x` in `for x in range(3):`. + """ + # Because most loops don't contain functions, it's most efficient to + # implement this "backwards": first we find all the candidate variable + # uses, and then if there are any we check for assignment of those names + # inside the loop body. + suspicious_variables = [] + for node in ast.walk(loop_node): + if isinstance(node, FUNCTION_NODES): + argnames = set(names_from_assignments(node.args)) + for name in ast.walk(node): + if isinstance(name, ast.Name) and name.id not in argnames: + err = B023(name.lineno, name.col_offset, vars=(name.id,)) + suspicious_variables.append(err) + + if suspicious_variables: + reassigned_in_loop = set(self._get_assigned_names(loop_node)) + + for err in sorted(suspicious_variables): + if reassigned_in_loop.issuperset(err.vars): + self.errors.append(err) + + def _get_assigned_names(self, loop_node): + loop_targets = (ast.For, ast.AsyncFor, ast.comprehension) + for node in children_in_scope(loop_node): + if isinstance(node, (ast.Assign)): + for child in node.targets: + yield from names_from_assignments(child) + if isinstance(node, loop_targets + (ast.AnnAssign, ast.AugAssign)): + yield from names_from_assignments(node.target) + def check_for_b904(self, node): """Checks `raise` without `from` inside an `except` clause. @@ -1041,6 +1123,15 @@ def visit_Lambda(self, node): ) ) +B023 = Error( + message=( + "B023 Function definition does not bind loop variable {!r}. " + "This means that invoking functions defined inside the loop will " + "always use the value of this variable from the last iteration. See " + "https://docs.python-guide.org/writing/gotchas/#late-binding-closures" + ) +) + # Warnings disabled by default. B901 = Error( message=( diff --git a/tests/b023.py b/tests/b023.py new file mode 100644 index 0000000..fe78e4e --- /dev/null +++ b/tests/b023.py @@ -0,0 +1,56 @@ +""" +Should emit: +B023 - on lines 12, 13, 16, 28, 29, 30, 31, 40, 42, 50, 51, 52, 53. +""" + +functions = [] +z = 0 + +for x in range(3): + y = x + 1 + # Subject to late-binding problems + functions.append(lambda: x) + functions.append(lambda: y) # not just the loop var + + def f_bad_1(): + return x + + # Actually OK + functions.append(lambda x: x * 2) + functions.append(lambda x=x: x) + functions.append(lambda: z) # OK because not assigned in the loop + + def f_ok_1(x): + return x * 2 + + +def check_inside_functions_too(): + ls = [lambda: x for x in range(2)] + st = {lambda: x for x in range(2)} + gn = (lambda: x for x in range(2)) + dt = {x: lambda: x for x in range(2)} + + +async def pointless_async_iterable(): + yield 1 + + +async def container_for_problems(): + async for x in pointless_async_iterable(): + functions.append(lambda: x) + + [lambda: x async for x in pointless_async_iterable()] + + +a = 10 +b = 0 +while True: + a = a_ = a - 1 + b += 1 + functions.append(lambda: a) + functions.append(lambda: a_) + functions.append(lambda: b) + functions.append(lambda: c) # not a name error because of late binding! + c: bool = a > 3 + if not c: + break diff --git a/tests/test_bugbear.py b/tests/test_bugbear.py index 12469e6..31686c6 100644 --- a/tests/test_bugbear.py +++ b/tests/test_bugbear.py @@ -33,6 +33,7 @@ B020, B021, B022, + B023, B901, B902, B903, @@ -325,6 +326,27 @@ def test_b022(self): errors = list(bbc.run()) self.assertEqual(errors, self.errors(B022(8, 0))) + def test_b023(self): + filename = Path(__file__).absolute().parent / "b023.py" + bbc = BugBearChecker(filename=str(filename)) + errors = list(bbc.run()) + expected = self.errors( + B023(12, 29, vars=("x",)), + B023(13, 29, vars=("y",)), + B023(16, 15, vars=("x",)), + B023(28, 18, vars=("x",)), + B023(29, 18, vars=("x",)), + B023(30, 18, vars=("x",)), + B023(31, 21, vars=("x",)), + B023(40, 33, vars=("x",)), + B023(42, 13, vars=("x",)), + B023(50, 29, vars=("a",)), + B023(51, 29, vars=("a_",)), + B023(52, 29, vars=("b",)), + B023(53, 29, vars=("c",)), + ) + self.assertEqual(errors, expected) + def test_b901(self): filename = Path(__file__).absolute().parent / "b901.py" bbc = BugBearChecker(filename=str(filename))