Skip to content

Commit

Permalink
Implement expand_kwargs() against a literal list
Browse files Browse the repository at this point in the history
The literal list may contain dict (operator kwargs) or an XComArg
(resolved to operator kwargs at runtime).
  • Loading branch information
uranusjr committed Aug 29, 2022
1 parent f483166 commit e28183d
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 28 deletions.
8 changes: 6 additions & 2 deletions airflow/decorators/base.py
Expand Up @@ -337,8 +337,12 @@ def expand(self, **map_kwargs: OperatorExpandArgument) -> XComArg:
return self._expand(DictOfListsExpandInput(map_kwargs), strict=False)

def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> XComArg:
if not isinstance(kwargs, XComArg):
raise TypeError(f"expected XComArg object, not {type(kwargs).__name__}")
if isinstance(kwargs, Sequence):
for item in kwargs:
if not isinstance(item, (XComArg, Mapping)):
raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
elif not isinstance(kwargs, XComArg):
raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
return self._expand(ListOfDictsExpandInput(kwargs), strict=strict)

def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg:
Expand Down
33 changes: 21 additions & 12 deletions airflow/models/expandinput.py
Expand Up @@ -18,7 +18,6 @@

from __future__ import annotations

import collections
import collections.abc
import functools
import operator
Expand Down Expand Up @@ -74,7 +73,7 @@ class DictOfListsExpandInput(NamedTuple):

value: dict[str, OperatorExpandArgument]

def iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]:
def _iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]:
"""Generate kwargs with values available on parse-time."""
from airflow.models.xcom_arg import XComArg

Expand All @@ -83,7 +82,7 @@ def iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]:
def get_parse_time_mapped_ti_count(self) -> int | None:
if not self.value:
return 0
literal_values = [len(v) for _, v in self.iter_parse_time_resolved_kwargs()]
literal_values = [len(v) for _, v in self._iter_parse_time_resolved_kwargs()]
if len(literal_values) != len(self.value):
return None # None-literal type encountered, so give up.
return functools.reduce(operator.mul, literal_values, 1)
Expand Down Expand Up @@ -149,7 +148,7 @@ def _find_index_for_this_field(index: int) -> int:

def resolve(self, context: Context, session: Session) -> tuple[Mapping[str, Any], set[int]]:
data = {k: self._expand_mapped_field(k, v, context, session=session) for k, v in self.value.items()}
literal_keys = {k for k, _ in self.iter_parse_time_resolved_kwargs()}
literal_keys = {k for k, _ in self._iter_parse_time_resolved_kwargs()}
resolved_oids = {id(v) for k, v in data.items() if k not in literal_keys}
return data, resolved_oids

