diff --git a/bugbear.py b/bugbear.py index 4b3f981..c512fd3 100644 --- a/bugbear.py +++ b/bugbear.py @@ -592,7 +592,7 @@ def check_for_b023(self, loop_node): # noqa: C901 # check for filter&reduce if ( isinstance(node.func, ast.Name) - and node.func.id in ("filter", "reduce") + and node.func.id in ("filter", "reduce", "map") ) or ( isinstance(node.func, ast.Attribute) and node.func.attr == "reduce" @@ -610,11 +610,14 @@ def check_for_b023(self, loop_node): # noqa: C901 ): safe_functions.append(keyword.value) + # mark `return lambda: x` as safe + # does not (currently) check inner lambdas in a returned expression + # e.g. `return (lambda: x, ) if isinstance(node, ast.Return): if isinstance(node.value, FUNCTION_NODES): safe_functions.append(node.value) - # TODO: ast.walk(node) and mark all child nodes safe? + # find unsafe functions if isinstance(node, FUNCTION_NODES) and node not in safe_functions: argnames = { arg.arg for arg in ast.walk(node.args) if isinstance(arg, ast.arg) diff --git a/tests/b023.py b/tests/b023.py index a7554a1..61d3c10 100644 --- a/tests/b023.py +++ b/tests/b023.py @@ -121,6 +121,14 @@ def myfunc(x): max(range(3), key=lambda y: x * y) sorted(range(3), key=lambda y: x * y) + any(map(lambda y: x < y, range(3))) + all(map(lambda y: x < y, range(3))) + set(map(lambda y: x < y, range(3))) + list(map(lambda y: x < y, range(3))) + tuple(map(lambda y: x < y, range(3))) + sorted(map(lambda y: x < y, range(3))) + frozenset(map(lambda y: x < y, range(3))) + any(filter(lambda y: x < y, range(3))) all(filter(lambda y: x < y, range(3))) set(filter(lambda y: x < y, range(3))) @@ -156,7 +164,7 @@ def iter_f(names): return lambda: name if exists(name) else None if foo(name): - return [lambda: name] # false alarm, should be fixed? + return [lambda: name] # known false alarm if False: return [lambda: i for i in range(3)] # error diff --git a/tests/test_bugbear.py b/tests/test_bugbear.py index d2db0e0..7773221 100644 --- a/tests/test_bugbear.py +++ b/tests/test_bugbear.py @@ -356,8 +356,8 @@ def test_b023(self): B023(115, 36, vars=("x",)), B023(116, 37, vars=("x",)), B023(117, 36, vars=("x",)), - B023(159, 28, vars=("name",)), # false alarm? - B023(162, 28, vars=("i",)), + B023(167, 28, vars=("name",)), # known false alarm + B023(170, 28, vars=("i",)), ) self.assertEqual(errors, expected)