diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 92737e8f63f46..80f60d32f7229 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -337,8 +337,12 @@ def expand(self, **map_kwargs: OperatorExpandArgument) -> XComArg: return self._expand(DictOfListsExpandInput(map_kwargs), strict=False) def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> XComArg: - if not isinstance(kwargs, XComArg): - raise TypeError(f"expected XComArg object, not {type(kwargs).__name__}") + if isinstance(kwargs, Sequence): + for item in kwargs: + if not isinstance(item, (XComArg, Mapping)): + raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") + elif not isinstance(kwargs, XComArg): + raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") return self._expand(ListOfDictsExpandInput(kwargs), strict=strict) def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg: diff --git a/airflow/models/expandinput.py b/airflow/models/expandinput.py index f023cbaff8746..7cda3f07f3801 100644 --- a/airflow/models/expandinput.py +++ b/airflow/models/expandinput.py @@ -18,7 +18,6 @@ from __future__ import annotations -import collections import collections.abc import functools import operator @@ -74,7 +73,7 @@ class DictOfListsExpandInput(NamedTuple): value: dict[str, OperatorExpandArgument] - def iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]: + def _iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]: """Generate kwargs with values available on parse-time.""" from airflow.models.xcom_arg import XComArg @@ -83,7 +82,7 @@ def iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]: def get_parse_time_mapped_ti_count(self) -> int | None: if not self.value: return 0 - literal_values = [len(v) for _, v in self.iter_parse_time_resolved_kwargs()] + literal_values = [len(v) for _, v in self._iter_parse_time_resolved_kwargs()] if len(literal_values) != len(self.value): return None # None-literal type encountered, so give up. return functools.reduce(operator.mul, literal_values, 1) @@ -149,7 +148,7 @@ def _find_index_for_this_field(index: int) -> int: def resolve(self, context: Context, session: Session) -> tuple[Mapping[str, Any], set[int]]: data = {k: self._expand_mapped_field(k, v, context, session=session) for k, v in self.value.items()} - literal_keys = {k for k, _ in self.iter_parse_time_resolved_kwargs()} + literal_keys = {k for k, _ in self._iter_parse_time_resolved_kwargs()} resolved_oids = {id(v) for k, v in data.items() if k not in literal_keys} return data, resolved_oids @@ -166,15 +165,16 @@ class ListOfDictsExpandInput(NamedTuple): This is created from ``expand_kwargs(xcom_arg)``. """ - value: XComArg - - def iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]: - return () + value: OperatorExpandKwargsArgument def get_parse_time_mapped_ti_count(self) -> int | None: + if isinstance(self.value, collections.abc.Sized): + return len(self.value) return None def get_total_map_length(self, run_id: str, *, session: Session) -> int: + if isinstance(self.value, collections.abc.Sized): + return len(self.value) length = self.value.get_task_map_length(run_id, session=session) if length is None: raise NotFullyPopulated({"expand_kwargs() argument"}) @@ -184,12 +184,21 @@ def resolve(self, context: Context, session: Session) -> tuple[Mapping[str, Any] map_index = context["ti"].map_index if map_index < 0: raise RuntimeError("can't resolve task-mapping argument without expanding") - mappings = self.value.resolve(context, session) - if not isinstance(mappings, collections.abc.Sequence): - raise ValueError(f"expand_kwargs() expects a list[dict], not {_describe_type(mappings)}") - mapping = mappings[map_index] + + mapping: Any + if isinstance(self.value, collections.abc.Sized): + mapping = self.value[map_index] + if not isinstance(mapping, collections.abc.Mapping): + mapping = mapping.resolve(context, session) + else: + mappings = self.value.resolve(context, session) + if not isinstance(mappings, collections.abc.Sequence): + raise ValueError(f"expand_kwargs() expects a list[dict], not {_describe_type(mappings)}") + mapping = mappings[map_index] + if not isinstance(mapping, collections.abc.Mapping): raise ValueError(f"expand_kwargs() expects a list[dict], not list[{_describe_type(mapping)}]") + for key in mapping: if not isinstance(key, str): raise ValueError( diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 72b18ee77953c..d67c1fc2f92d7 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -196,8 +196,12 @@ def expand(self, **mapped_kwargs: OperatorExpandArgument) -> "MappedOperator": def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> "MappedOperator": from airflow.models.xcom_arg import XComArg - if not isinstance(kwargs, XComArg): - raise TypeError(f"expected XComArg object, not {type(kwargs).__name__}") + if isinstance(kwargs, collections.abc.Sequence): + for item in kwargs: + if not isinstance(item, (XComArg, collections.abc.Mapping)): + raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") + elif not isinstance(kwargs, XComArg): + raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") return self._expand(ListOfDictsExpandInput(kwargs), strict=strict) def _expand(self, expand_input: ExpandInput, *, strict: bool) -> "MappedOperator": diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index 13fb47fdd962a..bd3aabf7d3a27 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -1824,21 +1824,48 @@ def test_schedule_tis_map_index(dag_maker, session): def test_mapped_expand_kwargs(dag_maker): - with dag_maker() as dag: + with dag_maker(): @task - def task_1(): - return [{"arg1": "a", "arg2": "b"}, {"arg1": "y"}, {"arg2": "z"}] + def task_0(): + return {"arg1": "a", "arg2": "b"} - MockOperator.partial(task_id="task_2").expand_kwargs(task_1()) + @task + def task_1(args_0): + return [args_0, {"arg1": "y"}, {"arg2": "z"}] - dr: DagRun = dag_maker.create_dagrun() - assert len([ti for ti in dr.get_task_instances() if ti.task_id == "task_2"]) == 1 + args_0 = task_0() + args_list = task_1(args_0=args_0) - ti1 = dr.get_task_instance("task_1") - ti1.refresh_from_task(dag.get_task("task_1")) - ti1.run() + MockOperator.partial(task_id="task_2").expand_kwargs(args_list) + MockOperator.partial(task_id="task_3").expand_kwargs( + [{"arg1": "a", "arg2": "b"}, {"arg1": "y"}, {"arg2": "z"}], + ) + MockOperator.partial(task_id="task_4").expand_kwargs([args_0, {"arg1": "y"}, {"arg2": "z"}]) - dr.task_instance_scheduling_decisions() - ti_states = {ti.map_index: ti.state for ti in dr.get_task_instances() if ti.task_id == "task_2"} - assert ti_states == {0: None, 1: None, 2: None} + dr: DagRun = dag_maker.create_dagrun() + tis = {(ti.task_id, ti.map_index): ti for ti in dr.task_instances} + + # task_2 is not expanded yet since it relies on one single XCom input. + # task_3 and task_4 received a pure literal and can expanded right away. + # task_4 relies on an XCom input in the list, but can also be expanded. + assert sorted(map_index for (task_id, map_index) in tis if task_id == "task_2") == [-1] + assert sorted(map_index for (task_id, map_index) in tis if task_id == "task_3") == [0, 1, 2] + assert sorted(map_index for (task_id, map_index) in tis if task_id == "task_4") == [0, 1, 2] + + tis[("task_0", -1)].run() + tis[("task_1", -1)].run() + + # With the upstreams available, everything should get expanded now. + decision = dr.task_instance_scheduling_decisions() + assert {(ti.task_id, ti.map_index): ti.state for ti in decision.schedulable_tis} == { + ("task_2", 0): None, + ("task_2", 1): None, + ("task_2", 2): None, + ("task_3", 0): None, + ("task_3", 1): None, + ("task_3", 2): None, + ("task_4", 0): None, + ("task_4", 1): None, + ("task_4", 2): None, + }