Skip to content

Commit

Permalink
Features assertion pass hook (#3479)
Browse files Browse the repository at this point in the history
Features assertion pass hook
  • Loading branch information
nicoddemus committed Jun 27, 2019
2 parents 790806e + 2ea2221 commit 37fb50a
Show file tree
Hide file tree
Showing 11 changed files with 302 additions and 42 deletions.
4 changes: 4 additions & 0 deletions changelog/3457.feature.rst
@@ -0,0 +1,4 @@
New `pytest_assertion_pass <https://docs.pytest.org/en/latest/reference.html#_pytest.hookspec.pytest_assertion_pass>`__
hook, called with context information when an assertion *passes*.

This hook is still **experimental** so use it with caution.
1 change: 1 addition & 0 deletions changelog/3457.trivial.rst
@@ -0,0 +1 @@
pytest now also depends on the `astor <https://pypi.org/project/astor/>`__ package.
7 changes: 3 additions & 4 deletions doc/en/reference.rst
Expand Up @@ -665,15 +665,14 @@ Session related reporting hooks:
.. autofunction:: pytest_fixture_post_finalizer
.. autofunction:: pytest_warning_captured

And here is the central hook for reporting about
test execution:
Central hook for reporting about test execution:

.. autofunction:: pytest_runtest_logreport

You can also use this hook to customize assertion representation for some
types:
Assertion related hooks:

.. autofunction:: pytest_assertrepr_compare
.. autofunction:: pytest_assertion_pass


Debugging/Interaction hooks
Expand Down
1 change: 1 addition & 0 deletions setup.py
Expand Up @@ -13,6 +13,7 @@
"pluggy>=0.12,<1.0",
"importlib-metadata>=0.12",
"wcwidth",
"astor",
]


Expand Down
19 changes: 18 additions & 1 deletion src/_pytest/assertion/__init__.py
Expand Up @@ -23,6 +23,13 @@ def pytest_addoption(parser):
test modules on import to provide assert
expression information.""",
)
parser.addini(
"enable_assertion_pass_hook",
type="bool",
default=False,
help="Enables the pytest_assertion_pass hook."
"Make sure to delete any previously generated pyc cache files.",
)


def register_assert_rewrite(*names):
Expand Down Expand Up @@ -92,7 +99,7 @@ def pytest_collection(session):


def pytest_runtest_setup(item):
"""Setup the pytest_assertrepr_compare hook
"""Setup the pytest_assertrepr_compare and pytest_assertion_pass hooks
The newinterpret and rewrite modules will use util._reprcompare if
it exists to use custom reporting via the
Expand Down Expand Up @@ -129,9 +136,19 @@ def callbinrepr(op, left, right):

util._reprcompare = callbinrepr

if item.ihook.pytest_assertion_pass.get_hookimpls():

def call_assertion_pass_hook(lineno, expl, orig):
item.ihook.pytest_assertion_pass(
item=item, lineno=lineno, orig=orig, expl=expl
)

util._assertion_pass = call_assertion_pass_hook


def pytest_runtest_teardown(item):
util._reprcompare = None
util._assertion_pass = None


def pytest_sessionfinish(session):
Expand Down
152 changes: 116 additions & 36 deletions src/_pytest/assertion/rewrite.py
Expand Up @@ -10,6 +10,7 @@
import sys
import types

import astor
import atomicwrites

from _pytest._io.saferepr import saferepr
Expand Down Expand Up @@ -134,7 +135,7 @@ def exec_module(self, module):
co = _read_pyc(fn, pyc, state.trace)
if co is None:
state.trace("rewriting {!r}".format(fn))
source_stat, co = _rewrite_test(fn)
source_stat, co = _rewrite_test(fn, self.config)
if write:
self._writing_pyc = True
try:
Expand Down Expand Up @@ -278,13 +279,13 @@ def _write_pyc(state, co, source_stat, pyc):
return True


def _rewrite_test(fn):
def _rewrite_test(fn, config):
"""read and rewrite *fn* and return the code object."""
stat = os.stat(fn)
with open(fn, "rb") as f:
source = f.read()
tree = ast.parse(source, filename=fn)
rewrite_asserts(tree, fn)
rewrite_asserts(tree, fn, config)
co = compile(tree, fn, "exec", dont_inherit=True)
return stat, co

Expand Down Expand Up @@ -326,9 +327,9 @@ def _read_pyc(source, pyc, trace=lambda x: None):
return co


def rewrite_asserts(mod, module_path=None):
def rewrite_asserts(mod, module_path=None, config=None):
"""Rewrite the assert statements in mod."""
AssertionRewriter(module_path).run(mod)
AssertionRewriter(module_path, config).run(mod)


def _saferepr(obj):
Expand Down Expand Up @@ -401,6 +402,17 @@ def _call_reprcompare(ops, results, expls, each_obj):
return expl


def _call_assertion_pass(lineno, orig, expl):
if util._assertion_pass is not None:
util._assertion_pass(lineno=lineno, orig=orig, expl=expl)


def _check_if_assertion_pass_impl():
"""Checks if any plugins implement the pytest_assertion_pass hook
in order not to generate explanation unecessarily (might be expensive)"""
return True if util._assertion_pass else False


unary_map = {ast.Not: "not %s", ast.Invert: "~%s", ast.USub: "-%s", ast.UAdd: "+%s"}

binop_map = {
Expand Down Expand Up @@ -473,7 +485,8 @@ class AssertionRewriter(ast.NodeVisitor):
original assert statement: it rewrites the test of an assertion
to provide intermediate values and replace it with an if statement
which raises an assertion error with a detailed explanation in
case the expression is false.
case the expression is false and calls pytest_assertion_pass hook
if expression is true.
For this .visit_Assert() uses the visitor pattern to visit all the
AST nodes of the ast.Assert.test field, each visit call returning
Expand All @@ -491,9 +504,10 @@ class AssertionRewriter(ast.NodeVisitor):
by statements. Variables are created using .variable() and
have the form of "@py_assert0".
:on_failure: The AST statements which will be executed if the
assertion test fails. This is the code which will construct
the failure message and raises the AssertionError.
:expl_stmts: The AST statements which will be executed to get
data from the assertion. This is the code which will construct
the detailed assertion message that is used in the AssertionError
or for the pytest_assertion_pass hook.
:explanation_specifiers: A dict filled by .explanation_param()
with %-formatting placeholders and their corresponding
Expand All @@ -509,9 +523,16 @@ class AssertionRewriter(ast.NodeVisitor):
"""

