diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py index 2fb60195ef911..9be82976ae2c8 100644 --- a/airflow/models/xcom_arg.py +++ b/airflow/models/xcom_arg.py @@ -14,10 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + from __future__ import annotations +import contextlib import inspect -from typing import TYPE_CHECKING, Any, Callable, Iterator, Mapping, Sequence, overload +from typing import TYPE_CHECKING, Any, Callable, Iterator, Mapping, Sequence, Union, overload from sqlalchemy import func from sqlalchemy.orm import Session @@ -35,6 +37,11 @@ from airflow.models.dag import DAG from airflow.models.operator import Operator +# Callable objects contained by MapXComArg. We only accept callables from +# the user, but deserialize them into strings in a serialized XComArg for +# safety (those callables are arbitrary user code). +MapCallables = Sequence[Union[Callable[[Any], Any], str]] + class XComArg(DependencyMixin): """Reference to an XCom value pushed from another operator. @@ -322,15 +329,39 @@ def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any: raise XComNotFound(context["ti"].dag_id, task_id, self.key) +def _get_callable_name(f: Callable | str) -> str: + """Try to "describe" a callable by getting its name.""" + if callable(f): + return f.__name__ + # Parse the source to find whatever is behind "def". For safety, we don't + # want to evaluate the code in any meaningful way! + with contextlib.suppress(Exception): + kw, name, _ = f.lstrip().split(None, 2) + if kw == "def": + return name + return "" + + class _MapResult(Sequence): - def __init__(self, value: Sequence | dict, callables: Sequence[Callable[[Any], Any]]) -> None: + def __init__(self, value: Sequence | dict, callables: MapCallables) -> None: self.value = value self.callables = callables def __getitem__(self, index: Any) -> Any: value = self.value[index] - for f in self.callables: - value = f(value) + + # In the worker, we can access all actual callables. Call them. + callables = [f for f in self.callables if callable(f)] + if len(callables) == len(self.callables): + for f in callables: + value = f(value) + return value + + # In the scheduler, we don't have access to the actual callables, nor do + # we want to run it since it's arbitrary code. This builds a string to + # represent the call chain in the UI or logs instead. + for v in self.callables: + value = f"{_get_callable_name(v)}({value})" return value def __len__(self) -> int: @@ -342,9 +373,11 @@ class MapXComArg(XComArg): This is based on an XComArg, but also applies a series of "transforms" that convert the pulled XCom value. + + :meta private: """ - def __init__(self, arg: XComArg, callables: Sequence[Callable[[Any], Any]]) -> None: + def __init__(self, arg: XComArg, callables: MapCallables) -> None: for c in callables: if getattr(c, "_airflow_is_task_decorator", False): raise ValueError("map() argument must be a plain function, not a @task operator") @@ -352,12 +385,13 @@ def __init__(self, arg: XComArg, callables: Sequence[Callable[[Any], Any]]) -> N self.callables = callables def __repr__(self) -> str: - return f"{self.arg!r}.map([{len(self.callables)} functions])" + map_calls = "".join(f".map({_get_callable_name(f)})" for f in self.callables) + return f"{self.arg!r}{map_calls}" def _serialize(self) -> dict[str, Any]: return { "arg": serialize_xcom_arg(self.arg), - "callables": [inspect.getsource(c) for c in self.callables], + "callables": [inspect.getsource(c) if callable(c) else c for c in self.callables], } @classmethod diff --git a/tests/models/test_xcom_arg.py b/tests/models/test_xcom_arg.py index 18cbe87de1f8f..1f9a342c026d1 100644 --- a/tests/models/test_xcom_arg.py +++ b/tests/models/test_xcom_arg.py @@ -211,14 +211,14 @@ def pull(value): # Run "push_letters" and "push_numbers". decision = dr.task_instance_scheduling_decisions(session=session) - assert decision.schedulable_tis and all(ti.task_id.startswith("push_") for ti in decision.schedulable_tis) + assert sorted(ti.task_id for ti in decision.schedulable_tis) == ["push_letters", "push_numbers"] for ti in decision.schedulable_tis: ti.run(session=session) session.commit() # Run "pull". decision = dr.task_instance_scheduling_decisions(session=session) - assert decision.schedulable_tis and all(ti.task_id == "pull" for ti in decision.schedulable_tis) + assert sorted(ti.task_id for ti in decision.schedulable_tis) == ["pull"] * len(expected_results) for ti in decision.schedulable_tis: ti.run(session=session)