diff --git a/bugbear.py b/bugbear.py index 7422ce4..4b3f981 100644 --- a/bugbear.py +++ b/bugbear.py @@ -280,7 +280,7 @@ def visit_ExceptHandler(self, node): names = [_to_name_str(e) for e in node.type.elts] as_ = " as " + node.name if node.name is not None else "" if len(names) == 0: - vs = ("`except (){}:`".format(as_),) + vs = (f"`except (){as_}:`",) self.errors.append(B001(node.lineno, node.col_offset, vars=vs)) elif len(names) == 1: self.errors.append(B013(node.lineno, node.col_offset, vars=names)) @@ -568,7 +568,7 @@ def check_for_b020(self, node): n = targets.names[name][0] self.errors.append(B020(n.lineno, n.col_offset, vars=(name,))) - def check_for_b023(self, loop_node): + def check_for_b023(self, loop_node): # noqa: C901 """Check that functions (including lambdas) do not use loop variables. https://docs.python-guide.org/writing/gotchas/#late-binding-closures from @@ -584,9 +584,38 @@ def check_for_b023(self, loop_node): # implement this "backwards": first we find all the candidate variable # uses, and then if there are any we check for assignment of those names # inside the loop body. + safe_functions = [] suspicious_variables = [] for node in ast.walk(loop_node): - if isinstance(node, FUNCTION_NODES): + # check if function is immediately consumed to avoid false alarm + if isinstance(node, ast.Call): + # check for filter&reduce + if ( + isinstance(node.func, ast.Name) + and node.func.id in ("filter", "reduce") + ) or ( + isinstance(node.func, ast.Attribute) + and node.func.attr == "reduce" + and isinstance(node.func.value, ast.Name) + and node.func.value.id == "functools" + ): + for arg in node.args: + if isinstance(arg, FUNCTION_NODES): + safe_functions.append(arg) + + # check for key= + for keyword in node.keywords: + if keyword.arg == "key" and isinstance( + keyword.value, FUNCTION_NODES + ): + safe_functions.append(keyword.value) + + if isinstance(node, ast.Return): + if isinstance(node.value, FUNCTION_NODES): + safe_functions.append(node.value) + # TODO: ast.walk(node) and mark all child nodes safe? + + if isinstance(node, FUNCTION_NODES) and node not in safe_functions: argnames = { arg.arg for arg in ast.walk(node.args) if isinstance(arg, ast.arg) } @@ -594,16 +623,19 @@ def check_for_b023(self, loop_node): body_nodes = ast.walk(node.body) else: body_nodes = itertools.chain.from_iterable(map(ast.walk, node.body)) + errors = [] for name in body_nodes: - if ( - isinstance(name, ast.Name) - and name.id not in argnames - and isinstance(name.ctx, ast.Load) - ): - err = B023(name.lineno, name.col_offset, vars=(name.id,)) - if err not in self._b023_seen: - self._b023_seen.add(err) # dedupe across nested loops - suspicious_variables.append(err) + if isinstance(name, ast.Name) and name.id not in argnames: + if isinstance(name.ctx, ast.Load): + errors.append( + B023(name.lineno, name.col_offset, vars=(name.id,)) + ) + elif isinstance(name.ctx, ast.Store): + argnames.add(name.id) + for err in errors: + if err.vars[0] not in argnames and err not in self._b023_seen: + self._b023_seen.add(err) # dedupe across nested loops + suspicious_variables.append(err) if suspicious_variables: reassigned_in_loop = set(self._get_assigned_names(loop_node)) @@ -912,7 +944,7 @@ def check_for_b025(self, node): uniques.add(name) seen.extend(uniques) # sort to have a deterministic output - duplicates = sorted(set(x for x in seen if seen.count(x) > 1)) + duplicates = sorted({x for x in seen if seen.count(x) > 1}) for duplicate in duplicates: self.errors.append(B025(node.lineno, node.col_offset, vars=(duplicate,))) diff --git a/tests/b023.py b/tests/b023.py index 542e46e..a7554a1 100644 --- a/tests/b023.py +++ b/tests/b023.py @@ -2,10 +2,10 @@ Should emit: B023 - on lines 12, 13, 16, 28, 29, 30, 31, 40, 42, 50, 51, 52, 53, 61, 68. """ +from functools import reduce functions = [] z = 0 - for x in range(3): y = x + 1 # Subject to late-binding problems @@ -25,10 +25,10 @@ def f_ok_1(x): def check_inside_functions_too(): - ls = [lambda: x for x in range(2)] - st = {lambda: x for x in range(2)} - gn = (lambda: x for x in range(2)) - dt = {x: lambda: x for x in range(2)} + ls = [lambda: x for x in range(2)] # error + st = {lambda: x for x in range(2)} # error + gn = (lambda: x for x in range(2)) # error + dt = {x: lambda: x for x in range(2)} # error async def pointless_async_iterable(): @@ -37,9 +37,9 @@ async def pointless_async_iterable(): async def container_for_problems(): async for x in pointless_async_iterable(): - functions.append(lambda: x) + functions.append(lambda: x) # error - [lambda: x async for x in pointless_async_iterable()] + [lambda: x async for x in pointless_async_iterable()] # error a = 10 @@ -47,10 +47,10 @@ async def container_for_problems(): while True: a = a_ = a - 1 b += 1 - functions.append(lambda: a) - functions.append(lambda: a_) - functions.append(lambda: b) - functions.append(lambda: c) # not a name error because of late binding! + functions.append(lambda: a) # error + functions.append(lambda: a_) # error + functions.append(lambda: b) # error + functions.append(lambda: c) # error, but not a name error due to late binding c: bool = a > 3 if not c: break @@ -58,7 +58,7 @@ async def container_for_problems(): # Nested loops should not duplicate reports for j in range(2): for k in range(3): - lambda: j * k + lambda: j * k # error for j, k, l in [(1, 2, 3)]: @@ -76,3 +76,87 @@ def f(): def explicit_capture(captured=var): return captured + + +# `query` is defined in the function, so also defining it in the loop should be OK. +for name in ["a", "b"]: + query = name + + def myfunc(x): + query = x + query_post = x + _ = query + _ = query_post + + query_post = name # in case iteration order matters + + +# Bug here because two dict comprehensions reference `name`, one of which is inside +# the lambda. This should be totally fine, of course. +_ = { + k: v + for k, v in reduce( + lambda data, event: merge_mappings( + [data, {name: f(caches, data, event) for name, f in xx}] + ), + events, + {name: getattr(group, name) for name in yy}, + ).items() + if k in backfill_fields +} + + +# OK to define lambdas if they're immediately consumed, typically as the `key=` +# argument or in a consumed `filter()` (even if a comprehension is better style) +for x in range(2): + # It's not a complete get-out-of-linting-free construct - these should fail: + min([None, lambda: x], key=repr) + sorted([None, lambda: x], key=repr) + any(filter(bool, [None, lambda: x])) + list(filter(bool, [None, lambda: x])) + all(reduce(bool, [None, lambda: x])) + + # But all these ones should be OK: + min(range(3), key=lambda y: x * y) + max(range(3), key=lambda y: x * y) + sorted(range(3), key=lambda y: x * y) + + any(filter(lambda y: x < y, range(3))) + all(filter(lambda y: x < y, range(3))) + set(filter(lambda y: x < y, range(3))) + list(filter(lambda y: x < y, range(3))) + tuple(filter(lambda y: x < y, range(3))) + sorted(filter(lambda y: x < y, range(3))) + frozenset(filter(lambda y: x < y, range(3))) + + any(reduce(lambda y: x | y, range(3))) + all(reduce(lambda y: x | y, range(3))) + set(reduce(lambda y: x | y, range(3))) + list(reduce(lambda y: x | y, range(3))) + tuple(reduce(lambda y: x | y, range(3))) + sorted(reduce(lambda y: x | y, range(3))) + frozenset(reduce(lambda y: x | y, range(3))) + + import functools + + any(functools.reduce(lambda y: x | y, range(3))) + all(functools.reduce(lambda y: x | y, range(3))) + set(functools.reduce(lambda y: x | y, range(3))) + list(functools.reduce(lambda y: x | y, range(3))) + tuple(functools.reduce(lambda y: x | y, range(3))) + sorted(functools.reduce(lambda y: x | y, range(3))) + frozenset(functools.reduce(lambda y: x | y, range(3))) + +# OK because the lambda which references a loop variable is defined in a `return` +# statement, and after we return the loop variable can't be redefined. +# In principle we could do something fancy with `break`, but it's not worth it. +def iter_f(names): + for name in names: + if exists(name): + return lambda: name if exists(name) else None + + if foo(name): + return [lambda: name] # false alarm, should be fixed? + + if False: + return [lambda: i for i in range(3)] # error diff --git a/tests/test_bugbear.py b/tests/test_bugbear.py index 69719ab..d2db0e0 100644 --- a/tests/test_bugbear.py +++ b/tests/test_bugbear.py @@ -351,6 +351,13 @@ def test_b023(self): B023(61, 16, vars=("j",)), B023(61, 20, vars=("k",)), B023(68, 9, vars=("l",)), + B023(113, 23, vars=("x",)), + B023(114, 26, vars=("x",)), + B023(115, 36, vars=("x",)), + B023(116, 37, vars=("x",)), + B023(117, 36, vars=("x",)), + B023(159, 28, vars=("name",)), # false alarm? + B023(162, 28, vars=("i",)), ) self.assertEqual(errors, expected)