diff --git a/lib/ansible/template/__init__.py b/lib/ansible/template/__init__.py index 0c6ff1c9fe538f..82d41c076de511 100644 --- a/lib/ansible/template/__init__.py +++ b/lib/ansible/template/__init__.py @@ -36,6 +36,9 @@ from sha import sha as sha1 from jinja2 import Environment +from jinja2 import escape +from jinja2 import nodes +from jinja2.compiler import CodeGenerator from jinja2.loaders import FileSystemLoader from jinja2.exceptions import TemplateSyntaxError, UndefinedError from jinja2.utils import concat as j2_concat @@ -45,12 +48,14 @@ from ansible.errors import AnsibleError, AnsibleFilterError, AnsibleUndefinedVariable from ansible.module_utils.six import string_types, text_type from ansible.module_utils._text import to_native, to_text, to_bytes +from ansible.parsing.yaml.objects import AnsibleUnicode from ansible.plugins import filter_loader, lookup_loader, test_loader from ansible.template.safe_eval import safe_eval from ansible.template.template import AnsibleJ2Template from ansible.template.vars import AnsibleJ2Vars from ansible.vars.unsafe_proxy import UnsafeProxy, wrap_var + try: from __main__ import display except ImportError: @@ -152,10 +157,186 @@ def _count_newlines_from_end(in_str): while in_str[j] == '\n': j -= 1 return i - 1 - j + except TypeError: + return 0 except IndexError: # Uncommon cases: zero length string and string containing only newlines return i + +def ansible_j2_concat(invals): + ''' + Jinja's native concat is just a list join. The list comes from result of + templating all the nodes. If the return type is expected to be + native, the list is artificial and we should return just the first + element. + ''' + invals = [x for x in invals] + if isinstance(invals, list): + if len(invals) == 1: + # break out the single value + invals = invals[0] + elif len(invals) > 1: + # cast to unicode and join + invals = u''.join([u'%s' % x for x in invals]) + return invals + + +class AnsibleCodeGenerator(CodeGenerator): + ''' + A custom gnerator, which avoids injecting to_string() calls around + the interal code jinja uses to render templates. + ''' + + def visit_Output(self, node, frame): + # if we have a known extends statement, we don't output anything + # if we are in a require_output_check section + + if self.has_known_extends and frame.require_output_check: + return + + allow_constant_finalize = True + if self.environment.finalize: + func = self.environment.finalize + if getattr(func, 'contextfunction', False) or \ + getattr(func, 'evalcontextfunction', False): + allow_constant_finalize = False + elif getattr(func, 'environmentfunction', False): + finalize = lambda x: text_type( + self.environment.finalize(self.environment, x)) + else: + finalize = lambda x: text_type(self.environment.finalize(x)) + else: + finalize = text_type + + # if we are inside a frame that requires output checking, we do so + outdent_later = False + if frame.require_output_check: + self.writeline('if parent_template is None:') + self.indent() + outdent_later = True + + # try to evaluate as many chunks as possible into a static + # string at compile time. + body = [] + for child in node.nodes: + try: + if not allow_constant_finalize: + raise nodes.Impossible() + const = child.as_const(frame.eval_ctx) + except nodes.Impossible: + body.append(child) + continue + # the frame can't be volatile here, becaus otherwise the + # as_const() function would raise an Impossible exception + # at that point. + try: + if frame.eval_ctx.autoescape: + if hasattr(const, '__html__'): + const = const.__html__() + else: + const = escape(const) + const = finalize(const) + except Exception: + # if something goes wrong here we evaluate the node + # at runtime for easier debugging + body.append(child) + continue + if body and isinstance(body[-1], list): + body[-1].append(const) + else: + body.append([const]) + + # if we have less than 3 nodes or a buffer we yield or extend/append + if len(body) < 3 or frame.buffer is not None: + if frame.buffer is not None: + # for one item we append, for more we extend + if len(body) == 1: + self.writeline('%s.append(' % frame.buffer) + else: + self.writeline('%s.extend((' % frame.buffer) + self.indent() + for item in body: + if isinstance(item, list): + val = repr(j2_concat(item)) + if frame.buffer is None: + self.writeline('yield ' + val) + else: + self.writeline(val + ',') + else: + if frame.buffer is None: + self.writeline('yield ', item) + else: + self.newline(item) + close = 0 + if frame.eval_ctx.volatile: + self.write('(escape if context.eval_ctx.autoescape' + ' else to_string)(') + close += 1 + elif frame.eval_ctx.autoescape: + self.write('escape(') + close += 1 + if self.environment.finalize is not None: + self.write('environment.finalize(') + if getattr(self.environment.finalize, + "contextfunction", False): + self.write('context, ') + close += 1 + self.visit(item, frame) + if close > 0: + self.write(')' * close) + if frame.buffer is not None: + self.write(',') + + if frame.buffer is not None: + # close the open parentheses + self.outdent() + self.writeline(len(body) == 1 and ')' or '))') + + # otherwise we create a format string as this is faster in that case + else: + format = [] + arguments = [] + for item in body: + if isinstance(item, list): + format.append(j2_concat(item).replace('%', '%%')) + else: + format.append('%s') + arguments.append(item) + self.writeline('yield ') + self.write(repr(j2_concat(format)) + ' % (') + self.indent() + for argument in arguments: + self.newline(argument) + close = 0 + if frame.eval_ctx.volatile: + self.write('(escape if context.eval_ctx.autoescape else' + ' to_string)(') + close += 1 + elif frame.eval_ctx.autoescape: + self.write('escape(') + close += 1 + if self.environment.finalize is not None: + self.write('environment.finalize(') + if getattr(self.environment.finalize, + 'contextfunction', False): + self.write('context, ') + elif getattr(self.environment.finalize, + 'evalcontextfunction', False): + self.write('context.eval_ctx, ') + elif getattr(self.environment.finalize, + 'environmentfunction', False): + self.write('environment, ') + close += 1 + self.visit(argument, frame) + self.write(')' * close + ', ') + self.outdent() + self.writeline(')') + + if outdent_later: + self.outdent() + + class AnsibleContext(Context): ''' A custom context, which intercepts resolve() calls and sets a flag @@ -211,6 +392,7 @@ class AnsibleEnvironment(Environment): ''' context_class = AnsibleContext template_class = AnsibleJ2Template + code_generator_class = AnsibleCodeGenerator class Templar: ''' @@ -446,17 +628,18 @@ def template(self, variable, convert_bare=False, preserve_trailing_newlines=True unsafe = hasattr(result, '__UNSAFE__') if convert_data and not self._no_type_regex.match(variable): - # if this looks like a dictionary or list, convert it to such using the safe_eval method - if (result.startswith("{") and not result.startswith(self.environment.variable_start_string)) or \ - result.startswith("[") or result in ("True", "False"): - eval_results = safe_eval(result, locals=self._available_variables, include_exceptions=True) - if eval_results[1] is None: - result = eval_results[0] - if unsafe: - result = wrap_var(result) - else: - # FIXME: if the safe_eval raised an error, should we do something with it? - pass + if isinstance(result, (string_types, text_type)): + # if this looks like a dictionary or list, convert it to such using the safe_eval method + if (result.startswith("{") and not result.startswith(self.environment.variable_start_string)) or \ + result.startswith("[") or result in ("True", "False"): + eval_results = safe_eval(result, locals=self._available_variables, include_exceptions=True) + if eval_results[1] is None: + result = eval_results[0] + if unsafe: + result = wrap_var(result) + else: + # FIXME: if the safe_eval raised an error, should we do something with it? + pass # we only cache in the case where we have a single variable # name, to make sure we're not putting things which may otherwise @@ -508,6 +691,8 @@ def is_template(self, data): return True except: return False + if isinstance(new, StrictUndefined): + return True return (new != data) elif isinstance(data, (list, tuple)): for v in data: @@ -660,7 +845,7 @@ def do_template(self, data, preserve_trailing_newlines=True, escape_backslashes= rf = t.root_render_func(new_context) try: - res = j2_concat(rf) + res = ansible_j2_concat(rf) if new_context.unsafe: res = wrap_var(res) except TypeError as te: @@ -672,7 +857,7 @@ def do_template(self, data, preserve_trailing_newlines=True, escape_backslashes= display.debug("failing because of a type error, template data is: %s" % to_native(data)) raise AnsibleError("Unexpected templating type error occurred on (%s): %s" % (to_native(data),to_native(te))) - if preserve_trailing_newlines: + if preserve_trailing_newlines and isinstance(res, (string_types, text_type)): # The low level calls above do not preserve the newline # characters at the end of the input data, so we use the # calculate the difference in newlines and append them @@ -686,8 +871,14 @@ def do_template(self, data, preserve_trailing_newlines=True, escape_backslashes= # initializing self.environment and remove a single trailing # newline here if preserve_newlines is False. res_newlines = _count_newlines_from_end(res) - if data_newlines > res_newlines: + if data_newlines is not None and data_newlines > res_newlines: res += self.environment.newline_sequence * (data_newlines - res_newlines) + + # backwards compatibility for the debug module + if isinstance(res, StrictUndefined): + errmsg = '%s is undefined' % data + raise AnsibleUndefinedVariable(errmsg) + return res except (UndefinedError, AnsibleUndefinedVariable) as e: if fail_on_undefined: diff --git a/test/units/template/test_templar.py b/test/units/template/test_templar.py index cf49e75a022155..29c6ebd7288df5 100644 --- a/test/units/template/test_templar.py +++ b/test/units/template/test_templar.py @@ -319,7 +319,7 @@ def test_lookup_list(self): def test_lookup_jinja_undefined(self): self.assertRaisesRegexp(AnsibleUndefinedVariable, - "'an_undefined_jinja_var' is undefined", + "{{ an_undefined_jinja_var }} is undefined", self.templar._lookup, 'list', '{{ an_undefined_jinja_var }}') @@ -353,7 +353,7 @@ def test_lookup_jinja_list_wantlist(self): def test_lookup_jinja_list_wantlist_undefined(self): self.assertRaisesRegexp(AnsibleUndefinedVariable, - "'some_undefined_var' is undefined", + "{{ some_undefined_var }} is undefined", self.templar._lookup, 'list', '{{ some_undefined_var }}',