Skip to content

Commit

Permalink
Remove astor and reproduce the original assertion expression (#5512)
Browse files Browse the repository at this point in the history
Remove astor and reproduce the original assertion expression
  • Loading branch information
nicoddemus committed Jun 28, 2019
2 parents 3c9b46f + 7ee2444 commit 73d918d
Show file tree
Hide file tree
Showing 4 changed files with 197 additions and 55 deletions.
1 change: 0 additions & 1 deletion changelog/3457.trivial.rst

This file was deleted.

1 change: 0 additions & 1 deletion setup.py
Expand Up @@ -13,7 +13,6 @@
"pluggy>=0.12,<1.0",
"importlib-metadata>=0.12",
"wcwidth",
"astor",
]


Expand Down
72 changes: 66 additions & 6 deletions src/_pytest/assertion/rewrite.py
@@ -1,16 +1,18 @@
"""Rewrite assertion AST to produce nice error messages"""
import ast
import errno
import functools
import importlib.machinery
import importlib.util
import io
import itertools
import marshal
import os
import struct
import sys
import tokenize
import types

import astor
import atomicwrites

from _pytest._io.saferepr import saferepr
Expand Down Expand Up @@ -285,7 +287,7 @@ def _rewrite_test(fn, config):
with open(fn, "rb") as f:
source = f.read()
tree = ast.parse(source, filename=fn)
rewrite_asserts(tree, fn, config)
rewrite_asserts(tree, source, fn, config)
co = compile(tree, fn, "exec", dont_inherit=True)
return stat, co

Expand Down Expand Up @@ -327,9 +329,9 @@ def _read_pyc(source, pyc, trace=lambda x: None):
return co


def rewrite_asserts(mod, module_path=None, config=None):
def rewrite_asserts(mod, source, module_path=None, config=None):
"""Rewrite the assert statements in mod."""
AssertionRewriter(module_path, config).run(mod)
AssertionRewriter(module_path, config, source).run(mod)


def _saferepr(obj):
Expand Down Expand Up @@ -457,6 +459,59 @@ def _fix(node, lineno, col_offset):
return node


def _get_assertion_exprs(src: bytes): # -> Dict[int, str]
"""Returns a mapping from {lineno: "assertion test expression"}"""
ret = {}

depth = 0
lines = []
assert_lineno = None
seen_lines = set()

def _write_and_reset() -> None:
nonlocal depth, lines, assert_lineno, seen_lines
ret[assert_lineno] = "".join(lines).rstrip().rstrip("\\")
depth = 0
lines = []
assert_lineno = None
seen_lines = set()

tokens = tokenize.tokenize(io.BytesIO(src).readline)
for tp, src, (lineno, offset), _, line in tokens:
if tp == tokenize.NAME and src == "assert":
assert_lineno = lineno
elif assert_lineno is not None:
# keep track of depth for the assert-message `,` lookup
if tp == tokenize.OP and src in "([{":
depth += 1
elif tp == tokenize.OP and src in ")]}":
depth -= 1

if not lines:
lines.append(line[offset:])
seen_lines.add(lineno)
# a non-nested comma separates the expression from the message
elif depth == 0 and tp == tokenize.OP and src == ",":
# one line assert with message
if lineno in seen_lines and len(lines) == 1:
offset_in_trimmed = offset + len(lines[-1]) - len(line)
lines[-1] = lines[-1][:offset_in_trimmed]
# multi-line assert with message
elif lineno in seen_lines:
lines[-1] = lines[-1][:offset]
# multi line assert with escapd newline before message
else:
lines.append(line[:offset])
_write_and_reset()
elif tp in {tokenize.NEWLINE, tokenize.ENDMARKER}:
_write_and_reset()
elif lines and lineno not in seen_lines:
lines.append(line)
seen_lines.add(lineno)

return ret


class AssertionRewriter(ast.NodeVisitor):
"""Assertion rewriting implementation.
Expand Down Expand Up @@ -511,7 +566,7 @@ class AssertionRewriter(ast.NodeVisitor):
"""

