From 4ae13d9d68c3b4366b54bff8f7bb833c3fe65cef Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Sat, 24 Sep 2022 04:33:01 +0800 Subject: [PATCH] Allow MapXComArg to resolve after serialization (#26591) This is useful for cases where we want to resolve an XCom without running a worker, e.g. to display the value in UI. Since we don't want to actually call the mapper function in this case (the function is arbitrary code, and not running it is the entire point to serialize operators), "resolving" the XComArg in this case would merely produce some kind of quasi-meaningful string representation, instead of the actual value we'd get in the worker. Also note that this only affects a very small number of cases, since once a worker is run for the task instance, RenderedTaskInstanceFields would store the real resolved value and take over UI representation, avoiding this fake resolving logic to be accessed at all. (cherry picked from commit 3e01c0d97aeefce303e1fdb5cef160f192cce4fa) --- airflow/models/xcom_arg.py | 48 ++++++++++++++++++++++++++++++----- tests/models/test_xcom_arg.py | 4 +-- 2 files changed, 43 insertions(+), 9 deletions(-) diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py index 2fb60195ef911..9be82976ae2c8 100644 --- a/airflow/models/xcom_arg.py +++ b/airflow/models/xcom_arg.py @@ -14,10 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + from __future__ import annotations +import contextlib import inspect -from typing import TYPE_CHECKING, Any, Callable, Iterator, Mapping, Sequence, overload +from typing import TYPE_CHECKING, Any, Callable, Iterator, Mapping, Sequence, Union, overload from sqlalchemy import func from sqlalchemy.orm import Session @@ -35,6 +37,11 @@ from airflow.models.dag import DAG from airflow.models.operator import Operator +# Callable objects contained by MapXComArg. We only accept callables from +# the user, but deserialize them into strings in a serialized XComArg for +# safety (those callables are arbitrary user code). +MapCallables = Sequence[Union[Callable[[Any], Any], str]] + class XComArg(DependencyMixin): """Reference to an XCom value pushed from another operator. @@ -322,15 +329,39 @@ def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any: raise XComNotFound(context["ti"].dag_id, task_id, self.key) +def _get_callable_name(f: Callable | str) -> str: + """Try to "describe" a callable by getting its name.""" + if callable(f): + return f.__name__ + # Parse the source to find whatever is behind "def". For safety, we don't + # want to evaluate the code in any meaningful way! + with contextlib.suppress(Exception): + kw, name, _ = f.lstrip().split(None, 2) + if kw == "def": + return name + return "" + + class _MapResult(Sequence): - def __init__(self, value: Sequence | dict, callables: Sequence[Callable[[Any], Any]]) -> None: + def __init__(self, value: Sequence | dict, callables: MapCallables) -> None: self.value = value self.callables = callables def __getitem__(self, index: Any) -> Any: value = self.value[index] - for f in self.callables: - value = f(value) + + # In the worker, we can access all actual callables. Call them. + callables = [f for f in self.callables if callable(f)] + if len(callables) == len(self.callables): + for f in callables: + value = f(value) + return value + + # In the scheduler, we don't have access to the actual callables, nor do + # we want to run it since it's arbitrary code. This builds a string to + # represent the call chain in the UI or logs instead. + for v in self.callables: + value = f"{_get_callable_name(v)}({value})" return value def __len__(self) -> int: @@ -342,9 +373,11 @@ class MapXComArg(XComArg): This is based on an XComArg, but also applies a series of "transforms" that convert the pulled XCom value. + + :meta private: """ - def __init__(self, arg: XComArg, callables: Sequence[Callable[[Any], Any]]) -> None: + def __init__(self, arg: XComArg, callables: MapCallables) -> None: for c in callables: if getattr(c, "_airflow_is_task_decorator", False): raise ValueError("map() argument must be a plain function, not a @task operator") @@ -352,12 +385,13 @@ def __init__(self, arg: XComArg, callables: Sequence[Callable[[Any], Any]]) -> N self.callables = callables def __repr__(self) -> str: - return f"{self.arg!r}.map([{len(self.callables)} functions])" + map_calls = "".join(f".map({_get_callable_name(f)})" for f in self.callables) + return f"{self.arg!r}{map_calls}" def _serialize(self) -> dict[str, Any]: return { "arg": serialize_xcom_arg(self.arg), - "callables": [inspect.getsource(c) for c in self.callables], + "callables": [inspect.getsource(c) if callable(c) else c for c in self.callables], } @classmethod diff --git a/tests/models/test_xcom_arg.py b/tests/models/test_xcom_arg.py index 18cbe87de1f8f..1f9a342c026d1 100644 --- a/tests/models/test_xcom_arg.py +++ b/tests/models/test_xcom_arg.py @@ -211,14 +211,14 @@ def pull(value): # Run "push_letters" and "push_numbers". decision = dr.task_instance_scheduling_decisions(session=session) - assert decision.schedulable_tis and all(ti.task_id.startswith("push_") for ti in decision.schedulable_tis) + assert sorted(ti.task_id for ti in decision.schedulable_tis) == ["push_letters", "push_numbers"] for ti in decision.schedulable_tis: ti.run(session=session) session.commit() # Run "pull". decision = dr.task_instance_scheduling_decisions(session=session) - assert decision.schedulable_tis and all(ti.task_id == "pull" for ti in decision.schedulable_tis) + assert sorted(ti.task_id for ti in decision.schedulable_tis) == ["pull"] * len(expected_results) for ti in decision.schedulable_tis: ti.run(session=session)