Skip to content

Commit

Permalink
When rendering template, unmap task in context (#26702)
Browse files Browse the repository at this point in the history
(cherry picked from commit 5560a46)
  • Loading branch information
uranusjr authored and jedcunningham committed Sep 27, 2022
1 parent 131d8be commit e98b0e8
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 27 deletions.
8 changes: 4 additions & 4 deletions airflow/models/abstractoperator.py
Expand Up @@ -373,15 +373,15 @@ def render_template_fields(
self,
context: Context,
jinja_env: jinja2.Environment | None = None,
) -> BaseOperator | None:
) -> None:
"""Template all attributes listed in template_fields.
If the operator is mapped, this should return the unmapped, fully
rendered, and map-expanded operator. The mapped operator should not be
modified.
modified. However, ``context`` will be modified in-place to reference
the unmapped operator for template rendering.
If the operator is not mapped, this should modify the operator in-place
and return either *None* (for backwards compatibility) or *self*.
If the operator is not mapped, this should modify the operator in-place.
"""
raise NotImplementedError()

Expand Down
3 changes: 1 addition & 2 deletions airflow/models/baseoperator.py
Expand Up @@ -1179,7 +1179,7 @@ def render_template_fields(
self,
context: Context,
jinja_env: jinja2.Environment | None = None,
) -> BaseOperator | None:
) -> None:
"""Template all attributes listed in template_fields.
This mutates the attributes in-place and is irreversible.
Expand All @@ -1190,7 +1190,6 @@ def render_template_fields(
if not jinja_env:
jinja_env = self.get_template_env()
self._do_render_template_fields(self, self.template_fields, context, jinja_env, set())
return self

@provide_session
def clear(
Expand Down
7 changes: 4 additions & 3 deletions airflow/models/mappedoperator.py
Expand Up @@ -58,7 +58,7 @@
from airflow.ti_deps.deps.base_ti_dep import BaseTIDep
from airflow.ti_deps.deps.mapped_task_expanded import MappedTaskIsExpanded
from airflow.typing_compat import Literal
from airflow.utils.context import Context
from airflow.utils.context import Context, context_update_for_unmapped
from airflow.utils.helpers import is_container
from airflow.utils.operator_resources import Resources
from airflow.utils.state import State, TaskInstanceState
Expand Down Expand Up @@ -748,7 +748,7 @@ def render_template_fields(
self,
context: Context,
jinja_env: jinja2.Environment | None = None,
) -> BaseOperator | None:
) -> None:
if not jinja_env:
jinja_env = self.get_template_env()

Expand All @@ -761,6 +761,8 @@ def render_template_fields(

mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context, session)
unmapped_task = self.unmap(mapped_kwargs)
context_update_for_unmapped(context, unmapped_task)

self._do_render_template_fields(
parent=unmapped_task,
template_fields=self.template_fields,
Expand All @@ -769,4 +771,3 @@ def render_template_fields(
seen_oids=seen_oids,
session=session,
)
return unmapped_task
12 changes: 8 additions & 4 deletions airflow/models/taskinstance.py
Expand Up @@ -2190,10 +2190,14 @@ def render_templates(self, context: Context | None = None) -> Operator:
"""
if not context:
context = self.get_template_context()
rendered_task = self.task.render_template_fields(context)
if rendered_task is None: # Compatibility -- custom renderer, assume unmapped.
return self.task
original_task, self.task = self.task, rendered_task
original_task = self.task

# If self.task is mapped, this call replaces self.task to point to the
# unmapped BaseOperator created by this function! This is because the
# MappedOperator is useless for template rendering, and we need to be
# able to access the unmapped task instead.
original_task.render_template_fields(context)

return original_task

