/
ast.py
68 lines (57 loc) · 1.94 KB
/
ast.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import ast
from ast import *
from rope.base import fscommands
def parse(source, filename="<string>"):
# NOTE: the raw string should be given to `compile` function
if isinstance(source, str):
source = fscommands.unicode_to_file_data(source)
if b"\r" in source:
source = source.replace(b"\r\n", b"\n").replace(b"\r", b"\n")
if not source.endswith(b"\n"):
source += b"\n"
try:
return ast.parse(source, filename="<unknown>")
except (TypeError, ValueError) as e:
error = SyntaxError()
error.lineno = 1
error.filename = filename
error.msg = str(e)
raise error
def walk(node, walker) -> None:
"""Walk the syntax tree"""
method_name = "_" + node.__class__.__name__
method = getattr(walker, method_name, None)
if method is not None:
method(node)
return
for child in get_child_nodes(node):
walk(child, walker)
def get_child_nodes(node):
if isinstance(node, ast.Module):
return node.body
result = []
if node._fields is not None:
for name in node._fields:
child = getattr(node, name)
if isinstance(child, list):
for entry in child:
if isinstance(entry, ast.AST):
result.append(entry)
if isinstance(child, ast.AST):
result.append(child)
return result
def call_for_nodes(node, callback, recursive=False):
"""If callback returns `True` the child nodes are skipped"""
result = callback(node)
if recursive and not result:
for child in get_child_nodes(node):
call_for_nodes(child, callback, recursive)
def get_children(node):
result = []
if node._fields is not None:
for name in node._fields:
if name in ["lineno", "col_offset"]:
continue
child = getattr(node, name)
result.append(child)
return result