Skip to content

Commit

Permalink
Revisit AssertionRewriter.run (#435)
Browse files Browse the repository at this point in the history
NOTE: fixing/reverting of `doc = getattr(mod, "docstring", None)` was
missed in 3153740 (pytest-dev#4723).

It was added initially in
pytest-dev#2870, but never made it into a
final Python release.  Basically reverts / revisits based on changes introduced in 734c435.

Adjusts a7dfc6f (#243).
  • Loading branch information
blueyed committed Sep 25, 2020
1 parent 865df4c commit 2e07887
Show file tree
Hide file tree
Showing 6 changed files with 114 additions and 70 deletions.
4 changes: 1 addition & 3 deletions src/_pytest/_code/code.py
Expand Up @@ -244,14 +244,12 @@ def getsource(self, astcache=None) -> Optional["Source"]:
astnode = astcache.get(key, None)
start = self.getfirstlinesource()
try:
astnode, ast_start, end = getstatementrange_ast(
astnode, _, end = getstatementrange_ast(
self.lineno, source, astnode=astnode
)
except SyntaxError:
end = self.lineno + 1
else:
if ast_start - 1 < start:
start = ast_start
if key is not None:
astcache[key] = astnode
return source[start:end]
Expand Down
1 change: 1 addition & 0 deletions src/_pytest/_code/source.py
Expand Up @@ -363,6 +363,7 @@ def get_statement_startend2(lineno: int, node: ast.AST) -> Tuple[int, Optional[i
values.append(val[0].lineno - 1 - 1)
values.sort()
insert_index = bisect_right(values, lineno)
assert insert_index > 0, (insert_index, values, lineno)
start = values[insert_index - 1]
if insert_index >= len(values):
end = None
Expand Down
54 changes: 25 additions & 29 deletions src/_pytest/assertion/rewrite.py
Expand Up @@ -618,44 +618,30 @@ def _assert_expr_to_lineno(self):
def run(self, mod: ast.Module) -> None:
"""Find all assert statements in *mod* and rewrite them."""
if not mod.body:
# Nothing to do.
return
# Insert some special imports at the top of the module but after any
# docstrings and __future__ imports.
aliases = [
ast.alias("builtins", "@py_builtins"),
ast.alias("_pytest.assertion.rewrite", "@pytest_ar"),
]
doc = getattr(mod, "docstring", None)
expect_docstring = doc is None
if doc is not None and self.is_rewrite_disabled(doc):
return

# Find position for imports (after docstrings and __future__ imports).
iter_body = iter(mod.body)
pos = 0
doc = ast.get_docstring(mod, clean=False)
if doc:
if self.is_rewrite_disabled(doc):
return
skipped_doc = next(iter_body)
assert isinstance(skipped_doc, ast.Expr)
pos += 1

lineno = 1
for item in mod.body:
for item in iter_body:
if (
expect_docstring
and isinstance(item, ast.Expr)
and isinstance(item.value, ast.Str)
):
doc = item.value.s
if self.is_rewrite_disabled(doc):
return
expect_docstring = False
elif (
not isinstance(item, ast.ImportFrom)
or item.level > 0
or item.module != "__future__"
):
lineno = item.lineno
break
pos += 1
else:
lineno = item.lineno
imports = [
ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases
]
mod.body[pos:pos] = imports
pos += 1

# Collect asserts.
nodes = [mod] # type: List[ast.AST]
while nodes:
Expand All @@ -680,8 +666,18 @@ def run(self, mod: ast.Module) -> None:
):
nodes.append(field)

imports = ast.Import(
names=[
ast.alias("builtins", "@py_builtins"),
ast.alias("_pytest.assertion.rewrite", "@pytest_ar"),
],
lineno=lineno,
col_offset=0,
)
mod.body.insert(pos, imports)

@staticmethod
def is_rewrite_disabled(docstring):
def is_rewrite_disabled(docstring: str) -> bool:
return "PYTEST_DONT_REWRITE" in docstring

def variable(self):
Expand Down
10 changes: 7 additions & 3 deletions testing/code/test_code.py
Expand Up @@ -7,6 +7,10 @@
from _pytest._code import ExceptionInfo
from _pytest._code import Frame
from _pytest._code.code import ReprFuncArgs
from _pytest.compat import TYPE_CHECKING

if TYPE_CHECKING:
from _pytest.pytester import Testdir


def test_ne() -> None:
Expand Down Expand Up @@ -231,15 +235,15 @@ def test_not_raise_exception_with_mixed_encoding(self, tw_mock) -> None:
)


def test_nameerror_with_decorator(testdir):
# TODO: unittest with Code (additionally)?!
def test_nameerror_with_decorator(testdir: "Testdir") -> None:
"""Ref: https://github.com/pytest-dev/pytest/issues/4984"""
source = """
@nameerror_deco1
def test():
pass
"""
p1 = testdir.makepyfile(source)
result = testdir.runpytest(str(p1), "-rf")
result = testdir.runpytest(p1)
result.stdout.fnmatch_lines(
[
"*_ ERROR collecting test_nameerror_with_decorator.py _*",
Expand Down
20 changes: 20 additions & 0 deletions testing/code/test_source.py
Expand Up @@ -15,6 +15,7 @@
import pytest
from _pytest._code import getfslineno
from _pytest._code import Source
from _pytest._code.source import get_statement_startend2


def test_source_str_function() -> None:
Expand Down Expand Up @@ -798,3 +799,22 @@ def __init__(self, *args):
# fmt: on
values = [i for i in x.source.lines if i.strip()]
assert len(values) == 4


def test_deco_statements() -> None:
"""Ref: https://github.com/pytest-dev/pytest/issues/4984"""
code = "\n".join(
[
"@deco",
"def test(): pass",
"",
"last_line = 1",
]
)
astnode = ast.parse(code, "source", "exec")
assert get_statement_startend2(0, astnode) == (0, 1)

assert getstatement(0, code).lines == ["@deco"]
assert getstatement(1, code).lines == ["def test(): pass"]
assert getstatement(2, code).lines == ["def test(): pass"]
assert getstatement(3, code).lines == ["last_line = 1"]
95 changes: 60 additions & 35 deletions testing/test_assertrewrite.py
Expand Up @@ -36,9 +36,10 @@ def teardown_module(mod):
del mod._old_reprcompare


def rewrite(src):
def rewrite(src: str) -> ast.Module:
tree = ast.parse(src)
rewrite_asserts(tree, src.encode())
compile(tree, "test-compile", "exec")
return tree


Expand Down Expand Up @@ -67,47 +68,71 @@ def getmsg(f, extra_ns=None, must_pass=False):


class TestAssertionRewrite:
def test_place_initial_imports(self):
s = """'Doc string'\nother = stuff"""
m = rewrite(s)
def test_place_initial_imports(self) -> None:
m = rewrite("'Doc string'\nother = stuff")
assert isinstance(m.body[0], ast.Expr)
for imp in m.body[1:3]:
assert isinstance(imp, ast.Import)
assert imp.lineno == 2
assert imp.col_offset == 0
assert isinstance(m.body[3], ast.Assign)
s = """from __future__ import division\nother_stuff"""
m = rewrite(s)
imp = m.body[1]
assert isinstance(imp, ast.Import)
assert imp.lineno == 1
assert imp.col_offset == 0
assert isinstance(m.body[2], ast.Assign)
assert len(m.body) == 3

m = rewrite("'''Multiline\nDoc string'''\nother = stuff")
assert isinstance(m.body[0], ast.Expr)
imp = m.body[1]
assert isinstance(imp, ast.Import)
assert imp.lineno == 1
assert imp.col_offset == 0
assert isinstance(m.body[2], ast.Assign)
assert len(m.body) == 3

m = rewrite("from __future__ import division\nother_stuff")
assert isinstance(m.body[0], ast.ImportFrom)
for imp in m.body[1:3]:
assert isinstance(imp, ast.Import)
assert imp.lineno == 2
assert imp.col_offset == 0
assert isinstance(m.body[3], ast.Expr)
s = """'doc string'\nfrom __future__ import division"""
m = rewrite(s)
imp = m.body[1]
assert isinstance(imp, ast.Import)
assert imp.lineno == 1
assert imp.col_offset == 0
assert isinstance(m.body[2], ast.Expr)
assert len(m.body) == 3

m = rewrite("'doc string'\nfrom __future__ import division")
assert isinstance(m.body[0], ast.Expr)
assert isinstance(m.body[1], ast.ImportFrom)
for imp in m.body[2:4]:
assert isinstance(imp, ast.Import)
assert imp.lineno == 2
assert imp.col_offset == 0
s = """'doc string'\nfrom __future__ import division\nother"""
m = rewrite(s)
imp = m.body[2]
assert isinstance(imp, ast.Import)
assert imp.lineno == 2
assert imp.col_offset == 0
assert len(m.body) == 3

m = rewrite("'doc string'\nfrom __future__ import division\nother")
assert isinstance(m.body[0], ast.Expr)
assert isinstance(m.body[1], ast.ImportFrom)
for imp in m.body[2:4]:
assert isinstance(imp, ast.Import)
assert imp.lineno == 3
assert imp.col_offset == 0
assert isinstance(m.body[4], ast.Expr)
s = """from . import relative\nother_stuff"""
m = rewrite(s)
for imp in m.body[:2]:
assert isinstance(imp, ast.Import)
assert imp.lineno == 1
assert imp.col_offset == 0
imp = m.body[2]
assert isinstance(imp, ast.Import)
assert imp.lineno == 2
assert imp.col_offset == 0
assert isinstance(m.body[3], ast.Expr)
assert len(m.body) == 4

m = rewrite("from . import relative\nother_stuff")
imp = m.body[0]
assert isinstance(imp, ast.Import)
assert imp.lineno == 1
assert imp.col_offset == 0
assert isinstance(m.body[1], ast.ImportFrom)
assert isinstance(m.body[2], ast.Expr)
assert len(m.body) == 3

def test_place_initial_imports_decorated_func(self) -> None:
"""Ref: https://github.com/pytest-dev/pytest/issues/4984"""
m = rewrite("@deco\ndef func(): pass")
imp = m.body[0]
assert isinstance(imp, ast.Import)
assert imp.lineno == 1
assert imp.col_offset == 0
assert isinstance(m.body[1], ast.FunctionDef)
assert len(m.body) == 2

def test_dont_rewrite(self):
s = """'PYTEST_DONT_REWRITE'\nassert 14"""
Expand Down

0 comments on commit 2e07887

Please sign in to comment.