Skip to content

Commit

Permalink
Implement expand_kwargs() against a literal list (#25925)
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr committed Sep 1, 2022
1 parent a032b8a commit 4791443
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 58 deletions.
22 changes: 12 additions & 10 deletions airflow/decorators/base.py
Expand Up @@ -18,7 +18,6 @@
import inspect
import re
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Expand Down Expand Up @@ -57,6 +56,8 @@
DictOfListsExpandInput,
ExpandInput,
ListOfDictsExpandInput,
OperatorExpandArgument,
OperatorExpandKwargsArgument,
)
from airflow.models.mappedoperator import (
MappedOperator,
Expand All @@ -73,9 +74,6 @@
from airflow.utils.task_group import TaskGroup, TaskGroupContext
from airflow.utils.types import NOTSET

if TYPE_CHECKING:
from airflow.models.mappedoperator import Mappable


def validate_python_callable(python_callable: Any) -> None:
"""
Expand Down Expand Up @@ -329,7 +327,7 @@ def _validate_arg_names(self, func: ValidationSource, kwargs: Dict[str, Any]):
names = ", ".join(repr(n) for n in kwargs_left)
raise TypeError(f"{func}() got unexpected keyword arguments {names}")

def expand(self, **map_kwargs: "Mappable") -> XComArg:
def expand(self, **map_kwargs: OperatorExpandArgument) -> XComArg:
if not map_kwargs:
raise TypeError("no arguments to expand against")
self._validate_arg_names("expand", map_kwargs)
Expand All @@ -338,9 +336,13 @@ def expand(self, **map_kwargs: "Mappable") -> XComArg:
# to False to skip the checks on execution.
return self._expand(DictOfListsExpandInput(map_kwargs), strict=False)

def expand_kwargs(self, kwargs: XComArg, *, strict: bool = True) -> XComArg:
if not isinstance(kwargs, XComArg):
raise TypeError(f"expected XComArg object, not {type(kwargs).__name__}")
def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> XComArg:
if isinstance(kwargs, Sequence):
for item in kwargs:
if not isinstance(item, (XComArg, Mapping)):
raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
elif not isinstance(kwargs, XComArg):
raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
return self._expand(ListOfDictsExpandInput(kwargs), strict=strict)

def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg:
Expand Down Expand Up @@ -500,10 +502,10 @@ def __wrapped__(self) -> Callable[FParams, FReturn]:
def partial(self, **kwargs: Any) -> "Task[FParams, FReturn]":
...

def expand(self, **kwargs: "Mappable") -> XComArg:
def expand(self, **kwargs: OperatorExpandArgument) -> XComArg:
...

def expand_kwargs(self, kwargs: XComArg, *, strict: bool = True) -> XComArg:
def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> XComArg:
...

def override(self, **kwargs: Any) -> "Task[FParams, FReturn]":
Expand Down
4 changes: 2 additions & 2 deletions airflow/decorators/task_group.py
Expand Up @@ -29,7 +29,7 @@

if TYPE_CHECKING:
from airflow.models.dag import DAG
from airflow.models.mappedoperator import Mappable
from airflow.models.mappedoperator import OperatorExpandArgument

F = TypeVar("F", bound=Callable)
R = TypeVar("R")
Expand Down Expand Up @@ -100,7 +100,7 @@ class Group(Generic[F]):
function: F

# Return value should match F's return type, but that's impossible to declare.
def expand(self, **kwargs: "Mappable") -> Any:
def expand(self, **kwargs: "OperatorExpandArgument") -> Any:
...

def partial(self, **kwargs: Any) -> "Group[F]":
Expand Down
59 changes: 30 additions & 29 deletions airflow/models/expandinput.py
Expand Up @@ -18,11 +18,10 @@

from __future__ import annotations

import collections
import collections.abc
import functools
import operator
from typing import TYPE_CHECKING, Any, Iterable, Mapping, NamedTuple, Sequence, Sized, Union
from typing import TYPE_CHECKING, Any, Dict, Iterable, Mapping, NamedTuple, Sequence, Sized, Union

from airflow.compat.functools import cache
from airflow.utils.context import Context
Expand All @@ -34,9 +33,13 @@

ExpandInput = Union["DictOfListsExpandInput", "ListOfDictsExpandInput"]

# BaseOperator.expand() can be called on an XComArg, sequence, or dict (not any
# mapping since we need the value to be ordered).
Mappable = Union["XComArg", Sequence, dict]
# Each keyword argument to expand() can be an XComArg, sequence, or dict (not
# any mapping since we need the value to be ordered).
OperatorExpandArgument = Union["XComArg", Sequence, Dict[str, Any]]

# The single argument of expand_kwargs() can be an XComArg, or a list with each
# element being either an XComArg or a dict.
OperatorExpandKwargsArgument = Union["XComArg", Sequence[Union["XComArg", Mapping[str, Any]]]]


# For isinstance() check.
Expand Down Expand Up @@ -68,13 +71,9 @@ class DictOfListsExpandInput(NamedTuple):
This is created from ``expand(**kwargs)``.
"""

value: dict[str, Mappable]

def get_unresolved_kwargs(self) -> dict[str, Any]:
"""Get the kwargs dict that can be inferred without resolving."""
return self.value
value: dict[str, OperatorExpandArgument]

def iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]:
def _iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]:
"""Generate kwargs with values available on parse-time."""
from airflow.models.xcom_arg import XComArg

