Skip to content

Commit

Permalink
Fix new typing issues in AST code (#12337)
Browse files Browse the repository at this point in the history
python/typeshed#11880 adds more precise types for AST nodes. I'm submitting some changes to adapt pytest to these changes.
  • Loading branch information
JelleZijlstra committed May 18, 2024
1 parent 635fbe2 commit ee9ea70
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/_pytest/assertion/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,7 @@ def pop_format_context(self, expl_expr: ast.expr) -> ast.Name:
current = self.stack.pop()
if self.stack:
self.explanation_specifiers = self.stack[-1]
keys = [ast.Constant(key) for key in current.keys()]
keys: List[Optional[ast.expr]] = [ast.Constant(key) for key in current.keys()]
format_dict = ast.Dict(keys, list(current.values()))
form = ast.BinOp(expl_expr, ast.Mod(), format_dict)
name = "@py_format" + str(next(self.variable_counter))
Expand Down Expand Up @@ -926,13 +926,13 @@ def visit_Assert(self, assert_: ast.Assert) -> List[ast.stmt]:
[*self.expl_stmts, hook_call_pass],
[],
)
statements_pass = [hook_impl_test]
statements_pass: List[ast.stmt] = [hook_impl_test]

# Test for assertion condition
main_test = ast.If(negation, statements_fail, statements_pass)
self.statements.append(main_test)
if self.format_variables:
variables = [
variables: List[ast.expr] = [
ast.Name(name, ast.Store()) for name in self.format_variables
]
clear_format = ast.Assign(variables, ast.Constant(None))
Expand Down Expand Up @@ -1114,11 +1114,11 @@ def visit_Compare(self, comp: ast.Compare) -> Tuple[ast.expr, str]:
if isinstance(comp.left, (ast.Compare, ast.BoolOp)):
left_expl = f"({left_expl})"
res_variables = [self.variable() for i in range(len(comp.ops))]
load_names = [ast.Name(v, ast.Load()) for v in res_variables]
load_names: List[ast.expr] = [ast.Name(v, ast.Load()) for v in res_variables]
store_names = [ast.Name(v, ast.Store()) for v in res_variables]
it = zip(range(len(comp.ops)), comp.ops, comp.comparators)
expls = []
syms = []
expls: List[ast.expr] = []
syms: List[ast.expr] = []
results = [left_res]
for i, op, next_operand in it:
if (
Expand Down
1 change: 1 addition & 0 deletions testing/test_assertrewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def test_location_is_set(self) -> None:
if isinstance(node, ast.Import):
continue
for n in [node, *ast.iter_child_nodes(node)]:
assert isinstance(n, (ast.stmt, ast.expr))
assert n.lineno == 3
assert n.col_offset == 0
assert n.end_lineno == 6
Expand Down

0 comments on commit ee9ea70

Please sign in to comment.