Skip to content

Commit

Permalink
Improve types for AST handling
Browse files Browse the repository at this point in the history
Adapting to changes in python/typeshed#11880. This mostly adds
more precise types for individual pieces of AST.
  • Loading branch information
JelleZijlstra committed May 18, 2024
1 parent e87f60b commit 8a672ef
Showing 1 changed file with 22 additions and 16 deletions.
38 changes: 22 additions & 16 deletions src/werkzeug/routing/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,11 +294,16 @@ def get_rules(self, map: Map) -> t.Iterator[Rule]:
)


def _prefix_names(src: str) -> ast.stmt:
_ASTT = t.TypeVar("_ASTT", bound=ast.AST)


def _prefix_names(src: str, expected_type: type[_ASTT]) -> _ASTT:
"""ast parse and prefix names with `.` to avoid collision with user vars"""
tree = ast.parse(src).body[0]
tree: ast.AST = ast.parse(src).body[0]
if isinstance(tree, ast.Expr):
tree = tree.value # type: ignore
tree = tree.value
if not isinstance(tree, expected_type):
raise TypeError(f"AST node is of type {type(tree).__name__}, not {expected_type.__name__}")
for node in ast.walk(tree):
if isinstance(node, ast.Name):
node.id = f".{node.id}"
Expand All @@ -313,8 +318,8 @@ def _prefix_names(src: str) -> ast.stmt:
else:
q = params = ""
"""
_IF_KWARGS_URL_ENCODE_AST = _prefix_names(_IF_KWARGS_URL_ENCODE_CODE)
_URL_ENCODE_AST_NAMES = (_prefix_names("q"), _prefix_names("params"))
_IF_KWARGS_URL_ENCODE_AST = _prefix_names(_IF_KWARGS_URL_ENCODE_CODE, ast.If)
_URL_ENCODE_AST_NAMES = (_prefix_names("q", ast.Name), _prefix_names("params", ast.Name))


class Rule(RuleFactory):
Expand Down Expand Up @@ -751,13 +756,13 @@ def _compile_builder(
else:
opl.append((True, data))

def _convert(elem: str) -> ast.stmt:
ret = _prefix_names(_CALL_CONVERTER_CODE_FMT.format(elem=elem))
ret.args = [ast.Name(str(elem), ast.Load())] # type: ignore # str for py2
def _convert(elem: str) -> ast.Call:
ret = _prefix_names(_CALL_CONVERTER_CODE_FMT.format(elem=elem), ast.Call)
ret.args = [ast.Name(elem, ast.Load())]
return ret

def _parts(ops: list[tuple[bool, str]]) -> list[ast.AST]:
parts = [
def _parts(ops: list[tuple[bool, str]]) -> list[ast.expr]:
parts: list[ast.expr] = [
_convert(elem) if is_dynamic else ast.Constant(elem)
for is_dynamic, elem in ops
]
Expand All @@ -773,13 +778,14 @@ def _parts(ops: list[tuple[bool, str]]) -> list[ast.AST]:

dom_parts = _parts(dom_ops)
url_parts = _parts(url_ops)
body: list[ast.stmt]
if not append_unknown:
body = []
else:
body = [_IF_KWARGS_URL_ENCODE_AST]
url_parts.extend(_URL_ENCODE_AST_NAMES)

def _join(parts: list[ast.AST]) -> ast.AST:
def _join(parts: list[ast.expr]) -> ast.expr:
if len(parts) == 1: # shortcut
return parts[0]
return ast.JoinedStr(parts)
Expand All @@ -795,7 +801,7 @@ def _join(parts: list[ast.AST]) -> ast.AST:
]
kargs = [str(k) for k in defaults]

func_ast: ast.FunctionDef = _prefix_names("def _(): pass") # type: ignore
func_ast = _prefix_names("def _(): pass", ast.FunctionDef)
func_ast.name = f"<builder:{self.rule!r}>"
func_ast.args.args.append(ast.arg(".self", None))
for arg in pargs + kargs:
Expand All @@ -815,13 +821,13 @@ def _join(parts: list[ast.AST]) -> ast.AST:
# bad line numbers cause an assert to fail in debug builds
for node in ast.walk(module):
if "lineno" in node._attributes:
node.lineno = 1
node.lineno = 1 # type: ignore[attr-defined]
if "end_lineno" in node._attributes:
node.end_lineno = node.lineno
node.end_lineno = node.lineno # type: ignore[attr-defined]
if "col_offset" in node._attributes:
node.col_offset = 0
node.col_offset = 0 # type: ignore[attr-defined]
if "end_col_offset" in node._attributes:
node.end_col_offset = node.col_offset
node.end_col_offset = node.col_offset # type: ignore[attr-defined]

code = compile(module, "<werkzeug routing>", "exec")
return self._get_func_code(code, func_ast.name)
Expand Down

0 comments on commit 8a672ef

Please sign in to comment.