Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement expand_kwargs() against a literal list #25925

Merged
merged 4 commits into from Sep 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 12 additions & 10 deletions airflow/decorators/base.py
Expand Up @@ -18,7 +18,6 @@
import inspect
import re
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Expand Down Expand Up @@ -57,6 +56,8 @@
DictOfListsExpandInput,
ExpandInput,
ListOfDictsExpandInput,
OperatorExpandArgument,
OperatorExpandKwargsArgument,
)
from airflow.models.mappedoperator import (
MappedOperator,
Expand All @@ -73,9 +74,6 @@
from airflow.utils.task_group import TaskGroup, TaskGroupContext
from airflow.utils.types import NOTSET

if TYPE_CHECKING:
from airflow.models.mappedoperator import Mappable


def validate_python_callable(python_callable: Any) -> None:
"""
Expand Down Expand Up @@ -329,7 +327,7 @@ def _validate_arg_names(self, func: ValidationSource, kwargs: Dict[str, Any]):
names = ", ".join(repr(n) for n in kwargs_left)
raise TypeError(f"{func}() got unexpected keyword arguments {names}")

def expand(self, **map_kwargs: "Mappable") -> XComArg:
def expand(self, **map_kwargs: OperatorExpandArgument) -> XComArg:
if not map_kwargs:
raise TypeError("no arguments to expand against")
self._validate_arg_names("expand", map_kwargs)
Expand All @@ -338,9 +336,13 @@ def expand(self, **map_kwargs: "Mappable") -> XComArg:
# to False to skip the checks on execution.
return self._expand(DictOfListsExpandInput(map_kwargs), strict=False)

def expand_kwargs(self, kwargs: XComArg, *, strict: bool = True) -> XComArg:
if not isinstance(kwargs, XComArg):
raise TypeError(f"expected XComArg object, not {type(kwargs).__name__}")
def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> XComArg:
if isinstance(kwargs, Sequence):
for item in kwargs:
if not isinstance(item, (XComArg, Mapping)):
raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
elif not isinstance(kwargs, XComArg):
raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
return self._expand(ListOfDictsExpandInput(kwargs), strict=strict)

def _expand(self, expand_input: ExpandInput, *, strict: bool) -> XComArg:
Expand Down Expand Up @@ -500,10 +502,10 @@ def __wrapped__(self) -> Callable[FParams, FReturn]:
def partial(self, **kwargs: Any) -> "Task[FParams, FReturn]":
...

def expand(self, **kwargs: "Mappable") -> XComArg:
def expand(self, **kwargs: OperatorExpandArgument) -> XComArg:
...

def expand_kwargs(self, kwargs: XComArg, *, strict: bool = True) -> XComArg:
def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> XComArg:
...

def override(self, **kwargs: Any) -> "Task[FParams, FReturn]":
Expand Down
4 changes: 2 additions & 2 deletions airflow/decorators/task_group.py
Expand Up @@ -29,7 +29,7 @@

if TYPE_CHECKING:
from airflow.models.dag import DAG
from airflow.models.mappedoperator import Mappable
from airflow.models.mappedoperator import OperatorExpandArgument

F = TypeVar("F", bound=Callable)
R = TypeVar("R")
Expand Down Expand Up @@ -100,7 +100,7 @@ class Group(Generic[F]):
function: F

# Return value should match F's return type, but that's impossible to declare.
def expand(self, **kwargs: "Mappable") -> Any:
def expand(self, **kwargs: "OperatorExpandArgument") -> Any:
...

def partial(self, **kwargs: Any) -> "Group[F]":
Expand Down
59 changes: 30 additions & 29 deletions airflow/models/expandinput.py
Expand Up @@ -18,11 +18,10 @@

from __future__ import annotations

import collections
import collections.abc
import functools
import operator
from typing import TYPE_CHECKING, Any, Iterable, Mapping, NamedTuple, Sequence, Sized, Union
from typing import TYPE_CHECKING, Any, Dict, Iterable, Mapping, NamedTuple, Sequence, Sized, Union

from airflow.compat.functools import cache
from airflow.utils.context import Context
Expand All @@ -34,9 +33,13 @@

ExpandInput = Union["DictOfListsExpandInput", "ListOfDictsExpandInput"]

# BaseOperator.expand() can be called on an XComArg, sequence, or dict (not any
# mapping since we need the value to be ordered).
Mappable = Union["XComArg", Sequence, dict]
# Each keyword argument to expand() can be an XComArg, sequence, or dict (not
# any mapping since we need the value to be ordered).
OperatorExpandArgument = Union["XComArg", Sequence, Dict[str, Any]]

# The single argument of expand_kwargs() can be an XComArg, or a list with each
# element being either an XComArg or a dict.
OperatorExpandKwargsArgument = Union["XComArg", Sequence[Union["XComArg", Mapping[str, Any]]]]


# For isinstance() check.
Expand Down Expand Up @@ -68,13 +71,9 @@ class DictOfListsExpandInput(NamedTuple):
This is created from ``expand(**kwargs)``.
"""

value: dict[str, Mappable]

def get_unresolved_kwargs(self) -> dict[str, Any]:
"""Get the kwargs dict that can be inferred without resolving."""
return self.value
value: dict[str, OperatorExpandArgument]

def iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]:
def _iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]:
"""Generate kwargs with values available on parse-time."""
from airflow.models.xcom_arg import XComArg

