Skip to content

Commit

Permalink
Support comprehensions inside functions when use strict_undefined flag.
Browse files Browse the repository at this point in the history
Fixes: #320

Now the test code works as expected if strict_undefined is set to true:

```python
from mako.template import Template

text = """
<%
    mydict = { 'foo': 1 }

    ## Uncomment the following line to workaround the error
    ##k = None
    def getkeys(x):
        return [ k for k in x.keys() ]
%>

${ ','.join( getkeys(mydict) ) }
"""

tmpl = Template(text=text, strict_undefined=True)
out = tmpl.render()
print(out)
```

output:
```

foo

```

Closes: #386
Pull-request: #386
Pull-request-sha: cc6a3e0

Change-Id: I0591873a83837f8f35b0963c0536df1e2675012f
  • Loading branch information
cocolato authored and sqla-tester committed Feb 6, 2024
1 parent 2815589 commit e2606d5
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 0 deletions.
9 changes: 9 additions & 0 deletions doc/build/unreleased/320.rst
@@ -0,0 +1,9 @@
.. change::
:tags: bug, parser
:tickets: 320

Fixed unexpected syntax error in strict_undefined mode that occurred
when using comprehensions within a function in a Mako Python code block.
Now, the local variable in comprehensions won't be added to the checklist
when using strict_undefined mode.
Pull request courtesy Hai Zhu.
20 changes: 20 additions & 0 deletions mako/pyparser.py
Expand Up @@ -90,6 +90,26 @@ def visit_FunctionDef(self, node):
self._add_declared(node.name)
self._visit_function(node, False)

def visit_ListComp(self, node):
if self.in_function:
if not isinstance(node.elt, _ast.Name):
self.visit(node.elt)
for comp in node.generators:
self.visit(comp.iter)
else:
self.generic_visit(node)

visit_SetComp = visit_GeneratorExp = visit_ListComp

def visit_DictComp(self, node):
if self.in_function:
if not isinstance(node.key, _ast.Name):
self.visit(node.elt)
for comp in node.generators:
self.visit(comp.iter)
else:
self.generic_visit(node)

def _expand_tuples(self, args):
for arg in args:
if isinstance(arg, _ast.Tuple):
Expand Down
36 changes: 36 additions & 0 deletions test/test_ast.py
Expand Up @@ -222,6 +222,42 @@ def test_locate_identifiers_17(self):
parsed = ast.PythonCode(code, **exception_kwargs)
eq_(parsed.undeclared_identifiers, {"x", "y", "Foo", "Bar"})

def test_locate_identifiers_18(self):
code = """
def func():
return [i for i in range(10)]
"""
parsed = ast.PythonCode(code, **exception_kwargs)
eq_(parsed.declared_identifiers, {"func"})
eq_(parsed.undeclared_identifiers, {"range"})

def test_locate_identifiers_19(self):
code = """
def func():
return (i for i in range(10))
"""
parsed = ast.PythonCode(code, **exception_kwargs)
eq_(parsed.declared_identifiers, {"func"})
eq_(parsed.undeclared_identifiers, {"range"})

def test_locate_identifiers_20(self):
code = """
def func():
return {i for i in range(10)}
"""
parsed = ast.PythonCode(code, **exception_kwargs)
eq_(parsed.declared_identifiers, {"func"})
eq_(parsed.undeclared_identifiers, {"range"})

def test_locate_identifiers_21(self):
code = """
def func():
return {i: i**2 for i in range(10)}
"""
parsed = ast.PythonCode(code, **exception_kwargs)
eq_(parsed.declared_identifiers, {"func"})
eq_(parsed.undeclared_identifiers, {"range"})

def test_no_global_imports(self):
code = """
from foo import *
Expand Down
59 changes: 59 additions & 0 deletions test/test_template.py
Expand Up @@ -1717,3 +1717,62 @@ def test_inline_percent(self):
"% foo",
"bar %% baz",
]

def test_listcomp_in_func_strict(self):
t = Template(
"""
<%
mydict = { 'foo': 1 }
def getkeys(x):
return [ k for k in x.keys() ]
%>
${ ','.join( getkeys(mydict) ) }
""",
strict_undefined=True,
)
assert result_raw_lines(t.render()) == ["foo"]

def test_setcomp_in_func_strict(self):
t = Template(
"""
<%
mydict = { 'foo': 1 }
def getkeys(x):
return { k for k in x.keys() }
%>
${ ','.join( getkeys(mydict) ) }
""",
strict_undefined=True,
)
assert result_raw_lines(t.render()) == ["foo"]

def test_generator_in_func_strict(self):
t = Template(
"""
<%
mydict = { 'foo': 1 }
def getkeys(x):
return ( k for k in x.keys())
%>
${ ','.join( getkeys(mydict) ) }
""",
strict_undefined=True,
)
assert result_raw_lines(t.render()) == ["foo"]

def test_dictcomp_in_func_strict(self):
t = Template(
"""
<%
def square():
return {i: i**2 for i in range(10)}
%>
${ square()[3] }
""",
strict_undefined=True,
)
assert result_raw_lines(t.render()) == ["9"]

0 comments on commit e2606d5

Please sign in to comment.