Expand All @@ -83,7 +82,7 @@ def iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]:
def get_parse_time_mapped_ti_count(self) -> int | None:
if not self.value:
return 0
literal_values = [len(v) for _, v in self.iter_parse_time_resolved_kwargs()]
literal_values = [len(v) for _, v in self._iter_parse_time_resolved_kwargs()]
if len(literal_values) != len(self.value):
return None # None-literal type encountered, so give up.
return functools.reduce(operator.mul, literal_values, 1)
Expand Down Expand Up @@ -149,7 +148,7 @@ def _find_index_for_this_field(index: int) -> int:

def resolve(self, context: Context, session: Session) -> tuple[Mapping[str, Any], set[int]]:
data = {k: self._expand_mapped_field(k, v, context, session=session) for k, v in self.value.items()}
literal_keys = {k for k, _ in self.iter_parse_time_resolved_kwargs()}
literal_keys = {k for k, _ in self._iter_parse_time_resolved_kwargs()}
resolved_oids = {id(v) for k, v in data.items() if k not in literal_keys}
return data, resolved_oids

Expand All @@ -166,23 +165,16 @@ class ListOfDictsExpandInput(NamedTuple):
This is created from ``expand_kwargs(xcom_arg)``.
"""

value: XComArg

def get_unresolved_kwargs(self) -> dict[str, Any]:
"""Get the kwargs dict that can be inferred without resolving.
Since the list-of-dicts case relies entirely on run-time XCom, there's
no kwargs structure available, so this just returns an empty dict.
"""
return {}

def iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]:
return ()
value: OperatorExpandKwargsArgument

def get_parse_time_mapped_ti_count(self) -> int | None:
if isinstance(self.value, collections.abc.Sized):
return len(self.value)
return None

def get_total_map_length(self, run_id: str, *, session: Session) -> int:
if isinstance(self.value, collections.abc.Sized):
return len(self.value)
length = self.value.get_task_map_length(run_id, session=session)
if length is None:
raise NotFullyPopulated({"expand_kwargs() argument"})
Expand All @@ -192,12 +184,21 @@ def resolve(self, context: Context, session: Session) -> tuple[Mapping[str, Any]
map_index = context["ti"].map_index
if map_index < 0:
raise RuntimeError("can't resolve task-mapping argument without expanding")
mappings = self.value.resolve(context, session)
if not isinstance(mappings, collections.abc.Sequence):
raise ValueError(f"expand_kwargs() expects a list[dict], not {_describe_type(mappings)}")
mapping = mappings[map_index]

mapping: Any
if isinstance(self.value, collections.abc.Sized):
mapping = self.value[map_index]
if not isinstance(mapping, collections.abc.Mapping):
mapping = mapping.resolve(context, session)
else:
mappings = self.value.resolve(context, session)
if not isinstance(mappings, collections.abc.Sequence):
raise ValueError(f"expand_kwargs() expects a list[dict], not {_describe_type(mappings)}")
mapping = mappings[map_index]

if not isinstance(mapping, collections.abc.Mapping):
raise ValueError(f"expand_kwargs() expects a list[dict], not list[{_describe_type(mapping)}]")

for key in mapping:
if not isinstance(key, str):
raise ValueError(
Expand Down
15 changes: 10 additions & 5 deletions airflow/models/mappedoperator.py
Expand Up @@ -64,8 +64,9 @@
DictOfListsExpandInput,
ExpandInput,
ListOfDictsExpandInput,
Mappable,
NotFullyPopulated,
OperatorExpandArgument,
OperatorExpandKwargsArgument,
get_mappable_types,
)
from airflow.models.pool import Pool
Expand Down Expand Up @@ -184,7 +185,7 @@ def __del__(self):
task_id = f"at {hex(id(self))}"
warnings.warn(f"Task {task_id} was never mapped!")

def expand(self, **mapped_kwargs: "Mappable") -> "MappedOperator":
def expand(self, **mapped_kwargs: OperatorExpandArgument) -> "MappedOperator":
if not mapped_kwargs:
raise TypeError("no arguments to expand against")
validate_mapping_kwargs(self.operator_class, "expand", mapped_kwargs)
Expand All @@ -193,11 +194,15 @@ def expand(self, **mapped_kwargs: "Mappable") -> "MappedOperator":
# to False to skip the checks on execution.
return self._expand(DictOfListsExpandInput(mapped_kwargs), strict=False)

def expand_kwargs(self, kwargs: "XComArg", *, strict: bool = True) -> "MappedOperator":
def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> "MappedOperator":
from airflow.models.xcom_arg import XComArg

if not isinstance(kwargs, XComArg):
raise TypeError(f"expected XComArg object, not {type(kwargs).__name__}")
if isinstance(kwargs, collections.abc.Sequence):
for item in kwargs:
if not isinstance(item, (XComArg, collections.abc.Mapping)):
raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
elif not isinstance(kwargs, XComArg):
raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
return self._expand(ListOfDictsExpandInput(kwargs), strict=strict)

def _expand(self, expand_input: ExpandInput, *, strict: bool) -> "MappedOperator":
Expand Down
51 changes: 39 additions & 12 deletions tests/models/test_dagrun.py
Expand Up @@ -1824,21 +1824,48 @@ def test_schedule_tis_map_index(dag_maker, session):


def test_mapped_expand_kwargs(dag_maker):
with dag_maker() as dag:
with dag_maker():

@task
def task_1():
return [{"arg1": "a", "arg2": "b"}, {"arg1": "y"}, {"arg2": "z"}]
def task_0():
return {"arg1": "a", "arg2": "b"}

MockOperator.partial(task_id="task_2").expand_kwargs(task_1())
@task
def task_1(args_0):
return [args_0, {"arg1": "y"}, {"arg2": "z"}]

dr: DagRun = dag_maker.create_dagrun()
assert len([ti for ti in dr.get_task_instances() if ti.task_id == "task_2"]) == 1
args_0 = task_0()
args_list = task_1(args_0=args_0)

ti1 = dr.get_task_instance("task_1")
ti1.refresh_from_task(dag.get_task("task_1"))
ti1.run()
MockOperator.partial(task_id="task_2").expand_kwargs(args_list)
MockOperator.partial(task_id="task_3").expand_kwargs(
[{"arg1": "a", "arg2": "b"}, {"arg1": "y"}, {"arg2": "z"}],
)
MockOperator.partial(task_id="task_4").expand_kwargs([args_0, {"arg1": "y"}, {"arg2": "z"}])

dr.task_instance_scheduling_decisions()
ti_states = {ti.map_index: ti.state for ti in dr.get_task_instances() if ti.task_id == "task_2"}
assert ti_states == {0: None, 1: None, 2: None}
dr: DagRun = dag_maker.create_dagrun()
tis = {(ti.task_id, ti.map_index): ti for ti in dr.task_instances}

# task_2 is not expanded yet since it relies on one single XCom input.
# task_3 and task_4 received a pure literal and can expanded right away.
# task_4 relies on an XCom input in the list, but can also be expanded.
assert sorted(map_index for (task_id, map_index) in tis if task_id == "task_2") == [-1]
assert sorted(map_index for (task_id, map_index) in tis if task_id == "task_3") == [0, 1, 2]
assert sorted(map_index for (task_id, map_index) in tis if task_id == "task_4") == [0, 1, 2]

tis[("task_0", -1)].run()
tis[("task_1", -1)].run()

# With the upstreams available, everything should get expanded now.
decision = dr.task_instance_scheduling_decisions()
assert {(ti.task_id, ti.map_index): ti.state for ti in decision.schedulable_tis} == {
("task_2", 0): None,
("task_2", 1): None,
("task_2", 2): None,
("task_3", 0): None,
("task_3", 1): None,
("task_3", 2): None,
("task_4", 0): None,
("task_4", 1): None,
("task_4", 2): None,
}

0 comments on commit 4791443

Please sign in to comment.