Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] force jinja to preserve native types #23943

Closed
wants to merge 11 commits into from
219 changes: 205 additions & 14 deletions lib/ansible/template/__init__.py
Expand Up @@ -36,6 +36,9 @@
from sha import sha as sha1

from jinja2 import Environment
from jinja2 import escape
from jinja2 import nodes
from jinja2.compiler import CodeGenerator
from jinja2.loaders import FileSystemLoader
from jinja2.exceptions import TemplateSyntaxError, UndefinedError
from jinja2.utils import concat as j2_concat
Expand All @@ -45,12 +48,14 @@
from ansible.errors import AnsibleError, AnsibleFilterError, AnsibleUndefinedVariable
from ansible.module_utils.six import string_types, text_type
from ansible.module_utils._text import to_native, to_text, to_bytes
from ansible.parsing.yaml.objects import AnsibleUnicode
from ansible.plugins import filter_loader, lookup_loader, test_loader
from ansible.template.safe_eval import safe_eval
from ansible.template.template import AnsibleJ2Template
from ansible.template.vars import AnsibleJ2Vars
from ansible.vars.unsafe_proxy import UnsafeProxy, wrap_var


try:
from __main__ import display
except ImportError:
Expand Down Expand Up @@ -152,10 +157,186 @@ def _count_newlines_from_end(in_str):
while in_str[j] == '\n':
j -= 1
return i - 1 - j
except TypeError:
return 0
except IndexError:
# Uncommon cases: zero length string and string containing only newlines
return i


def ansible_j2_concat(invals):
'''
Jinja's native concat is just a list join. The list comes from result of
templating all the nodes. If the return type is expected to be
native, the list is artificial and we should return just the first
element.
'''
invals = [x for x in invals]
if isinstance(invals, list):
if len(invals) == 1:
# break out the single value
invals = invals[0]
elif len(invals) > 1:
# cast to unicode and join
invals = u''.join([u'%s' % x for x in invals])
return invals
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could make this more efficient. what types can be given here? Since the original jinja2 code is u"".join(invals) it seems like the minimum assumptions we can make are that invals is an iterable and that all the elements can convert to a text type without unicode exceptions. So maybe something like this:

# We can omit this block if invals is never a string type
if isinstance(invals, six.text_type):
    return invals
elif isinstance(invals, six.binary_type):
    return to_text(invals, errors='surrogate_or_strict')

try:
    if len(invals) == 1:
        # Break out the single value
        return invals[0]
except TypeError:
    # Iterable (by contract) but not a Sequence
    pass

return u''.join(invals)
  • This keeps us from instantiating two lists.
  • This should be faster even than the jinja2 concat if strings (not inside of a container) are allowed as invals.
  • If we don't have to deal with strings, then we can get rid of that block and reduce the number of conditionals.

I can adapt that further if we need to make less assumptions (or if we can make more of them).

Copy link
Contributor

@abadger abadger Apr 27, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, jctanner says that this is always a generator. So this is a better implementation:

from itertools import chain
start = []
for val in invals:
    start.append(val)
    if len(start) > 1:
        break
else:
    if start:
        return start[0]
    return []      
return u''.join(chain(start, invals))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Profiling and jctanner's version is faster. So the only things I can see to do differently:

  • remove "if isinstance(invals, list):". after the list comprehension, invals should always be a list.
  • u'%s' % x for x in invals note that this isn't 100% safe (for instance b'café' will raise a unicode error) but I assume you took it from some part of jinja2 code so we aren't making a new assumption then.



class AnsibleCodeGenerator(CodeGenerator):
'''
A custom gnerator, which avoids injecting to_string() calls around
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generator, not gnerator

the interal code jinja uses to render templates.
'''

def visit_Output(self, node, frame):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Camel case and underscores? :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sometimes you have to allow for all kinds ...

# if we have a known extends statement, we don't output anything
# if we are in a require_output_check section

if self.has_known_extends and frame.require_output_check:
return

