From 8be16280423eafbc3c604a81419f40a8c1681fba Mon Sep 17 00:00:00 2001 From: hauntsaninja <> Date: Sun, 11 Apr 2021 15:38:32 -0700 Subject: [PATCH 1/3] Fix assertion rewriting on Python 3.10 Fixes https://github.com/pytest-dev/pytest/issues/8539 This seems to have been the result of https://bugs.python.org/issue43798 --- AUTHORS | 1 + changelog/8539.bugfix.rst | 1 + src/_pytest/assertion/rewrite.py | 18 ++++++++++++++---- 3 files changed, 16 insertions(+), 4 deletions(-) create mode 100644 changelog/8539.bugfix.rst diff --git a/AUTHORS b/AUTHORS index 46f283452c3..9b3ec153306 100644 --- a/AUTHORS +++ b/AUTHORS @@ -277,6 +277,7 @@ Sankt Petersbug Segev Finer Serhii Mozghovyi Seth Junot +Shantanu Jain Shubham Adep Simon Gomizelj Simon Kerr diff --git a/changelog/8539.bugfix.rst b/changelog/8539.bugfix.rst new file mode 100644 index 00000000000..a2098610e29 --- /dev/null +++ b/changelog/8539.bugfix.rst @@ -0,0 +1 @@ +Fixed assertion rewriting on Python 3.10. diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 6a3222f333d..537ded257e9 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -684,12 +684,9 @@ def run(self, mod: ast.Module) -> None: if not mod.body: # Nothing to do. return + # Insert some special imports at the top of the module but after any # docstrings and __future__ imports. - aliases = [ - ast.alias("builtins", "@py_builtins"), - ast.alias("_pytest.assertion.rewrite", "@pytest_ar"), - ] doc = getattr(mod, "docstring", None) expect_docstring = doc is None if doc is not None and self.is_rewrite_disabled(doc): @@ -721,6 +718,19 @@ def run(self, mod: ast.Module) -> None: lineno = item.decorator_list[0].lineno else: lineno = item.lineno + if sys.version_info >= (3, 10): + aliases = [ + ast.alias("builtins", "@py_builtins", lineno=lineno, col_offset=0), + ast.alias( + "_pytest.assertion.rewrite", "@pytest_ar", + lineno=lineno, col_offset=0 + ), + ] + else: + aliases = [ + ast.alias("builtins", "@py_builtins"), + ast.alias("_pytest.assertion.rewrite", "@pytest_ar"), + ] imports = [ ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases ] From da66f004133e6c6ed5081854146c0ceffd3f144e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 11 Apr 2021 22:44:28 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/_pytest/assertion/rewrite.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 537ded257e9..f661fe9475e 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -722,8 +722,10 @@ def run(self, mod: ast.Module) -> None: aliases = [ ast.alias("builtins", "@py_builtins", lineno=lineno, col_offset=0), ast.alias( - "_pytest.assertion.rewrite", "@pytest_ar", - lineno=lineno, col_offset=0 + "_pytest.assertion.rewrite", + "@pytest_ar", + lineno=lineno, + col_offset=0, ), ] else: From e3dc34ee41b3703808b22357a071302c48fc6fe6 Mon Sep 17 00:00:00 2001 From: hauntsaninja <> Date: Mon, 12 Apr 2021 11:33:40 -0700 Subject: [PATCH 3/3] fixup comments --- src/_pytest/assertion/rewrite.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index f661fe9475e..33e2ef6cc49 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -685,8 +685,8 @@ def run(self, mod: ast.Module) -> None: # Nothing to do. return - # Insert some special imports at the top of the module but after any - # docstrings and __future__ imports. + # We'll insert some special imports at the top of the module, but after any + # docstrings and __future__ imports, so first figure out where that is. doc = getattr(mod, "docstring", None) expect_docstring = doc is None if doc is not None and self.is_rewrite_disabled(doc): @@ -718,6 +718,7 @@ def run(self, mod: ast.Module) -> None: lineno = item.decorator_list[0].lineno else: lineno = item.lineno + # Now actually insert the special imports. if sys.version_info >= (3, 10): aliases = [ ast.alias("builtins", "@py_builtins", lineno=lineno, col_offset=0), @@ -737,6 +738,7 @@ def run(self, mod: ast.Module) -> None: ast.Import([alias], lineno=lineno, col_offset=0) for alias in aliases ] mod.body[pos:pos] = imports + # Collect asserts. nodes: List[ast.AST] = [mod] while nodes: