diff --git a/airflow/decorators/base.py b/airflow/decorators/base.py index 3631187ac0c49..80f60d32f7229 100644 --- a/airflow/decorators/base.py +++ b/airflow/decorators/base.py @@ -18,7 +18,6 @@ import inspect import re from typing import ( - TYPE_CHECKING, Any, Callable, ClassVar, @@ -57,6 +56,8 @@ DictOfListsExpandInput, ExpandInput, ListOfDictsExpandInput, + OperatorExpandArgument, + OperatorExpandKwargsArgument, ) from airflow.models.mappedoperator import ( MappedOperator, @@ -73,9 +74,6 @@ from airflow.utils.task_group import TaskGroup, TaskGroupContext from airflow.utils.types import NOTSET -if TYPE_CHECKING: - from airflow.models.mappedoperator import Mappable - def validate_python_callable(python_callable: Any) -> None: """ @@ -329,7 +327,7 @@ def _validate_arg_names(self, func: ValidationSource, kwargs: Dict[str, Any]): names = ", ".join(repr(n) for n in kwargs_left) raise TypeError(f"{func}() got unexpected keyword arguments {names}") - def expand(self, **map_kwargs: "Mappable") -> XComArg: + def expand(self, **map_kwargs: OperatorExpandArgument) -> XComArg: if not map_kwargs: raise TypeError("no arguments to expand against") self._validate_arg_names("expand", map_kwargs) @@ -338,9 +336,13 @@ def expand(self, **map_kwargs: "Mappable") -> XComArg: # to False to skip the checks on execution. return self._expand(DictOfListsExpandInput(map_kwargs), strict=False) - def expand_kwargs(self, kwargs: XComArg, *, strict: bool = True) -> XComArg: - if not isinstance(kwargs, XComArg): - raise TypeError(f"expected XComArg object, not {type(kwargs).__name__}") + def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> XComArg: + if isinstance(kwargs, Sequence): + for item in kwargs: + if not isinstance(item, (XComArg, Mapping)): + raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") + elif not isinstance(kwargs, XComArg): + raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") return self._expand(ListOfDictsExpandInput(kwargs), strict=strict) def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg: @@ -500,10 +502,10 @@ def __wrapped__(self) -> Callable[FParams, FReturn]: def partial(self, **kwargs: Any) -> "Task[FParams, FReturn]": ... - def expand(self, **kwargs: "Mappable") -> XComArg: + def expand(self, **kwargs: OperatorExpandArgument) -> XComArg: ... - def expand_kwargs(self, kwargs: XComArg, *, strict: bool = True) -> XComArg: + def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> XComArg: ... def override(self, **kwargs: Any) -> "Task[FParams, FReturn]": diff --git a/airflow/decorators/task_group.py b/airflow/decorators/task_group.py index 59adad95ea17e..957317578e43d 100644 --- a/airflow/decorators/task_group.py +++ b/airflow/decorators/task_group.py @@ -29,7 +29,7 @@ if TYPE_CHECKING: from airflow.models.dag import DAG - from airflow.models.mappedoperator import Mappable + from airflow.models.mappedoperator import OperatorExpandArgument F = TypeVar("F", bound=Callable) R = TypeVar("R") @@ -100,7 +100,7 @@ class Group(Generic[F]): function: F # Return value should match F's return type, but that's impossible to declare. - def expand(self, **kwargs: "Mappable") -> Any: + def expand(self, **kwargs: "OperatorExpandArgument") -> Any: ... def partial(self, **kwargs: Any) -> "Group[F]": diff --git a/airflow/models/expandinput.py b/airflow/models/expandinput.py index fd77e358b97fa..7cda3f07f3801 100644 --- a/airflow/models/expandinput.py +++ b/airflow/models/expandinput.py @@ -18,11 +18,10 @@ from __future__ import annotations -import collections import collections.abc import functools import operator -from typing import TYPE_CHECKING, Any, Iterable, Mapping, NamedTuple, Sequence, Sized, Union +from typing import TYPE_CHECKING, Any, Dict, Iterable, Mapping, NamedTuple, Sequence, Sized, Union from airflow.compat.functools import cache from airflow.utils.context import Context @@ -34,9 +33,13 @@ ExpandInput = Union["DictOfListsExpandInput", "ListOfDictsExpandInput"] -# BaseOperator.expand() can be called on an XComArg, sequence, or dict (not any -# mapping since we need the value to be ordered). -Mappable = Union["XComArg", Sequence, dict] +# Each keyword argument to expand() can be an XComArg, sequence, or dict (not +# any mapping since we need the value to be ordered). +OperatorExpandArgument = Union["XComArg", Sequence, Dict[str, Any]] + +# The single argument of expand_kwargs() can be an XComArg, or a list with each +# element being either an XComArg or a dict. +OperatorExpandKwargsArgument = Union["XComArg", Sequence[Union["XComArg", Mapping[str, Any]]]] # For isinstance() check. @@ -68,13 +71,9 @@ class DictOfListsExpandInput(NamedTuple): This is created from ``expand(**kwargs)``. """ - value: dict[str, Mappable] - - def get_unresolved_kwargs(self) -> dict[str, Any]: - """Get the kwargs dict that can be inferred without resolving.""" - return self.value + value: dict[str, OperatorExpandArgument] - def iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]: + def _iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]: """Generate kwargs with values available on parse-time.""" from airflow.models.xcom_arg import XComArg @@ -83,7 +82,7 @@ def iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]: def get_parse_time_mapped_ti_count(self) -> int | None: if not self.value: return 0 - literal_values = [len(v) for _, v in self.iter_parse_time_resolved_kwargs()] + literal_values = [len(v) for _, v in self._iter_parse_time_resolved_kwargs()] if len(literal_values) != len(self.value): return None # None-literal type encountered, so give up. return functools.reduce(operator.mul, literal_values, 1) @@ -149,7 +148,7 @@ def _find_index_for_this_field(index: int) -> int: def resolve(self, context: Context, session: Session) -> tuple[Mapping[str, Any], set[int]]: data = {k: self._expand_mapped_field(k, v, context, session=session) for k, v in self.value.items()} - literal_keys = {k for k, _ in self.iter_parse_time_resolved_kwargs()} + literal_keys = {k for k, _ in self._iter_parse_time_resolved_kwargs()} resolved_oids = {id(v) for k, v in data.items() if k not in literal_keys} return data, resolved_oids @@ -166,23 +165,16 @@ class ListOfDictsExpandInput(NamedTuple): This is created from ``expand_kwargs(xcom_arg)``. """ - value: XComArg - - def get_unresolved_kwargs(self) -> dict[str, Any]: - """Get the kwargs dict that can be inferred without resolving. - - Since the list-of-dicts case relies entirely on run-time XCom, there's - no kwargs structure available, so this just returns an empty dict. - """ - return {} - - def iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]: - return () + value: OperatorExpandKwargsArgument def get_parse_time_mapped_ti_count(self) -> int | None: + if isinstance(self.value, collections.abc.Sized): + return len(self.value) return None def get_total_map_length(self, run_id: str, *, session: Session) -> int: + if isinstance(self.value, collections.abc.Sized): + return len(self.value) length = self.value.get_task_map_length(run_id, session=session) if length is None: raise NotFullyPopulated({"expand_kwargs() argument"}) @@ -192,12 +184,21 @@ def resolve(self, context: Context, session: Session) -> tuple[Mapping[str, Any] map_index = context["ti"].map_index if map_index < 0: raise RuntimeError("can't resolve task-mapping argument without expanding") - mappings = self.value.resolve(context, session) - if not isinstance(mappings, collections.abc.Sequence): - raise ValueError(f"expand_kwargs() expects a list[dict], not {_describe_type(mappings)}") - mapping = mappings[map_index] + + mapping: Any + if isinstance(self.value, collections.abc.Sized): + mapping = self.value[map_index] + if not isinstance(mapping, collections.abc.Mapping): + mapping = mapping.resolve(context, session) + else: + mappings = self.value.resolve(context, session) + if not isinstance(mappings, collections.abc.Sequence): + raise ValueError(f"expand_kwargs() expects a list[dict], not {_describe_type(mappings)}") + mapping = mappings[map_index] + if not isinstance(mapping, collections.abc.Mapping): raise ValueError(f"expand_kwargs() expects a list[dict], not list[{_describe_type(mapping)}]") + for key in mapping: if not isinstance(key, str): raise ValueError( diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 82d7eea870161..d1214fcc027a3 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -64,8 +64,9 @@ DictOfListsExpandInput, ExpandInput, ListOfDictsExpandInput, - Mappable, NotFullyPopulated, + OperatorExpandArgument, + OperatorExpandKwargsArgument, get_mappable_types, ) from airflow.models.pool import Pool @@ -184,7 +185,7 @@ def __del__(self): task_id = f"at {hex(id(self))}" warnings.warn(f"Task {task_id} was never mapped!") - def expand(self, **mapped_kwargs: "Mappable") -> "MappedOperator": + def expand(self, **mapped_kwargs: OperatorExpandArgument) -> "MappedOperator": if not mapped_kwargs: raise TypeError("no arguments to expand against") validate_mapping_kwargs(self.operator_class, "expand", mapped_kwargs) @@ -193,11 +194,15 @@ def expand(self, **mapped_kwargs: "Mappable") -> "MappedOperator": # to False to skip the checks on execution. return self._expand(DictOfListsExpandInput(mapped_kwargs), strict=False) - def expand_kwargs(self, kwargs: "XComArg", *, strict: bool = True) -> "MappedOperator": + def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> "MappedOperator": from airflow.models.xcom_arg import XComArg - if not isinstance(kwargs, XComArg): - raise TypeError(f"expected XComArg object, not {type(kwargs).__name__}") + if isinstance(kwargs, collections.abc.Sequence): + for item in kwargs: + if not isinstance(item, (XComArg, collections.abc.Mapping)): + raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") + elif not isinstance(kwargs, XComArg): + raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}") return self._expand(ListOfDictsExpandInput(kwargs), strict=strict) def _expand(self, expand_input: ExpandInput, *, strict: bool) -> "MappedOperator": diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index 13fb47fdd962a..bd3aabf7d3a27 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -1824,21 +1824,48 @@ def test_schedule_tis_map_index(dag_maker, session): def test_mapped_expand_kwargs(dag_maker): - with dag_maker() as dag: + with dag_maker(): @task - def task_1(): - return [{"arg1": "a", "arg2": "b"}, {"arg1": "y"}, {"arg2": "z"}] + def task_0(): + return {"arg1": "a", "arg2": "b"} - MockOperator.partial(task_id="task_2").expand_kwargs(task_1()) + @task + def task_1(args_0): + return [args_0, {"arg1": "y"}, {"arg2": "z"}] - dr: DagRun = dag_maker.create_dagrun() - assert len([ti for ti in dr.get_task_instances() if ti.task_id == "task_2"]) == 1 + args_0 = task_0() + args_list = task_1(args_0=args_0) - ti1 = dr.get_task_instance("task_1") - ti1.refresh_from_task(dag.get_task("task_1")) - ti1.run() + MockOperator.partial(task_id="task_2").expand_kwargs(args_list) + MockOperator.partial(task_id="task_3").expand_kwargs( + [{"arg1": "a", "arg2": "b"}, {"arg1": "y"}, {"arg2": "z"}], + ) + MockOperator.partial(task_id="task_4").expand_kwargs([args_0, {"arg1": "y"}, {"arg2": "z"}]) - dr.task_instance_scheduling_decisions() - ti_states = {ti.map_index: ti.state for ti in dr.get_task_instances() if ti.task_id == "task_2"} - assert ti_states == {0: None, 1: None, 2: None} + dr: DagRun = dag_maker.create_dagrun() + tis = {(ti.task_id, ti.map_index): ti for ti in dr.task_instances} + + # task_2 is not expanded yet since it relies on one single XCom input. + # task_3 and task_4 received a pure literal and can expanded right away. + # task_4 relies on an XCom input in the list, but can also be expanded. + assert sorted(map_index for (task_id, map_index) in tis if task_id == "task_2") == [-1] + assert sorted(map_index for (task_id, map_index) in tis if task_id == "task_3") == [0, 1, 2] + assert sorted(map_index for (task_id, map_index) in tis if task_id == "task_4") == [0, 1, 2] + + tis[("task_0", -1)].run() + tis[("task_1", -1)].run() + + # With the upstreams available, everything should get expanded now. + decision = dr.task_instance_scheduling_decisions() + assert {(ti.task_id, ti.map_index): ti.state for ti in decision.schedulable_tis} == { + ("task_2", 0): None, + ("task_2", 1): None, + ("task_2", 2): None, + ("task_3", 0): None, + ("task_3", 1): None, + ("task_3", 2): None, + ("task_4", 0): None, + ("task_4", 1): None, + ("task_4", 2): None, + }