allow_constant_finalize = True
if self.environment.finalize:
func = self.environment.finalize
if getattr(func, 'contextfunction', False) or \
getattr(func, 'evalcontextfunction', False):
allow_constant_finalize = False
elif getattr(func, 'environmentfunction', False):
finalize = lambda x: text_type(
self.environment.finalize(self.environment, x))
else:
finalize = lambda x: text_type(self.environment.finalize(x))
else:
finalize = text_type

# if we are inside a frame that requires output checking, we do so
outdent_later = False
if frame.require_output_check:
self.writeline('if parent_template is None:')
self.indent()
outdent_later = True

# try to evaluate as many chunks as possible into a static
# string at compile time.
body = []
for child in node.nodes:
try:
if not allow_constant_finalize:
raise nodes.Impossible()
const = child.as_const(frame.eval_ctx)
except nodes.Impossible:
body.append(child)
continue
# the frame can't be volatile here, becaus otherwise the
# as_const() function would raise an Impossible exception
# at that point.
try:
if frame.eval_ctx.autoescape:
if hasattr(const, '__html__'):
const = const.__html__()
else:
const = escape(const)
const = finalize(const)
except Exception:
# if something goes wrong here we evaluate the node
# at runtime for easier debugging
body.append(child)
continue
if body and isinstance(body[-1], list):
body[-1].append(const)
else:
body.append([const])

# if we have less than 3 nodes or a buffer we yield or extend/append
if len(body) < 3 or frame.buffer is not None:
if frame.buffer is not None:
# for one item we append, for more we extend
if len(body) == 1:
self.writeline('%s.append(' % frame.buffer)
else:
self.writeline('%s.extend((' % frame.buffer)
self.indent()
for item in body:
if isinstance(item, list):
val = repr(j2_concat(item))
if frame.buffer is None:
self.writeline('yield ' + val)
else:
self.writeline(val + ',')
else:
if frame.buffer is None:
self.writeline('yield ', item)
else:
self.newline(item)
close = 0
if frame.eval_ctx.volatile:
self.write('(escape if context.eval_ctx.autoescape'
' else to_string)(')
close += 1
elif frame.eval_ctx.autoescape:
self.write('escape(')
close += 1
if self.environment.finalize is not None:
self.write('environment.finalize(')
if getattr(self.environment.finalize,
"contextfunction", False):
self.write('context, ')
close += 1
self.visit(item, frame)
if close > 0:
self.write(')' * close)
if frame.buffer is not None:
self.write(',')

if frame.buffer is not None:
# close the open parentheses
self.outdent()
self.writeline(len(body) == 1 and ')' or '))')

# otherwise we create a format string as this is faster in that case
else:
format = []
arguments = []
for item in body:
if isinstance(item, list):
format.append(j2_concat(item).replace('%', '%%'))
else:
format.append('%s')
arguments.append(item)
self.writeline('yield ')
self.write(repr(j2_concat(format)) + ' % (')
self.indent()
for argument in arguments:
self.newline(argument)
close = 0
if frame.eval_ctx.volatile:
self.write('(escape if context.eval_ctx.autoescape else'
' to_string)(')
close += 1
elif frame.eval_ctx.autoescape:
self.write('escape(')
close += 1
if self.environment.finalize is not None:
self.write('environment.finalize(')
if getattr(self.environment.finalize,
'contextfunction', False):
self.write('context, ')
elif getattr(self.environment.finalize,
'evalcontextfunction', False):
self.write('context.eval_ctx, ')
elif getattr(self.environment.finalize,
'environmentfunction', False):
self.write('environment, ')
close += 1
self.visit(argument, frame)
self.write(')' * close + ', ')
self.outdent()
self.writeline(')')

if outdent_later:
self.outdent()


class AnsibleContext(Context):
'''
A custom context, which intercepts resolve() calls and sets a flag
Expand Down Expand Up @@ -211,6 +392,7 @@ class AnsibleEnvironment(Environment):
'''
context_class = AnsibleContext
template_class = AnsibleJ2Template
code_generator_class = AnsibleCodeGenerator

