diff --git a/airflow/models/abstractoperator.py b/airflow/models/abstractoperator.py index a6a0ce021dc21..7f3b3397add50 100644 --- a/airflow/models/abstractoperator.py +++ b/airflow/models/abstractoperator.py @@ -373,15 +373,15 @@ def render_template_fields( self, context: Context, jinja_env: jinja2.Environment | None = None, - ) -> BaseOperator | None: + ) -> None: """Template all attributes listed in template_fields. If the operator is mapped, this should return the unmapped, fully rendered, and map-expanded operator. The mapped operator should not be - modified. + modified. However, ``context`` will be modified in-place to reference + the unmapped operator for template rendering. - If the operator is not mapped, this should modify the operator in-place - and return either *None* (for backwards compatibility) or *self*. + If the operator is not mapped, this should modify the operator in-place. """ raise NotImplementedError() diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py index 74c1b064ab168..640f4cac2dd0e 100644 --- a/airflow/models/baseoperator.py +++ b/airflow/models/baseoperator.py @@ -1179,7 +1179,7 @@ def render_template_fields( self, context: Context, jinja_env: jinja2.Environment | None = None, - ) -> BaseOperator | None: + ) -> None: """Template all attributes listed in template_fields. This mutates the attributes in-place and is irreversible. @@ -1190,7 +1190,6 @@ def render_template_fields( if not jinja_env: jinja_env = self.get_template_env() self._do_render_template_fields(self, self.template_fields, context, jinja_env, set()) - return self @provide_session def clear( diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index e307726cd6272..8d1c3c45591d1 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -58,7 +58,7 @@ from airflow.ti_deps.deps.base_ti_dep import BaseTIDep from airflow.ti_deps.deps.mapped_task_expanded import MappedTaskIsExpanded from airflow.typing_compat import Literal -from airflow.utils.context import Context +from airflow.utils.context import Context, context_update_for_unmapped from airflow.utils.helpers import is_container from airflow.utils.operator_resources import Resources from airflow.utils.state import State, TaskInstanceState @@ -748,7 +748,7 @@ def render_template_fields( self, context: Context, jinja_env: jinja2.Environment | None = None, - ) -> BaseOperator | None: + ) -> None: if not jinja_env: jinja_env = self.get_template_env() @@ -761,6 +761,8 @@ def render_template_fields( mapped_kwargs, seen_oids = self._expand_mapped_kwargs(context, session) unmapped_task = self.unmap(mapped_kwargs) + context_update_for_unmapped(context, unmapped_task) + self._do_render_template_fields( parent=unmapped_task, template_fields=self.template_fields, @@ -769,4 +771,3 @@ def render_template_fields( seen_oids=seen_oids, session=session, ) - return unmapped_task diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 52562e72ad312..190542063c071 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -2185,10 +2185,14 @@ def render_templates(self, context: Context | None = None) -> Operator: """ if not context: context = self.get_template_context() - rendered_task = self.task.render_template_fields(context) - if rendered_task is None: # Compatibility -- custom renderer, assume unmapped. - return self.task - original_task, self.task = self.task, rendered_task + original_task = self.task + + # If self.task is mapped, this call replaces self.task to point to the + # unmapped BaseOperator created by this function! This is because the + # MappedOperator is useless for template rendering, and we need to be + # able to access the unmapped task instead. + original_task.render_template_fields(context) + return original_task def render_k8s_pod_yaml(self) -> dict | None: diff --git a/airflow/utils/context.py b/airflow/utils/context.py index 6e29f97cbcdd4..f356e7495f828 100644 --- a/airflow/utils/context.py +++ b/airflow/utils/context.py @@ -22,13 +22,26 @@ import copy import functools import warnings -from typing import Any, Container, ItemsView, Iterator, KeysView, Mapping, MutableMapping, ValuesView +from typing import ( + TYPE_CHECKING, + Any, + Container, + ItemsView, + Iterator, + KeysView, + Mapping, + MutableMapping, + ValuesView, +) import lazy_object_proxy from airflow.exceptions import RemovedInAirflow3Warning from airflow.utils.types import NOTSET +if TYPE_CHECKING: + from airflow.models.baseoperator import BaseOperator + # NOTE: Please keep this in sync with Context in airflow/utils/context.pyi. KNOWN_CONTEXT_KEYS = { "conf", @@ -291,3 +304,15 @@ def _create_value(k: str, v: Any) -> Any: return lazy_object_proxy.Proxy(factory) return {k: _create_value(k, v) for k, v in source._context.items()} + + +def context_update_for_unmapped(context: Context, task: BaseOperator) -> None: + """Update context after task unmapping. + + Since ``get_template_context()`` is called before unmapping, the context + contains information about the mapped task. We need to do some in-place + updates to ensure the template context reflects the unmapped task instead. + + :meta private: + """ + context["task"] = context["ti"].task = task diff --git a/tests/decorators/test_python.py b/tests/decorators/test_python.py index 1cb18c0e75122..3dc151668231e 100644 --- a/tests/decorators/test_python.py +++ b/tests/decorators/test_python.py @@ -756,11 +756,13 @@ def fn(arg1, arg2): mapped_ti: TaskInstance = dr.get_task_instance(mapped.operator.task_id, session=session) mapped_ti.map_index = 0 - op = mapped.operator.render_template_fields(context=mapped_ti.get_template_context(session=session)) - assert op - assert op.op_kwargs['arg1'] == "{{ ds }}" - assert op.op_kwargs['arg2'] == "fn" + assert mapped_ti.task.is_mapped + mapped.operator.render_template_fields(context=mapped_ti.get_template_context(session=session)) + assert not mapped_ti.task.is_mapped + + assert mapped_ti.task.op_kwargs['arg1'] == "{{ ds }}" + assert mapped_ti.task.op_kwargs['arg2'] == "fn" def test_task_decorator_has_wrapped_attr(): diff --git a/tests/models/test_mappedoperator.py b/tests/models/test_mappedoperator.py index 8c2f5107d2fa7..1faf42be3d85f 100644 --- a/tests/models/test_mappedoperator.py +++ b/tests/models/test_mappedoperator.py @@ -305,12 +305,14 @@ def __init__(self, value, arg1, **kwargs): mapped_ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session) mapped_ti.map_index = 0 - op = mapped.render_template_fields(context=mapped_ti.get_template_context(session=session)) - assert isinstance(op, MyOperator) - assert op.value == "{{ ds }}", "Should not be templated!" - assert op.arg1 == "{{ ds }}", "Should not be templated!" - assert op.arg2 == "a" + assert isinstance(mapped_ti.task, MappedOperator) + mapped.render_template_fields(context=mapped_ti.get_template_context(session=session)) + assert isinstance(mapped_ti.task, MyOperator) + + assert mapped_ti.task.value == "{{ ds }}", "Should not be templated!" + assert mapped_ti.task.arg1 == "{{ ds }}", "Should not be templated!" + assert mapped_ti.task.arg2 == "a" def test_mapped_render_nested_template_fields(dag_maker, session): @@ -430,10 +432,11 @@ def test_expand_kwargs_render_template_fields_validating_operator(dag_maker, ses ti: TaskInstance = dr.get_task_instance(mapped.task_id, session=session) ti.refresh_from_task(mapped) ti.map_index = map_index - op = mapped.render_template_fields(context=ti.get_template_context(session=session)) - assert isinstance(op, MockOperator) - assert op.arg1 == expected - assert op.arg2 == "a" + assert isinstance(ti.task, MappedOperator) + mapped.render_template_fields(context=ti.get_template_context(session=session)) + assert isinstance(ti.task, MockOperator) + assert ti.task.arg1 == expected + assert ti.task.arg2 == "a" def test_xcomarg_property_of_mapped_operator(dag_maker):