Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix reducing mapped length of a mapped task at runtime after a clear #25531

Merged
merged 2 commits into from Aug 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 10 additions & 2 deletions airflow/models/dagrun.py
Expand Up @@ -736,7 +736,7 @@ def _filter_tis_and_exclude_removed(dag: "DAG", tis: List[TI]) -> Iterable[TI]:
yield ti

tis = list(_filter_tis_and_exclude_removed(self.get_dag(), tis))
missing_indexes = self._find_missing_task_indexes(tis, session=session)
missing_indexes = self._revise_mapped_task_indexes(tis, session=session)
if missing_indexes:
self.verify_integrity(missing_indexes=missing_indexes, session=session)

Expand Down Expand Up @@ -1158,7 +1158,7 @@ def _create_task_instances(
# TODO[HA]: We probably need to savepoint this so we can keep the transaction alive.
session.rollback()

def _find_missing_task_indexes(
def _revise_mapped_task_indexes(
self,
tis: Iterable[TI],
*,
Expand All @@ -1183,6 +1183,14 @@ def _find_missing_task_indexes(
existing_indexes[task].append(ti.map_index)
task.run_time_mapped_ti_count.cache_clear() # type: ignore[attr-defined]
new_length = task.run_time_mapped_ti_count(self.run_id, session=session) or 0

if ti.map_index >= new_length:
self.log.debug(
"Removing task '%s' as the map_index is longer than the resolved mapping list (%d)",
ti,
new_length,
)
ti.state = State.REMOVED
new_indexes[task] = range(new_length)
missing_indexes: Dict[MappedOperator, Sequence[int]] = defaultdict(list)
for k, v in existing_indexes.items():
Expand Down
64 changes: 64 additions & 0 deletions tests/models/test_dagrun.py
Expand Up @@ -1227,6 +1227,70 @@ def task_2(arg2):
]


def test_mapped_literal_length_reduction_at_runtime_adds_removed_state(dag_maker, session):
"""
Test that when the length of mapped literal reduces at runtime, the missing task instances
are marked as removed
"""
from airflow.models import Variable

Variable.set(key='arg1', value=[1, 2, 3])

@task
def task_1():
return Variable.get('arg1', deserialize_json=True)

with dag_maker(session=session) as dag:

@task
def task_2(arg2):
...

task_2.expand(arg2=task_1())

dr = dag_maker.create_dagrun()
ti = dr.get_task_instance(task_id='task_1')
ti.run()
dr.task_instance_scheduling_decisions()
tis = dr.get_task_instances()
indices = [(ti.map_index, ti.state) for ti in tis if ti.map_index >= 0]
assert sorted(indices) == [
(0, State.NONE),
(1, State.NONE),
(2, State.NONE),
]

# Now "clear" and "reduce" the length of literal
dag.clear()
Variable.set(key='arg1', value=[1, 2])

with dag:
task_2.expand(arg2=task_1()).operator

# At this point, we need to test that the change works on the serialized
# DAG (which is what the scheduler operates on)
serialized_dag = SerializedDAG.from_dict(SerializedDAG.to_dict(dag))

dr.dag = serialized_dag

# Run the first task again to get the new lengths
ti = dr.get_task_instance(task_id='task_1')
task1 = dag.get_task('task_1')
ti.refresh_from_task(task1)
ti.run()

# this would be called by the localtask job
dr.task_instance_scheduling_decisions()
tis = dr.get_task_instances()

indices = [(ti.map_index, ti.state) for ti in tis if ti.map_index >= 0]
assert sorted(indices) == [
(0, State.NONE),
(1, State.NONE),
(2, TaskInstanceState.REMOVED),
]


@pytest.mark.need_serialized_dag
def test_mapped_mixed__literal_not_expanded_at_create(dag_maker, session):
literal = [1, 2, 3, 4]
Expand Down