class Templar:
'''
Expand Down Expand Up @@ -446,17 +628,18 @@ def template(self, variable, convert_bare=False, preserve_trailing_newlines=True

unsafe = hasattr(result, '__UNSAFE__')
if convert_data and not self._no_type_regex.match(variable):
# if this looks like a dictionary or list, convert it to such using the safe_eval method
if (result.startswith("{") and not result.startswith(self.environment.variable_start_string)) or \
result.startswith("[") or result in ("True", "False"):
eval_results = safe_eval(result, locals=self._available_variables, include_exceptions=True)
if eval_results[1] is None:
result = eval_results[0]
if unsafe:
result = wrap_var(result)
else:
# FIXME: if the safe_eval raised an error, should we do something with it?
pass
if isinstance(result, (string_types, text_type)):
# if this looks like a dictionary or list, convert it to such using the safe_eval method
if (result.startswith("{") and not result.startswith(self.environment.variable_start_string)) or \
result.startswith("[") or result in ("True", "False"):
eval_results = safe_eval(result, locals=self._available_variables, include_exceptions=True)
if eval_results[1] is None:
result = eval_results[0]
if unsafe:
result = wrap_var(result)
else:
# FIXME: if the safe_eval raised an error, should we do something with it?
pass

# we only cache in the case where we have a single variable
# name, to make sure we're not putting things which may otherwise
Expand Down Expand Up @@ -508,6 +691,8 @@ def is_template(self, data):
return True
except:
return False
if isinstance(new, StrictUndefined):
return True
return (new != data)
elif isinstance(data, (list, tuple)):
for v in data:
Expand Down Expand Up @@ -660,7 +845,7 @@ def do_template(self, data, preserve_trailing_newlines=True, escape_backslashes=
rf = t.root_render_func(new_context)

try:
res = j2_concat(rf)
res = ansible_j2_concat(rf)
if new_context.unsafe:
res = wrap_var(res)
except TypeError as te:
Expand All @@ -672,7 +857,7 @@ def do_template(self, data, preserve_trailing_newlines=True, escape_backslashes=
display.debug("failing because of a type error, template data is: %s" % to_native(data))
raise AnsibleError("Unexpected templating type error occurred on (%s): %s" % (to_native(data),to_native(te)))

if preserve_trailing_newlines:
if preserve_trailing_newlines and isinstance(res, (string_types, text_type)):
# The low level calls above do not preserve the newline
# characters at the end of the input data, so we use the
# calculate the difference in newlines and append them
Expand All @@ -686,8 +871,14 @@ def do_template(self, data, preserve_trailing_newlines=True, escape_backslashes=
# initializing self.environment and remove a single trailing
# newline here if preserve_newlines is False.
res_newlines = _count_newlines_from_end(res)
if data_newlines > res_newlines:
if data_newlines is not None and data_newlines > res_newlines:
res += self.environment.newline_sequence * (data_newlines - res_newlines)

# backwards compatibility for the debug module
if isinstance(res, StrictUndefined):
errmsg = '%s is undefined' % data
raise AnsibleUndefinedVariable(errmsg)

return res
except (UndefinedError, AnsibleUndefinedVariable) as e:
if fail_on_undefined:
Expand Down
4 changes: 2 additions & 2 deletions test/units/template/test_templar.py
Expand Up @@ -319,7 +319,7 @@ def test_lookup_list(self):

def test_lookup_jinja_undefined(self):
self.assertRaisesRegexp(AnsibleUndefinedVariable,
"'an_undefined_jinja_var' is undefined",
"{{ an_undefined_jinja_var }} is undefined",
self.templar._lookup,
'list', '{{ an_undefined_jinja_var }}')

Expand Down Expand Up @@ -353,7 +353,7 @@ def test_lookup_jinja_list_wantlist(self):

def test_lookup_jinja_list_wantlist_undefined(self):
self.assertRaisesRegexp(AnsibleUndefinedVariable,
"'some_undefined_var' is undefined",
"{{ some_undefined_var }} is undefined",
self.templar._lookup,
'list',
'{{ some_undefined_var }}',
Expand Down