Skip to content

Commit

Permalink
support parse comp in function
Browse files Browse the repository at this point in the history
  • Loading branch information
cocolato committed Feb 1, 2024
1 parent 2815589 commit 7a012f5
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 0 deletions.
20 changes: 20 additions & 0 deletions mako/pyparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,26 @@ def visit_Lambda(self, node, *args):
def visit_FunctionDef(self, node):
self._add_declared(node.name)
self._visit_function(node, False)

def visit_ListComp(self, node):
if self.in_function:
if not isinstance(node.elt, _ast.Name):
self.visit(node.elt)
for comp in node.generators:
self.visit(comp.iter)
else:
self.generic_visit(node)

visit_SetComp = visit_GeneratorExp = visit_ListComp

def visit_DictComp(self, node):
if self.in_function:
if not isinstance(node.key, _ast.Name):
self.visit(node.elt)
for comp in node.generators:
self.visit(comp.iter)
else:
self.generic_visit(node)

def _expand_tuples(self, args):
for arg in args:
Expand Down
36 changes: 36 additions & 0 deletions test/test_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,42 @@ def test_locate_identifiers_17(self):
"""
parsed = ast.PythonCode(code, **exception_kwargs)
eq_(parsed.undeclared_identifiers, {"x", "y", "Foo", "Bar"})

def test_locate_identifiers_18(self):
code = """
def func():
return [i for i in range(10)]
"""
parsed = ast.PythonCode(code, **exception_kwargs)
eq_(parsed.declared_identifiers, {"func"})
eq_(parsed.undeclared_identifiers, {"range"})

def test_locate_identifiers_19(self):
code = """
def func():
return (i for i in range(10))
"""
parsed = ast.PythonCode(code, **exception_kwargs)
eq_(parsed.declared_identifiers, {"func"})
eq_(parsed.undeclared_identifiers, {"range"})

def test_locate_identifiers_20(self):
code = """
def func():
return {i for i in range(10)}
"""
parsed = ast.PythonCode(code, **exception_kwargs)
eq_(parsed.declared_identifiers, {"func"})
eq_(parsed.undeclared_identifiers, {"range"})

def test_locate_identifiers_21(self):
code = """
def func():
return {i: i**2 for i in range(10)}
"""
parsed = ast.PythonCode(code, **exception_kwargs)
eq_(parsed.declared_identifiers, {"func"})
eq_(parsed.undeclared_identifiers, {"range"})

def test_no_global_imports(self):
code = """
Expand Down
58 changes: 58 additions & 0 deletions test/test_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -1717,3 +1717,61 @@ def test_inline_percent(self):
"% foo",
"bar %% baz",
]


def test_lsitcomp_in_func_strict(self):
t = Template(
"""
<%
mydict = { 'foo': 1 }
def getkeys(x):
return [ k for k in x.keys() ]
%>
${ ','.join( getkeys(mydict) ) }
"""
, strict_undefined=True)
assert result_raw_lines(t.render()) == ["foo"]

def test_setcomp_in_func_strict(self):
t = Template(
"""
<%
mydict = { 'foo': 1 }
def getkeys(x):
return { k for k in x.keys() }
%>
${ ','.join( getkeys(mydict) ) }
"""
, strict_undefined=True)
assert result_raw_lines(t.render()) == ["foo"]


def test_generator_in_func_strict(self):
t = Template(
"""
<%
mydict = { 'foo': 1 }
def getkeys(x):
return ( k for k in x.keys())
%>
${ ','.join( getkeys(mydict) ) }
"""
, strict_undefined=True)
assert result_raw_lines(t.render()) == ["foo"]


def test_dictcomp_in_func_strict(self):
t = Template(
"""
<%
def square():
return {i: i**2 for i in range(10)}
%>
${ square()[3] }
"""
, strict_undefined=True)
assert result_raw_lines(t.render()) == ["9"]

0 comments on commit 7a012f5

Please sign in to comment.