From 9e488c385ffc589d83d924ddb6f78acfe9984c85 Mon Sep 17 00:00:00 2001 From: Bruno Oliveira Date: Sun, 2 Jun 2019 11:02:22 -0300 Subject: [PATCH] Fix all() unroll for non-generators/non-list comprehensions Fix #5358 --- changelog/5358.bugfix.rst | 1 + src/_pytest/assertion/rewrite.py | 14 ++++++++++---- testing/test_assertrewrite.py | 15 +++++++++++++-- 3 files changed, 24 insertions(+), 6 deletions(-) create mode 100644 changelog/5358.bugfix.rst diff --git a/changelog/5358.bugfix.rst b/changelog/5358.bugfix.rst new file mode 100644 index 00000000000..181da1e0ec2 --- /dev/null +++ b/changelog/5358.bugfix.rst @@ -0,0 +1 @@ +Fix assertion rewriting of ``all()`` calls to deal with non-generators. diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 5e2c5397bb8..790760e640e 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -949,11 +949,19 @@ def visit_BinOp(self, binop): res = self.assign(ast.BinOp(left_expr, binop.op, right_expr)) return res, explanation + def _is_any_call_with_generator_or_list_comprehension(self, call): + """Return True if the Call node is an 'any' call with a generator or list comprehension""" + return ( + isinstance(call.func, ast.Name) + and call.func.id == "all" + and isinstance(call.args[0], (ast.GeneratorExp, ast.ListComp)) + ) + def visit_Call_35(self, call): """ visit `ast.Call` nodes on Python3.5 and after """ - if isinstance(call.func, ast.Name) and call.func.id == "all": + if self._is_any_call_with_generator_or_list_comprehension(call): return self._visit_all(call) new_func, func_expl = self.visit(call.func) arg_expls = [] @@ -980,8 +988,6 @@ def visit_Call_35(self, call): def _visit_all(self, call): """Special rewrite for the builtin all function, see #5062""" - if not isinstance(call.args[0], (ast.GeneratorExp, ast.ListComp)): - return gen_exp = call.args[0] assertion_module = ast.Module( body=[ast.Assert(test=gen_exp.elt, lineno=1, msg="", col_offset=1)] @@ -1009,7 +1015,7 @@ def visit_Call_legacy(self, call): """ visit `ast.Call nodes on 3.4 and below` """ - if isinstance(call.func, ast.Name) and call.func.id == "all": + if self._is_any_call_with_generator_or_list_comprehension(call): return self._visit_all(call) new_func, func_expl = self.visit(call.func) arg_expls = [] diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 19d050f8769..ce8663345e7 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -677,7 +677,7 @@ def __repr__(self): assert "UnicodeDecodeError" not in msg assert "UnicodeEncodeError" not in msg - def test_unroll_generator(self, testdir): + def test_unroll_all_generator(self, testdir): testdir.makepyfile( """ def check_even(num): @@ -692,7 +692,7 @@ def test_generator(): result = testdir.runpytest() result.stdout.fnmatch_lines(["*assert False*", "*where False = check_even(1)*"]) - def test_unroll_list_comprehension(self, testdir): + def test_unroll_all_list_comprehension(self, testdir): testdir.makepyfile( """ def check_even(num): @@ -707,6 +707,17 @@ def test_list_comprehension(): result = testdir.runpytest() result.stdout.fnmatch_lines(["*assert False*", "*where False = check_even(1)*"]) + def test_unroll_all_object(self, testdir): + """all() for non generators/non list-comprehensions (#5358)""" + testdir.makepyfile( + """ + def test(): + assert all((1, 0)) + """ + ) + result = testdir.runpytest() + result.stdout.fnmatch_lines(["*assert False*", "*where False = all((1, 0))*"]) + def test_for_loop(self, testdir): testdir.makepyfile( """