def __init__(self, module_path, config):
def __init__(self, module_path, config, source):
super().__init__()
self.module_path = module_path
self.config = config
Expand All @@ -521,6 +576,11 @@ def __init__(self, module_path, config):
)
else:
self.enable_assertion_pass_hook = False
self.source = source

@functools.lru_cache(maxsize=1)
def _assert_expr_to_lineno(self):
return _get_assertion_exprs(self.source)

def run(self, mod):
"""Find all assert statements in *mod* and rewrite them."""
Expand Down Expand Up @@ -738,7 +798,7 @@ def visit_Assert(self, assert_):

# Passed
fmt_pass = self.helper("_format_explanation", msg)
orig = astor.to_source(assert_.test).rstrip("\n").lstrip("(").rstrip(")")
orig = self._assert_expr_to_lineno()[assert_.lineno]
hook_call_pass = ast.Expr(
self.helper(
"_call_assertion_pass",
Expand Down
178 changes: 131 additions & 47 deletions testing/test_assertrewrite.py
Expand Up @@ -13,6 +13,7 @@
import _pytest._code
import pytest
from _pytest.assertion import util
from _pytest.assertion.rewrite import _get_assertion_exprs
from _pytest.assertion.rewrite import AssertionRewritingHook
from _pytest.assertion.rewrite import PYTEST_TAG
from _pytest.assertion.rewrite import rewrite_asserts
Expand All @@ -31,7 +32,7 @@ def teardown_module(mod):

def rewrite(src):
tree = ast.parse(src)
rewrite_asserts(tree)
rewrite_asserts(tree, src.encode())
return tree


Expand Down Expand Up @@ -1292,10 +1293,10 @@ def test_pattern_contains_subdirectories(self, testdir, hook):
"""
p = testdir.makepyfile(
**{
"tests/file.py": """
def test_simple_failure():
assert 1 + 1 == 3
"""
"tests/file.py": """\
def test_simple_failure():
assert 1 + 1 == 3
"""
}
)
testdir.syspathinsert(p.dirpath())
Expand All @@ -1315,19 +1316,19 @@ def test_cwd_changed(self, testdir, monkeypatch):

testdir.makepyfile(
**{
"test_setup_nonexisting_cwd.py": """
import os
import shutil
import tempfile
d = tempfile.mkdtemp()
os.chdir(d)
shutil.rmtree(d)
""",
"test_test.py": """
def test():
pass
""",
"test_setup_nonexisting_cwd.py": """\
import os
import shutil
import tempfile
d = tempfile.mkdtemp()
os.chdir(d)
shutil.rmtree(d)
""",
"test_test.py": """\
def test():
pass
""",
}
)
result = testdir.runpytest()
Expand All @@ -1339,23 +1340,22 @@ def test_option_default(self, testdir):
config = testdir.parseconfig()
assert config.getini("enable_assertion_pass_hook") is False

def test_hook_call(self, testdir):
@pytest.fixture
def flag_on(self, testdir):
testdir.makeini("[pytest]\nenable_assertion_pass_hook = True\n")

@pytest.fixture
def hook_on(self, testdir):
testdir.makeconftest(
"""
"""\
def pytest_assertion_pass(item, lineno, orig, expl):
raise Exception("Assertion Passed: {} {} at line {}".format(orig, expl, lineno))
"""
)

testdir.makeini(
"""
[pytest]
enable_assertion_pass_hook = True
"""
)

def test_hook_call(self, testdir, flag_on, hook_on):
testdir.makepyfile(
"""
"""\
def test_simple():
a=1
b=2
Expand All @@ -1371,10 +1371,21 @@ def test_fails():
)
result = testdir.runpytest()
result.stdout.fnmatch_lines(
"*Assertion Passed: a + b == c + d (1 + 2) == (3 + 0) at line 7*"
"*Assertion Passed: a+b == c+d (1 + 2) == (3 + 0) at line 7*"
)

def test_hook_call_with_parens(self, testdir, flag_on, hook_on):
testdir.makepyfile(
"""\
def f(): return 1
def test():
assert f()
"""
)
result = testdir.runpytest()
result.stdout.fnmatch_lines("*Assertion Passed: f() 1")

def test_hook_not_called_without_hookimpl(self, testdir, monkeypatch):
def test_hook_not_called_without_hookimpl(self, testdir, monkeypatch, flag_on):
"""Assertion pass should not be called (and hence formatting should
not occur) if there is no hook declared for pytest_assertion_pass"""

Expand All @@ -1385,15 +1396,8 @@ def raise_on_assertionpass(*_, **__):
_pytest.assertion.rewrite, "_call_assertion_pass", raise_on_assertionpass
)

