From 7914c6cdd466a0194a8ff60165dd4ae507aa3f06 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Tue, 5 Jul 2022 17:36:33 +0800 Subject: [PATCH] Disable attrs state management on MappedOperator (#24772) (cherry picked from commit 6fd06fa8c274b39e4ed716f8d347229e017ba8e5) --- airflow/models/mappedoperator.py | 12 +++++++++++- tests/serialization/test_dag_serialization.py | 16 ++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 6b202d2cc6310..5cce0d8e89482 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -239,7 +239,17 @@ def _expand(self, **mapped_kwargs: "Mappable") -> "MappedOperator": return op -@attr.define(kw_only=True) +@attr.define( + kw_only=True, + # Disable custom __getstate__ and __setstate__ generation since it interacts + # badly with Airflow's DAG serialization and pickling. When a mapped task is + # deserialized, subclasses are coerced into MappedOperator, but when it goes + # through DAG pickling, all attributes defined in the subclasses are dropped + # by attrs's custom state management. Since attrs does not do anything too + # special here (the logic is only important for slots=True), we use Python's + # built-in implementation, which works (as proven by good old BaseOperator). + getstate_setstate=False, +) class MappedOperator(AbstractOperator): """Object representing a mapped operator in a DAG.""" diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 6f5b9b49ebcfd..6159cee166a9b 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -24,6 +24,7 @@ import json import multiprocessing import os +import pickle from datetime import datetime, timedelta from glob import glob from unittest import mock @@ -1890,6 +1891,21 @@ def x(arg1, arg2, arg3): "retry_delay": timedelta(seconds=30), } + # Ensure the serialized operator can also be correctly pickled, to ensure + # correct interaction between DAG pickling and serialization. This is done + # here so we don't need to duplicate tests between pickled and non-pickled + # DAGs everywhere else. + pickled = pickle.loads(pickle.dumps(deserialized)) + assert pickled.mapped_op_kwargs == { + "arg2": {"a": 1, "b": 2}, + "arg3": _XComRef("op1", XCOM_RETURN_KEY), + } + assert pickled.partial_kwargs == { + "op_args": [], + "op_kwargs": {"arg1": [1, 2, {"a": "b"}]}, + "retry_delay": timedelta(seconds=30), + } + @pytest.mark.filterwarnings("ignore::DeprecationWarning") @pytest.mark.parametrize(