Expand All @@ -83,7 +82,7 @@ def iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]:
def get_parse_time_mapped_ti_count(self) -> int | None:
if not self.value:
return 0
literal_values = [len(v) for _, v in self.iter_parse_time_resolved_kwargs()]
literal_values = [len(v) for _, v in self._iter_parse_time_resolved_kwargs()]
if len(literal_values) != len(self.value):
return None # None-literal type encountered, so give up.
return functools.reduce(operator.mul, literal_values, 1)
Expand Down Expand Up @@ -149,7 +148,7 @@ def _find_index_for_this_field(index: int) -> int:

def resolve(self, context: Context, session: Session) -> tuple[Mapping[str, Any], set[int]]:
data = {k: self._expand_mapped_field(k, v, context, session=session) for k, v in self.value.items()}
literal_keys = {k for k, _ in self.iter_parse_time_resolved_kwargs()}
literal_keys = {k for k, _ in self._iter_parse_time_resolved_kwargs()}
resolved_oids = {id(v) for k, v in data.items() if k not in literal_keys}
return data, resolved_oids

Expand All @@ -166,23 +165,16 @@ class ListOfDictsExpandInput(NamedTuple):
This is created from ``expand_kwargs(xcom_arg)``.
"""

value: XComArg

def get_unresolved_kwargs(self) -> dict[str, Any]:
"""Get the kwargs dict that can be inferred without resolving.