testdir.makeini(
"""
[pytest]
enable_assertion_pass_hook = True
"""
)

testdir.makepyfile(
"""
"""\
def test_simple():
a=1
b=2
Expand All @@ -1418,21 +1422,14 @@ def raise_on_assertionpass(*_, **__):
)

testdir.makeconftest(
"""
"""\
def pytest_assertion_pass(item, lineno, orig, expl):
raise Exception("Assertion Passed: {} {} at line {}".format(orig, expl, lineno))
"""
)

testdir.makeini(
"""
[pytest]
enable_assertion_pass_hook = False
"""
)

testdir.makepyfile(
"""
"""\
def test_simple():
a=1
b=2
Expand All @@ -1444,3 +1441,90 @@ def test_simple():
)
result = testdir.runpytest()
result.assert_outcomes(passed=1)


@pytest.mark.parametrize(
("src", "expected"),
(
# fmt: off
pytest.param(b"", {}, id="trivial"),
pytest.param(
b"def x(): assert 1\n",
{1: "1"},
id="assert statement not on own line",
),
pytest.param(
b"def x():\n"
b" assert 1\n"
b" assert 1+2\n",
{2: "1", 3: "1+2"},
id="multiple assertions",
),
pytest.param(
# changes in encoding cause the byte offsets to be different
"# -*- coding: latin1\n"
"def ÀÀÀÀÀ(): assert 1\n".encode("latin1"),
{2: "1"},
id="latin1 encoded on first line\n",
),
pytest.param(
# using the default utf-8 encoding
"def ÀÀÀÀÀ(): assert 1\n".encode(),
{1: "1"},
id="utf-8 encoded on first line",
),
pytest.param(
b"def x():\n"
b" assert (\n"
b" 1 + 2 # comment\n"
b" )\n",
{2: "(\n 1 + 2 # comment\n )"},
id="multi-line assertion",
),
pytest.param(
b"def x():\n"
b" assert y == [\n"
b" 1, 2, 3\n"
b" ]\n",
{2: "y == [\n 1, 2, 3\n ]"},
id="multi line assert with list continuation",
),
pytest.param(
b"def x():\n"
b" assert 1 + \\\n"
b" 2\n",
{2: "1 + \\\n 2"},
id="backslash continuation",
),
pytest.param(
b"def x():\n"
b" assert x, y\n",
{2: "x"},
id="assertion with message",
),
pytest.param(
b"def x():\n"
b" assert (\n"
b" f(1, 2, 3)\n"
b" ), 'f did not work!'\n",
{2: "(\n f(1, 2, 3)\n )"},
id="assertion with message, test spanning multiple lines",
),
pytest.param(
b"def x():\n"
b" assert \\\n"
b" x\\\n"
b" , 'failure message'\n",
{2: "x"},
id="escaped newlines plus message",
),
pytest.param(
b"def x(): assert 5",
{1: "5"},
id="no newline at end of file",
),
# fmt: on
),
)
def test_get_assertion_exprs(src, expected):
assert _get_assertion_exprs(src) == expected

0 comments on commit 73d918d

Please sign in to comment.