diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 4b7f9ebf7b9aa..d72a7edd4ad6a 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -60,7 +60,6 @@ ForeignKeyConstraint, Index, Integer, - PickleType, String, and_, false, @@ -123,7 +122,13 @@ from airflow.utils.platform import getuser from airflow.utils.retries import run_with_db_retries from airflow.utils.session import NEW_SESSION, create_session, provide_session -from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime, tuple_in_condition, with_row_locks +from airflow.utils.sqlalchemy import ( + ExecutorConfigType, + ExtendedJSON, + UtcDateTime, + tuple_in_condition, + with_row_locks, +) from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.timeout import timeout @@ -430,20 +435,6 @@ def key(self) -> "TaskInstanceKey": return self -def _executor_config_comparator(x, y): - """ - The TaskInstance.executor_config attribute is a pickled object that may contain - kubernetes objects. If the installed library version has changed since the - object was originally pickled, due to the underlying ``__eq__`` method on these - objects (which converts them to JSON), we may encounter attribute errors. In this - case we should replace the stored object. - """ - try: - return x == y - except AttributeError: - return False - - class TaskInstance(Base, LoggingMixin): """ Task instances store the state of a task instance. This table is the @@ -486,7 +477,7 @@ class TaskInstance(Base, LoggingMixin): queued_dttm = Column(UtcDateTime) queued_by_job_id = Column(Integer) pid = Column(Integer) - executor_config = Column(PickleType(pickler=dill, comparator=_executor_config_comparator)) + executor_config = Column(ExecutorConfigType(pickler=dill)) external_executor_id = Column(String(ID_LEN, **COLLATION_ARGS)) diff --git a/airflow/utils/sqlalchemy.py b/airflow/utils/sqlalchemy.py index c94b4100dd921..a8c15e75d68fb 100644 --- a/airflow/utils/sqlalchemy.py +++ b/airflow/utils/sqlalchemy.py @@ -23,7 +23,7 @@ import pendulum from dateutil import relativedelta -from sqlalchemy import and_, event, false, nullsfirst, or_, tuple_ +from sqlalchemy import PickleType, and_, event, false, nullsfirst, or_, tuple_ from sqlalchemy.dialects import mssql from sqlalchemy.exc import OperationalError from sqlalchemy.orm.session import Session @@ -33,6 +33,7 @@ from airflow import settings from airflow.configuration import conf +from airflow.serialization.enums import Encoding log = logging.getLogger(__name__) @@ -148,6 +149,66 @@ def process_result_value(self, value, dialect): return BaseSerialization._deserialize(value) +class ExecutorConfigType(PickleType): + """ + Adds special handling for K8s executor config. If we unpickle a k8s object that was + pickled under an earlier k8s library version, then the unpickled object may throw an error + when to_dict is called. To be more tolerant of version changes we convert to JSON using + Airflow's serializer before pickling. + """ + + def bind_processor(self, dialect): + + from airflow.serialization.serialized_objects import BaseSerialization + + super_process = super().bind_processor(dialect) + + def process(value): + if isinstance(value, dict) and 'pod_override' in value: + value['pod_override'] = BaseSerialization()._serialize(value['pod_override']) + return super_process(value) + + return process + + def result_processor(self, dialect, coltype): + from airflow.serialization.serialized_objects import BaseSerialization + + super_process = super().result_processor(dialect, coltype) + + def process(value): + value = super_process(value) # unpickle + + if isinstance(value, dict) and 'pod_override' in value: + pod_override = value['pod_override'] + + # If pod_override was serialized with Airflow's BaseSerialization, deserialize it + if isinstance(pod_override, dict) and pod_override.get(Encoding.TYPE): + value['pod_override'] = BaseSerialization()._deserialize(pod_override) + return value + + return process + + def compare_values(self, x, y): + """ + The TaskInstance.executor_config attribute is a pickled object that may contain + kubernetes objects. If the installed library version has changed since the + object was originally pickled, due to the underlying ``__eq__`` method on these + objects (which converts them to JSON), we may encounter attribute errors. In this + case we should replace the stored object. + + From https://github.com/apache/airflow/pull/24356 we use our serializer to store + k8s objects, but there could still be raw pickled k8s objects in the database, + stored from earlier version, so we still compare them defensively here. + """ + if self.comparator: + return self.comparator(x, y) + else: + try: + return x == y + except AttributeError: + return False + + class Interval(TypeDecorator): """Base class representing a time interval.""" diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 74ef87489ffd3..bcb1fa4b45677 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -57,12 +57,7 @@ XCom, ) from airflow.models.taskfail import TaskFail -from airflow.models.taskinstance import ( - TaskInstance, - _executor_config_comparator, - load_error_file, - set_error_file, -) +from airflow.models.taskinstance import TaskInstance, load_error_file, set_error_file from airflow.models.taskmap import TaskMap from airflow.models.xcom import XCOM_RETURN_KEY from airflow.operators.bash import BashOperator @@ -3124,22 +3119,3 @@ def get_extra_env(): echo_task = dag.get_task("echo") assert "get_extra_env" in echo_task.upstream_task_ids - - -def test_executor_config_comparator(): - """ - When comparison raises AttributeError, return False. - This can happen when executor config contains kubernetes objects pickled - under older kubernetes library version. - """ - - class MockAttrError: - def __eq__(self, other): - raise AttributeError('hello') - - a = MockAttrError() - with pytest.raises(AttributeError): - # just verify for ourselves that this throws - assert a == a - assert _executor_config_comparator(a, a) is False - assert _executor_config_comparator('a', 'a') is True diff --git a/tests/utils/test_sqlalchemy.py b/tests/utils/test_sqlalchemy.py index 250bd37756c48..0038577b8abc8 100644 --- a/tests/utils/test_sqlalchemy.py +++ b/tests/utils/test_sqlalchemy.py @@ -17,20 +17,29 @@ # under the License. # import datetime +import pickle import unittest +from copy import copy from unittest import mock +from unittest.mock import MagicMock import pytest +from kubernetes.client import models as k8s from parameterized import parameterized +from pytest import param from sqlalchemy.exc import StatementError from airflow import settings from airflow.models import DAG +from airflow.serialization.enums import Encoding +from airflow.serialization.serialized_objects import BaseSerialization from airflow.settings import Session -from airflow.utils.sqlalchemy import nowait, prohibit_commit, skip_locked, with_row_locks +from airflow.utils.sqlalchemy import ExecutorConfigType, nowait, prohibit_commit, skip_locked, with_row_locks from airflow.utils.state import State from airflow.utils.timezone import utcnow +TEST_POD = k8s.V1Pod(spec=k8s.V1PodSpec(containers=[k8s.V1Container(name="base")])) + class TestSqlAlchemyUtils(unittest.TestCase): def setUp(self): @@ -226,3 +235,82 @@ def test_prohibit_commit_specific_session_only(self): def tearDown(self): self.session.close() settings.engine.dispose() + + +class TestExecutorConfigType: + @pytest.mark.parametrize( + 'input', + ['anything', {'pod_override': TEST_POD}], + ) + def test_bind_processor(self, input): + """ + The returned bind processor should pickle the object as is, unless it is a dictionary with + a pod_override node, in which case it should run it through BaseSerialization. + """ + config_type = ExecutorConfigType() + mock_dialect = MagicMock() + mock_dialect.dbapi = None + process = config_type.bind_processor(mock_dialect) + expected = copy(input) + if 'pod_override' in input: + expected['pod_override'] = BaseSerialization()._serialize(input['pod_override']) + assert pickle.loads(process(input)) == expected + + @pytest.mark.parametrize( + 'input', + [ + param( + pickle.dumps('anything'), + id='anything', + ), + param( + pickle.dumps({'pod_override': BaseSerialization()._serialize(TEST_POD)}), + id='serialized_pod', + ), + param( + pickle.dumps({'pod_override': TEST_POD}), + id='old_pickled_raw_pod', + ), + param( + pickle.dumps({'pod_override': {"name": "hi"}}), + id='arbitrary_dict', + ), + ], + ) + def test_result_processor(self, input): + """ + The returned bind processor should pickle the object as is, unless it is a dictionary with + a pod_override node whose value was serialized with BaseSerialization. + """ + config_type = ExecutorConfigType() + mock_dialect = MagicMock() + mock_dialect.dbapi = None + process = config_type.result_processor(mock_dialect, None) + result = process(input) + expected = pickle.loads(input) + pod_override = isinstance(expected, dict) and expected.get('pod_override') + if pod_override and isinstance(pod_override, dict) and pod_override.get(Encoding.TYPE): + # We should only deserialize a pod_override with BaseSerialization if + # it was serialized with BaseSerialization (which is the behavior added in #24356 + expected['pod_override'] = BaseSerialization()._deserialize(expected['pod_override']) + assert result == expected + + def test_compare_values(self): + """ + When comparison raises AttributeError, return False. + This can happen when executor config contains kubernetes objects pickled + under older kubernetes library version. + """ + + class MockAttrError: + def __eq__(self, other): + raise AttributeError('hello') + + a = MockAttrError() + with pytest.raises(AttributeError): + # just verify for ourselves that comparing directly will throw AttributeError + assert a == a + + instance = ExecutorConfigType() + assert instance.compare_values(a, a) is False + assert instance.compare_values('a', 'a') is True