def __init__(self, module_path):
def __init__(self, module_path, config):
super().__init__()
self.module_path = module_path
self.config = config
if config is not None:
self.enable_assertion_pass_hook = config.getini(
"enable_assertion_pass_hook"
)
else:
self.enable_assertion_pass_hook = False

def run(self, mod):
"""Find all assert statements in *mod* and rewrite them."""
Expand Down Expand Up @@ -642,7 +663,7 @@ def pop_format_context(self, expl_expr):
The expl_expr should be an ast.Str instance constructed from
the %-placeholders created by .explanation_param(). This will
add the required code to format said string to .on_failure and
add the required code to format said string to .expl_stmts and
return the ast.Name instance of the formatted string.
"""
Expand All @@ -653,7 +674,9 @@ def pop_format_context(self, expl_expr):
format_dict = ast.Dict(keys, list(current.values()))
form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
name = "@py_format" + str(next(self.variable_counter))
self.on_failure.append(ast.Assign([ast.Name(name, ast.Store())], form))
if self.enable_assertion_pass_hook:
self.format_variables.append(name)
self.expl_stmts.append(ast.Assign([ast.Name(name, ast.Store())], form))
return ast.Name(name, ast.Load())

def generic_visit(self, node):
Expand Down Expand Up @@ -687,8 +710,12 @@ def visit_Assert(self, assert_):
self.statements = []
self.variables = []
self.variable_counter = itertools.count()

if self.enable_assertion_pass_hook:
self.format_variables = []

self.stack = []
self.on_failure = []
self.expl_stmts = []
self.push_format_context()
# Rewrite assert into a bunch of statements.
top_condition, explanation = self.visit(assert_.test)
Expand All @@ -699,24 +726,77 @@ def visit_Assert(self, assert_):
top_condition, module_path=self.module_path, lineno=assert_.lineno
)
)
# Create failure message.
body = self.on_failure
negation = ast.UnaryOp(ast.Not(), top_condition)
self.statements.append(ast.If(negation, body, []))
if assert_.msg:
assertmsg = self.helper("_format_assertmsg", assert_.msg)
explanation = "\n>assert " + explanation
else:
assertmsg = ast.Str("")
explanation = "assert " + explanation
template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation))
msg = self.pop_format_context(template)
fmt = self.helper("_format_explanation", msg)
err_name = ast.Name("AssertionError", ast.Load())
exc = ast.Call(err_name, [fmt], [])
raise_ = ast.Raise(exc, None)

body.append(raise_)

if self.enable_assertion_pass_hook: # Experimental pytest_assertion_pass hook
negation = ast.UnaryOp(ast.Not(), top_condition)
msg = self.pop_format_context(ast.Str(explanation))

# Failed
if assert_.msg:
assertmsg = self.helper("_format_assertmsg", assert_.msg)
gluestr = "\n>assert "
else:
assertmsg = ast.Str("")
gluestr = "assert "
err_explanation = ast.BinOp(ast.Str(gluestr), ast.Add(), msg)
err_msg = ast.BinOp(assertmsg, ast.Add(), err_explanation)
err_name = ast.Name("AssertionError", ast.Load())
fmt = self.helper("_format_explanation", err_msg)
exc = ast.Call(err_name, [fmt], [])
raise_ = ast.Raise(exc, None)
statements_fail = []
statements_fail.extend(self.expl_stmts)
statements_fail.append(raise_)

# Passed
fmt_pass = self.helper("_format_explanation", msg)
orig = astor.to_source(assert_.test).rstrip("\n").lstrip("(").rstrip(")")
hook_call_pass = ast.Expr(
self.helper(
"_call_assertion_pass",
ast.Num(assert_.lineno),
ast.Str(orig),
fmt_pass,
)
)
# If any hooks implement assert_pass hook
hook_impl_test = ast.If(
self.helper("_check_if_assertion_pass_impl"),
self.expl_stmts + [hook_call_pass],
[],
)
statements_pass = [hook_impl_test]

# Test for assertion condition
main_test = ast.If(negation, statements_fail, statements_pass)
self.statements.append(main_test)
if self.format_variables:
variables = [
ast.Name(name, ast.Store()) for name in self.format_variables
]
clear_format = ast.Assign(variables, _NameConstant(None))
self.statements.append(clear_format)

else: # Original assertion rewriting
# Create failure message.
body = self.expl_stmts
negation = ast.UnaryOp(ast.Not(), top_condition)
self.statements.append(ast.If(negation, body, []))
if assert_.msg:
assertmsg = self.helper("_format_assertmsg", assert_.msg)
explanation = "\n>assert " + explanation
else:
assertmsg = ast.Str("")
explanation = "assert " + explanation
template = ast.BinOp(assertmsg, ast.Add(), ast.Str(explanation))
msg = self.pop_format_context(template)
fmt = self.helper("_format_explanation", msg)
err_name = ast.Name("AssertionError", ast.Load())
exc = ast.Call(err_name, [fmt], [])
raise_ = ast.Raise(exc, None)

body.append(raise_)

# Clear temporary variables by setting them to None.
if self.variables:
variables = [ast.Name(name, ast.Store()) for name in self.variables]
Expand Down Expand Up @@ -770,22 +850,22 @@ def visit_BoolOp(self, boolop):
app = ast.Attribute(expl_list, "append", ast.Load())
is_or = int(isinstance(boolop.op, ast.Or))
body = save = self.statements
fail_save = self.on_failure
fail_save = self.expl_stmts
levels = len(boolop.values) - 1
self.push_format_context()
# Process each operand, short-circuiting if needed.
for i, v in enumerate(boolop.values):
if i:
fail_inner = []
# cond is set in a prior loop iteration below
self.on_failure.append(ast.If(cond, fail_inner, [])) # noqa
self.on_failure = fail_inner
self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa
self.expl_stmts = fail_inner
self.push_format_context()
res, expl = self.visit(v)
body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
expl_format = self.pop_format_context(ast.Str(expl))
call = ast.Call(app, [expl_format], [])
self.on_failure.append(ast.Expr(call))
self.expl_stmts.append(ast.Expr(call))
if i < levels:
cond = res
if is_or:
Expand All @@ -794,7 +874,7 @@ def visit_BoolOp(self, boolop):
self.statements.append(ast.If(cond, inner, []))
self.statements = body = inner
self.statements = save
self.on_failure = fail_save
self.expl_stmts = fail_save
expl_template = self.helper("_format_boolop", expl_list, ast.Num(is_or))
expl = self.pop_format_context(expl_template)
return ast.Name(res_var, ast.Load()), self.explanation_param(expl)
Expand Down
4 changes: 4 additions & 0 deletions src/_pytest/assertion/util.py
Expand Up @@ -12,6 +12,10 @@
# DebugInterpreter.
_reprcompare = None

# Works similarly as _reprcompare attribute. Is populated with the hook call
# when pytest_runtest_setup is called.
_assertion_pass = None


def format_explanation(explanation):
"""This formats an explanation
Expand Down

0 comments on commit 37fb50a

Please sign in to comment.