Skip to content

Commit

Permalink
Allow MapXComArg to resolve after serialization (#26591)
Browse files Browse the repository at this point in the history
This is useful for cases where we want to resolve an XCom without
running a worker, e.g. to display the value in UI.

Since we don't want to actually call the mapper function in this case
(the function is arbitrary code, and not running it is the entire point
to serialize operators), "resolving" the XComArg in this case would
merely produce some kind of quasi-meaningful string representation,
instead of the actual value we'd get in the worker.

Also note that this only affects a very small number of cases, since
once a worker is run for the task instance, RenderedTaskInstanceFields
would store the real resolved value and take over UI representation,
avoiding this fake resolving logic to be accessed at all.

(cherry picked from commit 3e01c0d)
  • Loading branch information
uranusjr authored and jedcunningham committed Sep 27, 2022
1 parent ef400fe commit 4ae13d9
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 9 deletions.
48 changes: 41 additions & 7 deletions airflow/models/xcom_arg.py
Expand Up @@ -14,10 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import contextlib
import inspect
from typing import TYPE_CHECKING, Any, Callable, Iterator, Mapping, Sequence, overload
from typing import TYPE_CHECKING, Any, Callable, Iterator, Mapping, Sequence, Union, overload

from sqlalchemy import func
from sqlalchemy.orm import Session
Expand All @@ -35,6 +37,11 @@
from airflow.models.dag import DAG
from airflow.models.operator import Operator

# Callable objects contained by MapXComArg. We only accept callables from
# the user, but deserialize them into strings in a serialized XComArg for
# safety (those callables are arbitrary user code).
MapCallables = Sequence[Union[Callable[[Any], Any], str]]


class XComArg(DependencyMixin):
"""Reference to an XCom value pushed from another operator.
Expand Down Expand Up @@ -322,15 +329,39 @@ def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any:
raise XComNotFound(context["ti"].dag_id, task_id, self.key)


def _get_callable_name(f: Callable | str) -> str:
"""Try to "describe" a callable by getting its name."""
if callable(f):
return f.__name__
# Parse the source to find whatever is behind "def". For safety, we don't
# want to evaluate the code in any meaningful way!
with contextlib.suppress(Exception):
kw, name, _ = f.lstrip().split(None, 2)
if kw == "def":
return name
return "<function>"


class _MapResult(Sequence):
def __init__(self, value: Sequence | dict, callables: Sequence[Callable[[Any], Any]]) -> None:
def __init__(self, value: Sequence | dict, callables: MapCallables) -> None:
self.value = value
self.callables = callables

def __getitem__(self, index: Any) -> Any:
value = self.value[index]
for f in self.callables:
value = f(value)

# In the worker, we can access all actual callables. Call them.
callables = [f for f in self.callables if callable(f)]
if len(callables) == len(self.callables):
for f in callables:
value = f(value)
return value

# In the scheduler, we don't have access to the actual callables, nor do
# we want to run it since it's arbitrary code. This builds a string to
# represent the call chain in the UI or logs instead.
for v in self.callables:
value = f"{_get_callable_name(v)}({value})"
return value

def __len__(self) -> int:
Expand All @@ -342,22 +373,25 @@ class MapXComArg(XComArg):
This is based on an XComArg, but also applies a series of "transforms" that
convert the pulled XCom value.
:meta private:
"""

def __init__(self, arg: XComArg, callables: Sequence[Callable[[Any], Any]]) -> None:
def __init__(self, arg: XComArg, callables: MapCallables) -> None:
for c in callables:
if getattr(c, "_airflow_is_task_decorator", False):
raise ValueError("map() argument must be a plain function, not a @task operator")
self.arg = arg
self.callables = callables

def __repr__(self) -> str:
return f"{self.arg!r}.map([{len(self.callables)} functions])"
map_calls = "".join(f".map({_get_callable_name(f)})" for f in self.callables)
return f"{self.arg!r}{map_calls}"

def _serialize(self) -> dict[str, Any]:
return {
"arg": serialize_xcom_arg(self.arg),
"callables": [inspect.getsource(c) for c in self.callables],
"callables": [inspect.getsource(c) if callable(c) else c for c in self.callables],
}

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_xcom_arg.py
Expand Up @@ -211,14 +211,14 @@ def pull(value):

# Run "push_letters" and "push_numbers".
decision = dr.task_instance_scheduling_decisions(session=session)
assert decision.schedulable_tis and all(ti.task_id.startswith("push_") for ti in decision.schedulable_tis)
assert sorted(ti.task_id for ti in decision.schedulable_tis) == ["push_letters", "push_numbers"]
for ti in decision.schedulable_tis:
ti.run(session=session)
session.commit()

# Run "pull".
decision = dr.task_instance_scheduling_decisions(session=session)
assert decision.schedulable_tis and all(ti.task_id == "pull" for ti in decision.schedulable_tis)
assert sorted(ti.task_id for ti in decision.schedulable_tis) == ["pull"] * len(expected_results)
for ti in decision.schedulable_tis:
ti.run(session=session)

Expand Down

0 comments on commit 4ae13d9

Please sign in to comment.