Expand All @@ -166,15 +165,16 @@ class ListOfDictsExpandInput(NamedTuple):
This is created from ``expand_kwargs(xcom_arg)``.
"""

value: XComArg

def iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]:
return ()
value: OperatorExpandKwargsArgument

def get_parse_time_mapped_ti_count(self) -> int | None:
if isinstance(self.value, collections.abc.Sized):
return len(self.value)
return None

def get_total_map_length(self, run_id: str, *, session: Session) -> int:
if isinstance(self.value, collections.abc.Sized):
return len(self.value)
length = self.value.get_task_map_length(run_id, session=session)
if length is None:
raise NotFullyPopulated({"expand_kwargs() argument"})
Expand All @@ -184,12 +184,21 @@ def resolve(self, context: Context, session: Session) -> tuple[Mapping[str, Any]
map_index = context["ti"].map_index
if map_index < 0:
raise RuntimeError("can't resolve task-mapping argument without expanding")
mappings = self.value.resolve(context, session)
if not isinstance(mappings, collections.abc.Sequence):
raise ValueError(f"expand_kwargs() expects a list[dict], not {_describe_type(mappings)}")
mapping = mappings[map_index]

mapping: Any
if isinstance(self.value, collections.abc.Sized):
mapping = self.value[map_index]
if not isinstance(mapping, collections.abc.Mapping):
mapping = mapping.resolve(context, session)
else:
mappings = self.value.resolve(context, session)
if not isinstance(mappings, collections.abc.Sequence):
raise ValueError(f"expand_kwargs() expects a list[dict], not {_describe_type(mappings)}")
mapping = mappings[map_index]

if not isinstance(mapping, collections.abc.Mapping):
raise ValueError(f"expand_kwargs() expects a list[dict], not list[{_describe_type(mapping)}]")

for key in mapping:
if not isinstance(key, str):
raise ValueError(
Expand Down
8 changes: 6 additions & 2 deletions airflow/models/mappedoperator.py
Expand Up @@ -196,8 +196,12 @@ def expand(self, **mapped_kwargs: OperatorExpandArgument) -> "MappedOperator":
def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> "MappedOperator":
from airflow.models.xcom_arg import XComArg

if not isinstance(kwargs, XComArg):
raise TypeError(f"expected XComArg object, not {type(kwargs).__name__}")
if isinstance(kwargs, collections.abc.Sequence):
for item in kwargs:
if not isinstance(item, (XComArg, collections.abc.Mapping)):
raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
elif not isinstance(kwargs, XComArg):
raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
return self._expand(ListOfDictsExpandInput(kwargs), strict=strict)

def _expand(self, expand_input: ExpandInput, *, strict: bool) -> "MappedOperator":
Expand Down
51 changes: 39 additions & 12 deletions tests/models/test_dagrun.py
Expand Up @@ -1824,21 +1824,48 @@ def test_schedule_tis_map_index(dag_maker, session):


def test_mapped_expand_kwargs(dag_maker):
with dag_maker() as dag:
with dag_maker():

@task
def task_1():
return [{"arg1": "a", "arg2": "b"}, {"arg1": "y"}, {"arg2": "z"}]
def task_0():
return {"arg1": "a", "arg2": "b"}

MockOperator.partial(task_id="task_2").expand_kwargs(task_1())
@task
def task_1(args_0):
return [args_0, {"arg1": "y"}, {"arg2": "z"}]

dr: DagRun = dag_maker.create_dagrun()
assert len([ti for ti in dr.get_task_instances() if ti.task_id == "task_2"]) == 1
args_0 = task_0()
args_list = task_1(args_0=args_0)

ti1 = dr.get_task_instance("task_1")
ti1.refresh_from_task(dag.get_task("task_1"))
ti1.run()
MockOperator.partial(task_id="task_2").expand_kwargs(args_list)
MockOperator.partial(task_id="task_3").expand_kwargs(
[{"arg1": "a", "arg2": "b"}, {"arg1": "y"}, {"arg2": "z"}],
)
MockOperator.partial(task_id="task_4").expand_kwargs([args_0, {"arg1": "y"}, {"arg2": "z"}])

dr.task_instance_scheduling_decisions()
ti_states = {ti.map_index: ti.state for ti in dr.get_task_instances() if ti.task_id == "task_2"}
assert ti_states == {0: None, 1: None, 2: None}
dr: DagRun = dag_maker.create_dagrun()
tis = {(ti.task_id, ti.map_index): ti for ti in dr.task_instances}

# task_2 is not expanded yet since it relies on one single XCom input.
# task_3 and task_4 received a pure literal and can expanded right away.
# task_4 relies on an XCom input in the list, but can also be expanded.
assert sorted(map_index for (task_id, map_index) in tis if task_id == "task_2") == [-1]
assert sorted(map_index for (task_id, map_index) in tis if task_id == "task_3") == [0, 1, 2]
assert sorted(map_index for (task_id, map_index) in tis if task_id == "task_4") == [0, 1, 2]

tis[("task_0", -1)].run()
tis[("task_1", -1)].run()

# With the upstreams available, everything should get expanded now.
decision = dr.task_instance_scheduling_decisions()
assert {(ti.task_id, ti.map_index): ti.state for ti in decision.schedulable_tis} == {
("task_2", 0): None,
("task_2", 1): None,
("task_2", 2): None,
("task_3", 0): None,
("task_3", 1): None,
("task_3", 2): None,
("task_4", 0): None,
("task_4", 1): None,
("task_4", 2): None,
}

0 comments on commit e28183d

Please sign in to comment.