Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Serialize pod_override to JSON before pickling executor_config #24356

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
25 changes: 8 additions & 17 deletions airflow/models/taskinstance.py
Expand Up @@ -58,7 +58,6 @@
ForeignKeyConstraint,
Index,
Integer,
PickleType,
String,
and_,
false,
Expand Down Expand Up @@ -120,7 +119,13 @@
from airflow.utils.platform import getuser
from airflow.utils.retries import run_with_db_retries
from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime, tuple_in_condition, with_row_locks
from airflow.utils.sqlalchemy import (
ExecutorConfigType,
ExtendedJSON,
UtcDateTime,
tuple_in_condition,
with_row_locks,
)
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.timeout import timeout

Expand Down Expand Up @@ -400,20 +405,6 @@ def key(self) -> "TaskInstanceKey":
return self


def _executor_config_comparator(x, y):
"""
The TaskInstance.executor_config attribute is a pickled object that may contain
kubernetes objects. If the installed library version has changed since the
object was originally pickled, due to the underlying ``__eq__`` method on these
objects (which converts them to JSON), we may encounter attribute errors. In this
case we should replace the stored object.
"""
try:
return x == y
except AttributeError:
return False


class TaskInstance(Base, LoggingMixin):
"""
Task instances store the state of a task instance. This table is the
Expand Down Expand Up @@ -456,7 +447,7 @@ class TaskInstance(Base, LoggingMixin):
queued_dttm = Column(UtcDateTime)
queued_by_job_id = Column(Integer)
pid = Column(Integer)
executor_config = Column(PickleType(pickler=dill, comparator=_executor_config_comparator))
jedcunningham marked this conversation as resolved.
Show resolved Hide resolved
executor_config = Column(ExecutorConfigType(pickler=dill))

external_executor_id = Column(StringID())

Expand Down
63 changes: 62 additions & 1 deletion airflow/utils/sqlalchemy.py
Expand Up @@ -23,7 +23,7 @@

import pendulum
from dateutil import relativedelta
from sqlalchemy import TIMESTAMP, and_, event, false, nullsfirst, or_, tuple_
from sqlalchemy import TIMESTAMP, PickleType, and_, event, false, nullsfirst, or_, tuple_
from sqlalchemy.dialects import mssql, mysql
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm.session import Session
Expand All @@ -33,6 +33,7 @@

from airflow import settings
from airflow.configuration import conf
from airflow.serialization.enums import Encoding

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -146,6 +147,66 @@ def process_result_value(self, value, dialect):
return BaseSerialization._deserialize(value)


class ExecutorConfigType(PickleType):
"""
Adds special handling for K8s executor config. If we unpickle a k8s object that was
pickled under an earlier k8s library version, then the unpickled object may throw an error
when to_dict is called. To be more tolerant of version changes we convert to JSON using
Airflow's serializer before pickling.
"""

def bind_processor(self, dialect):

from airflow.serialization.serialized_objects import BaseSerialization

super_process = super().bind_processor(dialect)

def process(value):
if isinstance(value, dict) and 'pod_override' in value:
value['pod_override'] = BaseSerialization()._serialize(value['pod_override'])
return super_process(value)

return process

def result_processor(self, dialect, coltype):
from airflow.serialization.serialized_objects import BaseSerialization

super_process = super().result_processor(dialect, coltype)

def process(value):
value = super_process(value) # unpickle

if isinstance(value, dict) and 'pod_override' in value:
pod_override = value['pod_override']

# If pod_override was serialized with Airflow's BaseSerialization, deserialize it
if isinstance(pod_override, dict) and pod_override.get(Encoding.TYPE):
value['pod_override'] = BaseSerialization()._deserialize(pod_override)
return value

return process

def compare_values(self, x, y):
"""
The TaskInstance.executor_config attribute is a pickled object that may contain
kubernetes objects. If the installed library version has changed since the
object was originally pickled, due to the underlying ``__eq__`` method on these
objects (which converts them to JSON), we may encounter attribute errors. In this
case we should replace the stored object.

