Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix TIPydantic serialization of MappedOperator #39288

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 6 additions & 1 deletion airflow/models/expandinput.py
Expand Up @@ -33,6 +33,7 @@

from airflow.models.operator import Operator
from airflow.models.xcom_arg import XComArg
from airflow.serialization.serialized_objects import _ExpandInputRef
from airflow.typing_compat import TypeGuard
from airflow.utils.context import Context

Expand Down Expand Up @@ -281,7 +282,11 @@ def resolve(self, context: Context, session: Session) -> tuple[Mapping[str, Any]
}


def get_map_type_key(expand_input: ExpandInput) -> str:
def get_map_type_key(expand_input: ExpandInput | _ExpandInputRef) -> str:
from airflow.serialization.serialized_objects import _ExpandInputRef

if isinstance(expand_input, _ExpandInputRef):
return expand_input.key
return next(k for k, v in _EXPAND_INPUT_TYPES.items() if v == type(expand_input))


Expand Down
7 changes: 6 additions & 1 deletion airflow/models/mappedoperator.py
Expand Up @@ -799,7 +799,12 @@ def get_parse_time_mapped_ti_count(self) -> int:
return parent_count * current_count

def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int:
current_count = self._get_specified_expand_input().get_total_map_length(run_id, session=session)
from airflow.serialization.serialized_objects import _ExpandInputRef

exp_input = self._get_specified_expand_input()
if isinstance(exp_input, _ExpandInputRef):
exp_input = exp_input.deref(self.dag)
current_count = exp_input.get_total_map_length(run_id, session=session)
try:
parent_count = super().get_mapped_ti_count(run_id, session=session)
except NotMapped:
Expand Down
9 changes: 5 additions & 4 deletions airflow/serialization/pydantic/taskinstance.py
Expand Up @@ -49,20 +49,21 @@

def serialize_operator(x: Operator | None) -> dict | None:
if x:
from airflow.serialization.serialized_objects import SerializedBaseOperator
from airflow.serialization.serialized_objects import BaseSerialization

return SerializedBaseOperator.serialize_operator(x)
return BaseSerialization.serialize(x, use_pydantic_models=True)
return None


def validated_operator(x: dict[str, Any] | Operator, _info: ValidationInfo) -> Any:
from airflow.models.baseoperator import BaseOperator
from airflow.models.mappedoperator import MappedOperator
from airflow.serialization.serialized_objects import SerializedBaseOperator

if isinstance(x, BaseOperator) or isinstance(x, MappedOperator) or x is None:
return x
return SerializedBaseOperator.deserialize_operator(x)
from airflow.serialization.serialized_objects import BaseSerialization

return BaseSerialization.deserialize(x, use_pydantic_models=True)


PydanticOperator = Annotated[
Expand Down
4 changes: 3 additions & 1 deletion airflow/serialization/serialized_objects.py
Expand Up @@ -1176,6 +1176,8 @@ def populate_operator(cls, op: Operator, encoded_op: dict[str, Any]) -> None:
v = {arg: cls.deserialize(value) for arg, value in v.items()}
elif k in {"expand_input", "op_kwargs_expand_input"}:
v = _ExpandInputRef(v["type"], cls.deserialize(v["value"]))
elif k == "operator_class":
v = {k_: cls.deserialize(v_, use_pydantic_models=True) for k_, v_ in v.items()}
elif (
k in cls._decorated_fields
or k not in op.get_serialized_fields()
Expand All @@ -1191,7 +1193,7 @@ def populate_operator(cls, op: Operator, encoded_op: dict[str, Any]) -> None:
setattr(op, k, v)

for k in op.get_serialized_fields() - encoded_op.keys() - cls._CONSTRUCTOR_PARAMS.keys():
# TODO: refactor deserialization of BaseOperator and MappedOperaotr (split it out), then check
# TODO: refactor deserialization of BaseOperator and MappedOperator (split it out), then check
# could go away.
if not hasattr(op, k):
setattr(op, k, None)
Expand Down
64 changes: 64 additions & 0 deletions tests/serialization/test_pydantic_models.py
Expand Up @@ -22,8 +22,11 @@
import pytest
from dateutil import relativedelta

from airflow.decorators import task
from airflow.decorators.python import _PythonDecoratedOperator
from airflow.jobs.job import Job
from airflow.jobs.local_task_job_runner import LocalTaskJobRunner
from airflow.models import MappedOperator
from airflow.models.dag import DagModel
from airflow.models.dataset import (
DagScheduleDatasetReference,
Expand All @@ -36,6 +39,7 @@
from airflow.serialization.pydantic.dataset import DatasetEventPydantic
from airflow.serialization.pydantic.job import JobPydantic
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
from airflow.serialization.serialized_objects import BaseSerialization
from airflow.settings import _ENABLE_AIP_44
from airflow.utils import timezone
from airflow.utils.state import State
Expand Down Expand Up @@ -68,6 +72,66 @@ def test_serializing_pydantic_task_instance(session, create_task_instance):
assert deserialized_model.next_kwargs == {"foo": "bar"}


@pytest.mark.skipif(not _ENABLE_AIP_44, reason="AIP-44 is disabled")
def test_deserialize_ti_mapped_op_reserialized_with_refresh_from_task(session, dag_maker):
op_class_dict_expected = {
"_task_type": "_PythonDecoratedOperator",
"downstream_task_ids": [],
"_operator_name": "@task",
"ui_fgcolor": "#000",
"ui_color": "#ffefeb",
"template_fields": ["templates_dict", "op_args", "op_kwargs"],
"template_fields_renderers": {"templates_dict": "json", "op_args": "py", "op_kwargs": "py"},
"template_ext": [],
"task_id": "target",
}

with dag_maker():

@task
def source():
return [1, 2, 3]

@task
def target(val=None):
print(val)

# source() >> target()
target.expand(val=source())
dr = dag_maker.create_dagrun()
ti = dr.task_instances[1]

# roundtrip task
ser_task = BaseSerialization.serialize(ti.task, use_pydantic_models=True)
deser_task = BaseSerialization.deserialize(ser_task, use_pydantic_models=True)
ti.task.operator_class
# this is part of the problem!
assert isinstance(ti.task.operator_class, type)
assert isinstance(deser_task.operator_class, dict)

assert ti.task.operator_class == _PythonDecoratedOperator
ti.refresh_from_task(deser_task)
# roundtrip ti
sered = BaseSerialization.serialize(ti, use_pydantic_models=True)
desered = BaseSerialization.deserialize(sered, use_pydantic_models=True)

assert "operator_class" not in sered["__var"]["task"]

assert desered.task.__class__ == MappedOperator

assert desered.task.operator_class == op_class_dict_expected

desered.refresh_from_task(deser_task)

assert desered.task.__class__ == MappedOperator

assert isinstance(desered.task.operator_class, dict)

resered = BaseSerialization.serialize(desered, use_pydantic_models=True)
deresered = BaseSerialization.deserialize(resered, use_pydantic_models=True)
assert deresered.task.operator_class == desered.task.operator_class == op_class_dict_expected


@pytest.mark.skipif(not _ENABLE_AIP_44, reason="AIP-44 is disabled")
def test_serializing_pydantic_dagrun(session, create_task_instance):
dag_id = "test-dag"
Expand Down