def render_k8s_pod_yaml(self) -> dict | None:
Expand Down
27 changes: 26 additions & 1 deletion airflow/utils/context.py
Expand Up @@ -22,13 +22,26 @@
import copy
import functools
import warnings
from typing import Any, Container, ItemsView, Iterator, KeysView, Mapping, MutableMapping, ValuesView
from typing import (
TYPE_CHECKING,
Any,
Container,
ItemsView,
Iterator,
KeysView,
Mapping,
MutableMapping,
ValuesView,
)

import lazy_object_proxy

from airflow.exceptions import RemovedInAirflow3Warning
from airflow.utils.types import NOTSET

if TYPE_CHECKING:
from airflow.models.baseoperator import BaseOperator

# NOTE: Please keep this in sync with Context in airflow/utils/context.pyi.
KNOWN_CONTEXT_KEYS = {
"conf",
Expand Down Expand Up @@ -291,3 +304,15 @@ def _create_value(k: str, v: Any) -> Any:
return lazy_object_proxy.Proxy(factory)

return {k: _create_value(k, v) for k, v in source._context.items()}


def context_update_for_unmapped(context: Context, task: BaseOperator) -> None:
"""Update context after task unmapping.
Since ``get_template_context()`` is called before unmapping, the context
contains information about the mapped task. We need to do some in-place
updates to ensure the template context reflects the unmapped task instead.
:meta private:
"""
context["task"] = context["ti"].task = task
10 changes: 6 additions & 4 deletions tests/decorators/test_python.py
Expand Up @@ -756,11 +756,13 @@ def fn(arg1, arg2):

mapped_ti: TaskInstance = dr.get_task_instance(mapped.operator.task_id, session=session)
mapped_ti.map_index = 0
op = mapped.operator.render_template_fields(context=mapped_ti.get_template_context(session=session))
assert op

assert op.op_kwargs['arg1'] == "{{ ds }}"
assert op.op_kwargs['arg2'] == "fn"
assert mapped_ti.task.is_mapped
mapped.operator.render_template_fields(context=mapped_ti.get_template_context(session=session))
assert not mapped_ti.task.is_mapped

assert mapped_ti.task.op_kwargs['arg1'] == "{{ ds }}"
assert mapped_ti.task.op_kwargs['arg2'] == "fn"


def test_task_decorator_has_wrapped_attr():
Expand Down
21 changes: 12 additions & 9 deletions tests/models/test_mappedoperator.py
Expand Up @@ -305,12 +305,14 @@ def __init__(self, value, arg1, **kwargs):

mapped_ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session)
mapped_ti.map_index = 0
op = mapped.render_template_fields(context=mapped_ti.get_template_context(session=session))
assert isinstance(op, MyOperator)

assert op.value == "{{ ds }}", "Should not be templated!"
assert op.arg1 == "{{ ds }}", "Should not be templated!"
assert op.arg2 == "a"
assert isinstance(mapped_ti.task, MappedOperator)
mapped.render_template_fields(context=mapped_ti.get_template_context(session=session))
assert isinstance(mapped_ti.task, MyOperator)

assert mapped_ti.task.value == "{{ ds }}", "Should not be templated!"
assert mapped_ti.task.arg1 == "{{ ds }}", "Should not be templated!"
assert mapped_ti.task.arg2 == "a"


def test_mapped_render_nested_template_fields(dag_maker, session):
Expand Down Expand Up @@ -430,10 +432,11 @@ def test_expand_kwargs_render_template_fields_validating_operator(dag_maker, ses
ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session)
ti.refresh_from_task(mapped)
ti.map_index = map_index
op = mapped.render_template_fields(context=ti.get_template_context(session=session))
assert isinstance(op, MockOperator)
assert op.arg1 == expected
assert op.arg2 == "a"
assert isinstance(ti.task, MappedOperator)
mapped.render_template_fields(context=ti.get_template_context(session=session))
assert isinstance(ti.task, MockOperator)
assert ti.task.arg1 == expected
assert ti.task.arg2 == "a"


def test_xcomarg_property_of_mapped_operator(dag_maker):
Expand Down

0 comments on commit e98b0e8

Please sign in to comment.