From https://github.com/apache/airflow/pull/24356 we use our serializer to store
k8s objects, but there could still be raw pickled k8s objects in the database,
stored from earlier version, so we still compare them defensively here.
"""
if self.comparator:
return self.comparator(x, y)
else:
try:
return x == y
except AttributeError:
return False


class Interval(TypeDecorator):
"""Base class representing a time interval."""

Expand Down
21 changes: 1 addition & 20 deletions tests/models/test_taskinstance.py
Expand Up @@ -57,7 +57,7 @@
)
from airflow.models.serialized_dag import SerializedDagModel
from airflow.models.taskfail import TaskFail
from airflow.models.taskinstance import TaskInstance, _executor_config_comparator
from airflow.models.taskinstance import TaskInstance
from airflow.models.taskmap import TaskMap
from airflow.models.xcom import XCOM_RETURN_KEY
from airflow.operators.bash import BashOperator
Expand Down Expand Up @@ -2848,22 +2848,3 @@ def get_extra_env():

echo_task = dag.get_task("echo")
assert "get_extra_env" in echo_task.upstream_task_ids


def test_executor_config_comparator():
"""
When comparison raises AttributeError, return False.
This can happen when executor config contains kubernetes objects pickled
under older kubernetes library version.
"""

class MockAttrError:
def __eq__(self, other):
raise AttributeError('hello')

a = MockAttrError()
with pytest.raises(AttributeError):
# just verify for ourselves that this throws
assert a == a
assert _executor_config_comparator(a, a) is False
assert _executor_config_comparator('a', 'a') is True
90 changes: 89 additions & 1 deletion tests/utils/test_sqlalchemy.py
Expand Up @@ -17,20 +17,29 @@
# under the License.
#
import datetime
import pickle
import unittest
from copy import copy
from unittest import mock
from unittest.mock import MagicMock

import pytest
from kubernetes.client import models as k8s
from parameterized import parameterized
from pytest import param
from sqlalchemy.exc import StatementError

from airflow import settings
from airflow.models import DAG
from airflow.serialization.enums import Encoding
from airflow.serialization.serialized_objects import BaseSerialization
from airflow.settings import Session
from airflow.utils.sqlalchemy import nowait, prohibit_commit, skip_locked, with_row_locks
from airflow.utils.sqlalchemy import ExecutorConfigType, nowait, prohibit_commit, skip_locked, with_row_locks
from airflow.utils.state import State
from airflow.utils.timezone import utcnow

TEST_POD = k8s.V1Pod(spec=k8s.V1PodSpec(containers=[k8s.V1Container(name="base")]))


class TestSqlAlchemyUtils(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -226,3 +235,82 @@ def test_prohibit_commit_specific_session_only(self):
def tearDown(self):
self.session.close()
settings.engine.dispose()


class TestExecutorConfigType:
@pytest.mark.parametrize(
'input',
['anything', {'pod_override': TEST_POD}],
)
def test_bind_processor(self, input):
"""
The returned bind processor should pickle the object as is, unless it is a dictionary with
a pod_override node, in which case it should run it through BaseSerialization.
"""
config_type = ExecutorConfigType()
mock_dialect = MagicMock()
mock_dialect.dbapi = None
process = config_type.bind_processor(mock_dialect)
expected = copy(input)
if 'pod_override' in input:
expected['pod_override'] = BaseSerialization()._serialize(input['pod_override'])
assert pickle.loads(process(input)) == expected

@pytest.mark.parametrize(
'input',
[
param(
pickle.dumps('anything'),
id='anything',
),
param(
pickle.dumps({'pod_override': BaseSerialization()._serialize(TEST_POD)}),
id='serialized_pod',
),
param(
pickle.dumps({'pod_override': TEST_POD}),
id='old_pickled_raw_pod',
),
param(
pickle.dumps({'pod_override': {"name": "hi"}}),
id='arbitrary_dict',
),
],
)
def test_result_processor(self, input):
"""
The returned bind processor should pickle the object as is, unless it is a dictionary with
a pod_override node whose value was serialized with BaseSerialization.
"""
config_type = ExecutorConfigType()
mock_dialect = MagicMock()
mock_dialect.dbapi = None
process = config_type.result_processor(mock_dialect, None)
result = process(input)
expected = pickle.loads(input)
pod_override = isinstance(expected, dict) and expected.get('pod_override')
if pod_override and isinstance(pod_override, dict) and pod_override.get(Encoding.TYPE):
# We should only deserialize a pod_override with BaseSerialization if
# it was serialized with BaseSerialization (which is the behavior added in #24356
expected['pod_override'] = BaseSerialization()._deserialize(expected['pod_override'])
assert result == expected

def test_compare_values(self):
"""
When comparison raises AttributeError, return False.
This can happen when executor config contains kubernetes objects pickled
under older kubernetes library version.
"""

class MockAttrError:
def __eq__(self, other):
raise AttributeError('hello')

a = MockAttrError()
with pytest.raises(AttributeError):
# just verify for ourselves that comparing directly will throw AttributeError
assert a == a

instance = ExecutorConfigType()
assert instance.compare_values(a, a) is False
assert instance.compare_values('a', 'a') is True