diff --git a/changelog/3457.trivial.rst b/changelog/3457.trivial.rst deleted file mode 100644 index f1888763440..00000000000 --- a/changelog/3457.trivial.rst +++ /dev/null @@ -1 +0,0 @@ -pytest now also depends on the `astor `__ package. diff --git a/setup.py b/setup.py index 7d953281611..4c87c6429bb 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,6 @@ "pluggy>=0.12,<1.0", "importlib-metadata>=0.12", "wcwidth", - "astor", ] diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 2a82e9c977c..8b2c1e14610 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -1,16 +1,18 @@ """Rewrite assertion AST to produce nice error messages""" import ast import errno +import functools import importlib.machinery import importlib.util +import io import itertools import marshal import os import struct import sys +import tokenize import types -import astor import atomicwrites from _pytest._io.saferepr import saferepr @@ -285,7 +287,7 @@ def _rewrite_test(fn, config): with open(fn, "rb") as f: source = f.read() tree = ast.parse(source, filename=fn) - rewrite_asserts(tree, fn, config) + rewrite_asserts(tree, source, fn, config) co = compile(tree, fn, "exec", dont_inherit=True) return stat, co @@ -327,9 +329,9 @@ def _read_pyc(source, pyc, trace=lambda x: None): return co -def rewrite_asserts(mod, module_path=None, config=None): +def rewrite_asserts(mod, source, module_path=None, config=None): """Rewrite the assert statements in mod.""" - AssertionRewriter(module_path, config).run(mod) + AssertionRewriter(module_path, config, source).run(mod) def _saferepr(obj): @@ -457,6 +459,59 @@ def _fix(node, lineno, col_offset): return node +def _get_assertion_exprs(src: bytes): # -> Dict[int, str] + """Returns a mapping from {lineno: "assertion test expression"}""" + ret = {} + + depth = 0 + lines = [] + assert_lineno = None + seen_lines = set() + + def _write_and_reset() -> None: + nonlocal depth, lines, assert_lineno, seen_lines + ret[assert_lineno] = "".join(lines).rstrip().rstrip("\\") + depth = 0 + lines = [] + assert_lineno = None + seen_lines = set() + + tokens = tokenize.tokenize(io.BytesIO(src).readline) + for tp, src, (lineno, offset), _, line in tokens: + if tp == tokenize.NAME and src == "assert": + assert_lineno = lineno + elif assert_lineno is not None: + # keep track of depth for the assert-message `,` lookup + if tp == tokenize.OP and src in "([{": + depth += 1 + elif tp == tokenize.OP and src in ")]}": + depth -= 1 + + if not lines: + lines.append(line[offset:]) + seen_lines.add(lineno) + # a non-nested comma separates the expression from the message + elif depth == 0 and tp == tokenize.OP and src == ",": + # one line assert with message + if lineno in seen_lines and len(lines) == 1: + offset_in_trimmed = offset + len(lines[-1]) - len(line) + lines[-1] = lines[-1][:offset_in_trimmed] + # multi-line assert with message + elif lineno in seen_lines: + lines[-1] = lines[-1][:offset] + # multi line assert with escapd newline before message + else: + lines.append(line[:offset]) + _write_and_reset() + elif tp in {tokenize.NEWLINE, tokenize.ENDMARKER}: + _write_and_reset() + elif lines and lineno not in seen_lines: + lines.append(line) + seen_lines.add(lineno) + + return ret + + class AssertionRewriter(ast.NodeVisitor): """Assertion rewriting implementation. @@ -511,7 +566,7 @@ class AssertionRewriter(ast.NodeVisitor): """ - def __init__(self, module_path, config): + def __init__(self, module_path, config, source): super().__init__() self.module_path = module_path self.config = config @@ -521,6 +576,11 @@ def __init__(self, module_path, config): ) else: self.enable_assertion_pass_hook = False + self.source = source + + @functools.lru_cache(maxsize=1) + def _assert_expr_to_lineno(self): + return _get_assertion_exprs(self.source) def run(self, mod): """Find all assert statements in *mod* and rewrite them.""" @@ -738,7 +798,7 @@ def visit_Assert(self, assert_): # Passed fmt_pass = self.helper("_format_explanation", msg) - orig = astor.to_source(assert_.test).rstrip("\n").lstrip("(").rstrip(")") + orig = self._assert_expr_to_lineno()[assert_.lineno] hook_call_pass = ast.Expr( self.helper( "_call_assertion_pass", diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 8d1c7a5f000..b8242b37d1c 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -13,6 +13,7 @@ import _pytest._code import pytest from _pytest.assertion import util +from _pytest.assertion.rewrite import _get_assertion_exprs from _pytest.assertion.rewrite import AssertionRewritingHook from _pytest.assertion.rewrite import PYTEST_TAG from _pytest.assertion.rewrite import rewrite_asserts @@ -31,7 +32,7 @@ def teardown_module(mod): def rewrite(src): tree = ast.parse(src) - rewrite_asserts(tree) + rewrite_asserts(tree, src.encode()) return tree @@ -1292,10 +1293,10 @@ def test_pattern_contains_subdirectories(self, testdir, hook): """ p = testdir.makepyfile( **{ - "tests/file.py": """ - def test_simple_failure(): - assert 1 + 1 == 3 - """ + "tests/file.py": """\ + def test_simple_failure(): + assert 1 + 1 == 3 + """ } ) testdir.syspathinsert(p.dirpath()) @@ -1315,19 +1316,19 @@ def test_cwd_changed(self, testdir, monkeypatch): testdir.makepyfile( **{ - "test_setup_nonexisting_cwd.py": """ - import os - import shutil - import tempfile - - d = tempfile.mkdtemp() - os.chdir(d) - shutil.rmtree(d) - """, - "test_test.py": """ - def test(): - pass - """, + "test_setup_nonexisting_cwd.py": """\ + import os + import shutil + import tempfile + + d = tempfile.mkdtemp() + os.chdir(d) + shutil.rmtree(d) + """, + "test_test.py": """\ + def test(): + pass + """, } ) result = testdir.runpytest() @@ -1339,23 +1340,22 @@ def test_option_default(self, testdir): config = testdir.parseconfig() assert config.getini("enable_assertion_pass_hook") is False - def test_hook_call(self, testdir): + @pytest.fixture + def flag_on(self, testdir): + testdir.makeini("[pytest]\nenable_assertion_pass_hook = True\n") + + @pytest.fixture + def hook_on(self, testdir): testdir.makeconftest( - """ + """\ def pytest_assertion_pass(item, lineno, orig, expl): raise Exception("Assertion Passed: {} {} at line {}".format(orig, expl, lineno)) """ ) - testdir.makeini( - """ - [pytest] - enable_assertion_pass_hook = True - """ - ) - + def test_hook_call(self, testdir, flag_on, hook_on): testdir.makepyfile( - """ + """\ def test_simple(): a=1 b=2 @@ -1371,10 +1371,21 @@ def test_fails(): ) result = testdir.runpytest() result.stdout.fnmatch_lines( - "*Assertion Passed: a + b == c + d (1 + 2) == (3 + 0) at line 7*" + "*Assertion Passed: a+b == c+d (1 + 2) == (3 + 0) at line 7*" + ) + + def test_hook_call_with_parens(self, testdir, flag_on, hook_on): + testdir.makepyfile( + """\ + def f(): return 1 + def test(): + assert f() + """ ) + result = testdir.runpytest() + result.stdout.fnmatch_lines("*Assertion Passed: f() 1") - def test_hook_not_called_without_hookimpl(self, testdir, monkeypatch): + def test_hook_not_called_without_hookimpl(self, testdir, monkeypatch, flag_on): """Assertion pass should not be called (and hence formatting should not occur) if there is no hook declared for pytest_assertion_pass""" @@ -1385,15 +1396,8 @@ def raise_on_assertionpass(*_, **__): _pytest.assertion.rewrite, "_call_assertion_pass", raise_on_assertionpass ) - testdir.makeini( - """ - [pytest] - enable_assertion_pass_hook = True - """ - ) - testdir.makepyfile( - """ + """\ def test_simple(): a=1 b=2 @@ -1418,21 +1422,14 @@ def raise_on_assertionpass(*_, **__): ) testdir.makeconftest( - """ + """\ def pytest_assertion_pass(item, lineno, orig, expl): raise Exception("Assertion Passed: {} {} at line {}".format(orig, expl, lineno)) """ ) - testdir.makeini( - """ - [pytest] - enable_assertion_pass_hook = False - """ - ) - testdir.makepyfile( - """ + """\ def test_simple(): a=1 b=2 @@ -1444,3 +1441,90 @@ def test_simple(): ) result = testdir.runpytest() result.assert_outcomes(passed=1) + + +@pytest.mark.parametrize( + ("src", "expected"), + ( + # fmt: off + pytest.param(b"", {}, id="trivial"), + pytest.param( + b"def x(): assert 1\n", + {1: "1"}, + id="assert statement not on own line", + ), + pytest.param( + b"def x():\n" + b" assert 1\n" + b" assert 1+2\n", + {2: "1", 3: "1+2"}, + id="multiple assertions", + ), + pytest.param( + # changes in encoding cause the byte offsets to be different + "# -*- coding: latin1\n" + "def ÀÀÀÀÀ(): assert 1\n".encode("latin1"), + {2: "1"}, + id="latin1 encoded on first line\n", + ), + pytest.param( + # using the default utf-8 encoding + "def ÀÀÀÀÀ(): assert 1\n".encode(), + {1: "1"}, + id="utf-8 encoded on first line", + ), + pytest.param( + b"def x():\n" + b" assert (\n" + b" 1 + 2 # comment\n" + b" )\n", + {2: "(\n 1 + 2 # comment\n )"}, + id="multi-line assertion", + ), + pytest.param( + b"def x():\n" + b" assert y == [\n" + b" 1, 2, 3\n" + b" ]\n", + {2: "y == [\n 1, 2, 3\n ]"}, + id="multi line assert with list continuation", + ), + pytest.param( + b"def x():\n" + b" assert 1 + \\\n" + b" 2\n", + {2: "1 + \\\n 2"}, + id="backslash continuation", + ), + pytest.param( + b"def x():\n" + b" assert x, y\n", + {2: "x"}, + id="assertion with message", + ), + pytest.param( + b"def x():\n" + b" assert (\n" + b" f(1, 2, 3)\n" + b" ), 'f did not work!'\n", + {2: "(\n f(1, 2, 3)\n )"}, + id="assertion with message, test spanning multiple lines", + ), + pytest.param( + b"def x():\n" + b" assert \\\n" + b" x\\\n" + b" , 'failure message'\n", + {2: "x"}, + id="escaped newlines plus message", + ), + pytest.param( + b"def x(): assert 5", + {1: "5"}, + id="no newline at end of file", + ), + # fmt: on + ), +) +def test_get_assertion_exprs(src, expected): + assert _get_assertion_exprs(src) == expected