From 3af6579d8e732364028bb35965f487a6c45676d9 Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Fri, 1 Jul 2022 12:27:18 +0800 Subject: [PATCH 1/2] Disable attrs state management on MappedOperator The custom __getstate__ and __setstate__ implementation from attrs interacts badly with Airflow's DAG serialization and pickling. When a mapped task is deserialized, subclasses are coerced to MappedOperator. But when the instances go through DAG pickling, all attributes defined in the subclasses are dropped by attrs's custom state management. Since attrs does not do anything too special here (the logic is only important for slot=True), we can use Python's built-in implementation instead. --- airflow/models/mappedoperator.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index 21a265e6e904c..e34b1501bc346 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -240,7 +240,17 @@ def _expand(self, **mapped_kwargs: "Mappable") -> "MappedOperator": return op -@attr.define(kw_only=True) +@attr.define( + kw_only=True, + # Disable custom __getstate__ and __setstate__ generation since it interacts + # badly with Airflow's DAG serialization and pickling. When a mapped task is + # deserialized, subclasses are coerced into MappedOperator, but when it goes + # through DAG pickling, all attributes defined in the subclasses are dropped + # by attrs's custom state management. Since attrs does not do anything too + # special here (the logic is only important for slots=True), we use Python's + # built-in implementation, which works (as proven by good old BaseOperator). + getstate_setstate=False, +) class MappedOperator(AbstractOperator): """Object representing a mapped operator in a DAG.""" From 825c48c37cd81226083dbabbcfd89e49d46fc39f Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Tue, 5 Jul 2022 16:41:54 +0800 Subject: [PATCH 2/2] Test serialized mapped operator against pickling --- tests/serialization/test_dag_serialization.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/serialization/test_dag_serialization.py b/tests/serialization/test_dag_serialization.py index 7d6a43e933b2f..af5a7013661d8 100644 --- a/tests/serialization/test_dag_serialization.py +++ b/tests/serialization/test_dag_serialization.py @@ -24,6 +24,7 @@ import json import multiprocessing import os +import pickle from datetime import datetime, timedelta from glob import glob from unittest import mock @@ -1862,6 +1863,21 @@ def x(arg1, arg2, arg3): "retry_delay": timedelta(seconds=30), } + # Ensure the serialized operator can also be correctly pickled, to ensure + # correct interaction between DAG pickling and serialization. This is done + # here so we don't need to duplicate tests between pickled and non-pickled + # DAGs everywhere else. + pickled = pickle.loads(pickle.dumps(deserialized)) + assert pickled.mapped_op_kwargs == { + "arg2": {"a": 1, "b": 2}, + "arg3": _XComRef("op1", XCOM_RETURN_KEY), + } + assert pickled.partial_kwargs == { + "op_args": [], + "op_kwargs": {"arg1": [1, 2, {"a": "b"}]}, + "retry_delay": timedelta(seconds=30), + } + @pytest.mark.filterwarnings("ignore::DeprecationWarning") @pytest.mark.parametrize(