From 237f2240b0b011c9d97858ba8f5c63795f692cdf Mon Sep 17 00:00:00 2001 From: "Robert J. McGinness" Date: Mon, 26 Sep 2022 05:02:55 -0400 Subject: [PATCH] Fix xcom arg.py .zip bug (#26636) (cherry picked from commit f219bfbe22e662a8747af19d688bbe843e1a953d) --- airflow/models/xcom_arg.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/airflow/models/xcom_arg.py b/airflow/models/xcom_arg.py index 9be82976ae2c8..b70f26cd4fc72 100644 --- a/airflow/models/xcom_arg.py +++ b/airflow/models/xcom_arg.py @@ -31,7 +31,7 @@ from airflow.utils.context import Context from airflow.utils.edgemodifier import EdgeModifier from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.types import NOTSET +from airflow.utils.types import NOTSET, ArgNotSet if TYPE_CHECKING: from airflow.models.dag import DAG @@ -322,7 +322,7 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None: def resolve(self, context: Context, session: Session = NEW_SESSION) -> Any: task_id = self.operator.task_id result = context["ti"].xcom_pull(task_ids=task_id, key=str(self.key), default=NOTSET, session=session) - if result is not NOTSET: + if not isinstance(result, ArgNotSet): return result if self.key == XCOM_RETURN_KEY: return None @@ -437,7 +437,7 @@ def __getitem__(self, index: Any) -> Any: def __len__(self) -> int: lengths = (len(v) for v in self.values) - if self.fillvalue is NOTSET: + if isinstance(self.fillvalue, ArgNotSet): return min(lengths) return max(lengths) @@ -460,13 +460,13 @@ def __repr__(self) -> str: args_iter = iter(self.args) first = repr(next(args_iter)) rest = ", ".join(repr(arg) for arg in args_iter) - if self.fillvalue is NOTSET: + if isinstance(self.fillvalue, ArgNotSet): return f"{first}.zip({rest})" return f"{first}.zip({rest}, fillvalue={self.fillvalue!r})" def _serialize(self) -> dict[str, Any]: args = [serialize_xcom_arg(arg) for arg in self.args] - if self.fillvalue is NOTSET: + if isinstance(self.fillvalue, ArgNotSet): return {"args": args} return {"args": args, "fillvalue": self.fillvalue} @@ -486,7 +486,7 @@ def get_task_map_length(self, run_id: str, *, session: Session) -> int | None: ready_lengths = [length for length in all_lengths if length is not None] if len(ready_lengths) != len(self.args): return None # If any of the referenced XComs is not ready, we are not ready either. - if self.fillvalue is NOTSET: + if isinstance(self.fillvalue, ArgNotSet): return min(ready_lengths) return max(ready_lengths)