Since the list-of-dicts case relies entirely on run-time XCom, there's
no kwargs structure available, so this just returns an empty dict.
"""
return {}

def iter_parse_time_resolved_kwargs(self) -> Iterable[tuple[str, Sized]]:
return ()
value: OperatorExpandKwargsArgument

def get_parse_time_mapped_ti_count(self) -> int | None:
if isinstance(self.value, collections.abc.Sized):
return len(self.value)
return None

def get_total_map_length(self, run_id: str, *, session: Session) -> int:
if isinstance(self.value, collections.abc.Sized):
return len(self.value)
length = self.value.get_task_map_length(run_id, session=session)
if length is None:
raise NotFullyPopulated({"expand_kwargs() argument"})
Expand All @@ -192,12 +184,21 @@ def resolve(self, context: Context, session: Session) -> tuple[Mapping[str, Any]
map_index = context["ti"].map_index
if map_index < 0:
raise RuntimeError("can't resolve task-mapping argument without expanding")
mappings = self.value.resolve(context, session)
if not isinstance(mappings, collections.abc.Sequence):
raise ValueError(f"expand_kwargs() expects a list[dict], not {_describe_type(mappings)}")
mapping = mappings[map_index]

mapping: Any
if isinstance(self.value, collections.abc.Sized):
mapping = self.value[map_index]
if not isinstance(mapping, collections.abc.Mapping):
mapping = mapping.resolve(context, session)
else:
mappings = self.value.resolve(context, session)
if not isinstance(mappings, collections.abc.Sequence):
raise ValueError(f"expand_kwargs() expects a list[dict], not {_describe_type(mappings)}")
mapping = mappings[map_index]

if not isinstance(mapping, collections.abc.Mapping):
raise ValueError(f"expand_kwargs() expects a list[dict], not list[{_describe_type(mapping)}]")

for key in mapping:
if not isinstance(key, str):
raise ValueError(
Expand Down
15 changes: 10 additions & 5 deletions airflow/models/mappedoperator.py
Expand Up @@ -64,8 +64,9 @@
DictOfListsExpandInput,
ExpandInput,
ListOfDictsExpandInput,
Mappable,
NotFullyPopulated,
OperatorExpandArgument,
OperatorExpandKwargsArgument,
get_mappable_types,
)
from airflow.models.pool import Pool
Expand Down Expand Up @@ -184,7 +185,7 @@ def __del__(self):
task_id = f"at {hex(id(self))}"
warnings.warn(f"Task {task_id} was never mapped!")

def expand(self, **mapped_kwargs: "Mappable") -> "MappedOperator":
def expand(self, **mapped_kwargs: OperatorExpandArgument) -> "MappedOperator":
if not mapped_kwargs:
raise TypeError("no arguments to expand against")
validate_mapping_kwargs(self.operator_class, "expand", mapped_kwargs)
Expand All @@ -193,11 +194,15 @@ def expand(self, **mapped_kwargs: "Mappable") -> "MappedOperator":
# to False to skip the checks on execution.
return self._expand(DictOfListsExpandInput(mapped_kwargs), strict=False)

def expand_kwargs(self, kwargs: "XComArg", *, strict: bool = True) -> "MappedOperator":
def expand_kwargs(self, kwargs: OperatorExpandKwargsArgument, *, strict: bool = True) -> "MappedOperator":
from airflow.models.xcom_arg import XComArg

if not isinstance(kwargs, XComArg):
raise TypeError(f"expected XComArg object, not {type(kwargs).__name__}")
if isinstance(kwargs, collections.abc.Sequence):
for item in kwargs:
if not isinstance(item, (XComArg, collections.abc.Mapping)):
raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
elif not isinstance(kwargs, XComArg):
raise TypeError(f"expected XComArg or list[dict], not {type(kwargs).__name__}")
return self._expand(ListOfDictsExpandInput(kwargs), strict=strict)

def _expand(self, expand_input: ExpandInput, *, strict: bool) -> "MappedOperator":
Expand Down
51 changes: 39 additions & 12 deletions tests/models/test_dagrun.py
Expand Up @@ -1824,21 +1824,48 @@ def test_schedule_tis_map_index(dag_maker, session):


def test_mapped_expand_kwargs(dag_maker):
with dag_maker() as dag:
with dag_maker():

@task
def task_1():
return [{"arg1": "a", "arg2": "b"}, {"arg1": "y"}, {"arg2": "z"}]
def task_0():
return {"arg1": "a", "arg2": "b"}

MockOperator.partial(task_id="task_2").expand_kwargs(task_1())
@task
def task_1(args_0):
return [args_0, {"arg1": "y"}, {"arg2": "z"}]

dr: DagRun = dag_maker.create_dagrun()
assert len([ti for ti in dr.get_task_instances() if ti.task_id == "task_2"]) == 1
args_0 = task_0()
args_list = task_1(args_0=args_0)

ti1 = dr.get_task_instance("task_1")
ti1.refresh_from_task(dag.get_task("task_1"))
ti1.run()
MockOperator.partial(task_id="task_2").expand_kwargs(args_list)
MockOperator.partial(task_id="task_3").expand_kwargs(
[{"arg1": "a", "arg2": "b"}, {"arg1": "y"}, {"arg2": "z"}],
)
MockOperator.partial(task_id="task_4").expand_kwargs([args_0, {"arg1": "y"}, {"arg2": "z"}])

dr.task_instance_scheduling_decisions()
ti_states = {ti.map_index: ti.state for ti in dr.get_task_instances() if ti.task_id == "task_2"}
assert ti_states == {0: None, 1: None, 2: None}
dr: DagRun = dag_maker.create_dagrun()
tis = {(ti.task_id, ti.map_index): ti for ti in dr.task_instances}

# task_2 is not expanded yet since it relies on one single XCom input.
# task_3 and task_4 received a pure literal and can expanded right away.
# task_4 relies on an XCom input in the list, but can also be expanded.
assert sorted(map_index for (task_id, map_index) in tis if task_id == "task_2") == [-1]
assert sorted(map_index for (task_id, map_index) in tis if task_id == "task_3") == [0, 1, 2]
assert sorted(map_index for (task_id, map_index) in tis if task_id == "task_4") == [0, 1, 2]

tis[("task_0", -1)].run()
tis[("task_1", -1)].run()

# With the upstreams available, everything should get expanded now.
decision = dr.task_instance_scheduling_decisions()
assert {(ti.task_id, ti.map_index): ti.state for ti in decision.schedulable_tis} == {
("task_2", 0): None,
("task_2", 1): None,
("task_2", 2): None,
("task_3", 0): None,
("task_3", 1): None,
("task_3", 2): None,
("task_4", 0): None,
("task_4", 1): None,
("task_4", 2): None,
}