Skip to content

Commit

Permalink
Fix deadlock when mapped task with removed upstream is rerun (#26518)
Browse files Browse the repository at this point in the history
When a dag with a mapped downstream tasks that depends on a mapped upstream tasks that have some mapped indexes
removed is rerun, we run into a deadlock because the trigger rules evaluation is not accounting for removed
task instances.

The fix for the deadlocks was to account for the removed task instances where possible in the trigger rules

In this fix, I added a case where if we set flag_upstream_failed, then for the removed task instance, the downstream of that task instance will be removed. That's if the upstream with index 3 is removed, then downstream
with index 3 will also be removed if flag_upstream_failed is set to True.
  • Loading branch information
ephraimbuddy committed Sep 21, 2022
1 parent a60e3b9 commit e91637f
Show file tree
Hide file tree
Showing 4 changed files with 395 additions and 36 deletions.
18 changes: 17 additions & 1 deletion airflow/ti_deps/deps/trigger_rule_dep.py
Expand Up @@ -59,6 +59,7 @@ def _get_states_count_upstream_ti(task, finished_tis):
counter.get(State.SKIPPED, 0),
counter.get(State.FAILED, 0),
counter.get(State.UPSTREAM_FAILED, 0),
counter.get(State.REMOVED, 0),
sum(counter.values()),
)

Expand All @@ -73,7 +74,7 @@ def _get_dep_statuses(self, ti, session, dep_context: DepContext):
yield self._passing_status(reason="The task had a always trigger rule set.")
return
# see if the task name is in the task upstream for our task
successes, skipped, failed, upstream_failed, done = self._get_states_count_upstream_ti(
successes, skipped, failed, upstream_failed, removed, done = self._get_states_count_upstream_ti(
task=ti.task, finished_tis=dep_context.ensure_finished_tis(ti.get_dagrun(session), session)
)

Expand All @@ -83,6 +84,7 @@ def _get_dep_statuses(self, ti, session, dep_context: DepContext):
skipped=skipped,
failed=failed,
upstream_failed=upstream_failed,
removed=removed,
done=done,
flag_upstream_failed=dep_context.flag_upstream_failed,
dep_context=dep_context,
Expand Down Expand Up @@ -122,6 +124,7 @@ def _evaluate_trigger_rule(
skipped,
failed,
upstream_failed,
removed,
done,
flag_upstream_failed,
dep_context: DepContext,
Expand Down Expand Up @@ -152,6 +155,7 @@ def _evaluate_trigger_rule(
"successes": successes,
"skipped": skipped,
"failed": failed,
"removed": removed,
"upstream_failed": upstream_failed,
"done": done,
}
Expand All @@ -162,6 +166,9 @@ def _evaluate_trigger_rule(
changed = ti.set_state(State.UPSTREAM_FAILED, session)
elif skipped:
changed = ti.set_state(State.SKIPPED, session)
elif removed and successes and ti.map_index > -1:
if ti.map_index >= successes:
changed = ti.set_state(State.REMOVED, session)
elif trigger_rule == TR.ALL_FAILED:
if successes or skipped:
changed = ti.set_state(State.SKIPPED, session)
Expand Down Expand Up @@ -189,6 +196,7 @@ def _evaluate_trigger_rule(
elif trigger_rule == TR.ALL_SKIPPED:
if successes or failed:
changed = ti.set_state(State.SKIPPED, session)

if changed:
dep_context.have_changed_ti_states = True

Expand All @@ -212,6 +220,8 @@ def _evaluate_trigger_rule(
)
elif trigger_rule == TR.ALL_SUCCESS:
num_failures = upstream - successes
if ti.map_index > -1:
num_failures -= removed
if num_failures > 0:
yield self._failing_status(
reason=(
Expand All @@ -223,6 +233,8 @@ def _evaluate_trigger_rule(
)
elif trigger_rule == TR.ALL_FAILED:
num_successes = upstream - failed - upstream_failed
if ti.map_index > -1:
num_successes -= removed
if num_successes > 0:
yield self._failing_status(
reason=(
Expand All @@ -244,6 +256,8 @@ def _evaluate_trigger_rule(
)
elif trigger_rule == TR.NONE_FAILED:
num_failures = upstream - successes - skipped
if ti.map_index > -1:
num_failures -= removed
if num_failures > 0:
yield self._failing_status(
reason=(
Expand All @@ -255,6 +269,8 @@ def _evaluate_trigger_rule(
)
elif trigger_rule == TR.NONE_FAILED_MIN_ONE_SUCCESS:
num_failures = upstream - successes - skipped
if ti.map_index > -1:
num_failures -= removed
if num_failures > 0:
yield self._failing_status(
reason=(
Expand Down
40 changes: 40 additions & 0 deletions tests/models/test_dagrun.py
Expand Up @@ -1905,3 +1905,43 @@ def say_hi():
dr.update_state(session=session)
assert dr.state == DagRunState.SUCCESS
assert tis['add_one__1'].state == TaskInstanceState.SKIPPED


def test_schedulable_task_exist_when_rerun_removed_upstream_mapped_task(session, dag_maker):
from airflow.decorators import task

@task
def do_something(i):
return 1

@task
def do_something_else(i):
return 1

with dag_maker():
nums = do_something.expand(i=[i + 1 for i in range(5)])
do_something_else.expand(i=nums)

dr = dag_maker.create_dagrun()

ti = dr.get_task_instance('do_something_else', session=session)
ti.map_index = 0
task = ti.task
for map_index in range(1, 5):
ti = TI(task, run_id=dr.run_id, map_index=map_index)
ti.dag_run = dr
session.add(ti)
session.flush()
tis = dr.get_task_instances()
for ti in tis:
if ti.task_id == 'do_something':
if ti.map_index > 2:
ti.state = TaskInstanceState.REMOVED
else:
ti.state = TaskInstanceState.SUCCESS
session.merge(ti)
session.commit()
# The Upstream is done with 2 removed tis and 3 success tis
(tis, _) = dr.update_state()
assert len(tis)
assert dr.state != DagRunState.FAILED
178 changes: 147 additions & 31 deletions tests/models/test_taskinstance.py
Expand Up @@ -1065,55 +1065,55 @@ def test_depends_on_past(self, dag_maker):
# Parameterized tests to check for the correct firing
# of the trigger_rule under various circumstances
# Numeric fields are in order:
# successes, skipped, failed, upstream_failed, done
# successes, skipped, failed, upstream_failed, done, removed
@pytest.mark.parametrize(
"trigger_rule,successes,skipped,failed,upstream_failed,done,"
"trigger_rule,successes,skipped,failed,upstream_failed,done,removed,"
"flag_upstream_failed,expect_state,expect_completed",
[
#
# Tests for all_success
#
['all_success', 5, 0, 0, 0, 0, True, None, True],
['all_success', 2, 0, 0, 0, 0, True, None, False],
['all_success', 2, 0, 1, 0, 0, True, State.UPSTREAM_FAILED, False],
['all_success', 2, 1, 0, 0, 0, True, State.SKIPPED, False],
['all_success', 5, 0, 0, 0, 0, 0, True, None, True],
['all_success', 2, 0, 0, 0, 0, 0, True, None, False],
['all_success', 2, 0, 1, 0, 0, 0, True, State.UPSTREAM_FAILED, False],
['all_success', 2, 1, 0, 0, 0, 0, True, State.SKIPPED, False],
#
# Tests for one_success
#
['one_success', 5, 0, 0, 0, 5, True, None, True],
['one_success', 2, 0, 0, 0, 2, True, None, True],
['one_success', 2, 0, 1, 0, 3, True, None, True],
['one_success', 2, 1, 0, 0, 3, True, None, True],
['one_success', 0, 5, 0, 0, 5, True, State.SKIPPED, False],
['one_success', 0, 4, 1, 0, 5, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 3, 1, 1, 5, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 4, 0, 1, 5, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 0, 5, 0, 5, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 0, 4, 1, 5, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 0, 0, 5, 5, True, State.UPSTREAM_FAILED, False],
['one_success', 5, 0, 0, 0, 5, 0, True, None, True],
['one_success', 2, 0, 0, 0, 2, 0, True, None, True],
['one_success', 2, 0, 1, 0, 3, 0, True, None, True],
['one_success', 2, 1, 0, 0, 3, 0, True, None, True],
['one_success', 0, 5, 0, 0, 5, 0, True, State.SKIPPED, False],
['one_success', 0, 4, 1, 0, 5, 0, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 3, 1, 1, 5, 0, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 4, 0, 1, 5, 0, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 0, 5, 0, 5, 0, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 0, 4, 1, 5, 0, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 0, 0, 5, 5, 0, True, State.UPSTREAM_FAILED, False],
#
# Tests for all_failed
#
['all_failed', 5, 0, 0, 0, 5, True, State.SKIPPED, False],
['all_failed', 0, 0, 5, 0, 5, True, None, True],
['all_failed', 2, 0, 0, 0, 2, True, State.SKIPPED, False],
['all_failed', 2, 0, 1, 0, 3, True, State.SKIPPED, False],
['all_failed', 2, 1, 0, 0, 3, True, State.SKIPPED, False],
['all_failed', 5, 0, 0, 0, 5, 0, True, State.SKIPPED, False],
['all_failed', 0, 0, 5, 0, 5, 0, True, None, True],
['all_failed', 2, 0, 0, 0, 2, 0, True, State.SKIPPED, False],
['all_failed', 2, 0, 1, 0, 3, 0, True, State.SKIPPED, False],
['all_failed', 2, 1, 0, 0, 3, 0, True, State.SKIPPED, False],
#
# Tests for one_failed
#
['one_failed', 5, 0, 0, 0, 0, True, None, False],
['one_failed', 2, 0, 0, 0, 0, True, None, False],
['one_failed', 2, 0, 1, 0, 0, True, None, True],
['one_failed', 2, 1, 0, 0, 3, True, None, False],
['one_failed', 2, 3, 0, 0, 5, True, State.SKIPPED, False],
['one_failed', 5, 0, 0, 0, 0, 0, True, None, False],
['one_failed', 2, 0, 0, 0, 0, 0, True, None, False],
['one_failed', 2, 0, 1, 0, 0, 0, True, None, True],
['one_failed', 2, 1, 0, 0, 3, 0, True, None, False],
['one_failed', 2, 3, 0, 0, 5, 0, True, State.SKIPPED, False],
#
# Tests for done
#
['all_done', 5, 0, 0, 0, 5, True, None, True],
['all_done', 2, 0, 0, 0, 2, True, None, False],
['all_done', 2, 0, 1, 0, 3, True, None, False],
['all_done', 2, 1, 0, 0, 3, True, None, False],
['all_done', 5, 0, 0, 0, 5, 0, True, None, True],
['all_done', 2, 0, 0, 0, 2, 0, True, None, False],
['all_done', 2, 0, 1, 0, 3, 0, True, None, False],
['all_done', 2, 1, 0, 0, 3, 0, True, None, False],
],
)
def test_check_task_dependencies(
Expand All @@ -1122,6 +1122,7 @@ def test_check_task_dependencies(
successes: int,
skipped: int,
failed: int,
removed: int,
upstream_failed: int,
done: int,
flag_upstream_failed: bool,
Expand All @@ -1144,6 +1145,121 @@ def test_check_task_dependencies(
successes=successes,
skipped=skipped,
failed=failed,
removed=removed,
upstream_failed=upstream_failed,
done=done,
dep_context=DepContext(),
flag_upstream_failed=flag_upstream_failed,
)
completed = all(dep.passed for dep in dep_results)

assert completed == expect_completed
assert ti.state == expect_state

# Parameterized tests to check for the correct firing
# of the trigger_rule under various circumstances of mapped task
# Numeric fields are in order:
# successes, skipped, failed, upstream_failed, done,removed
@pytest.mark.parametrize(
"trigger_rule,successes,skipped,failed,upstream_failed,done,removed,"
"flag_upstream_failed,expect_state,expect_completed",
[
#
# Tests for all_success
#
['all_success', 5, 0, 0, 0, 0, 0, True, None, True],
['all_success', 2, 0, 0, 0, 0, 0, True, None, False],
['all_success', 2, 0, 1, 0, 0, 0, True, State.UPSTREAM_FAILED, False],
['all_success', 2, 1, 0, 0, 0, 0, True, State.SKIPPED, False],
['all_success', 3, 0, 0, 0, 0, 2, True, State.REMOVED, True], # ti.map_index >=successes
#
# Tests for one_success
#
['one_success', 5, 0, 0, 0, 5, 0, True, None, True],
['one_success', 2, 0, 0, 0, 2, 0, True, None, True],
['one_success', 2, 0, 1, 0, 3, 0, True, None, True],
['one_success', 2, 1, 0, 0, 3, 0, True, None, True],
['one_success', 0, 5, 0, 0, 5, 0, True, State.SKIPPED, False],
['one_success', 0, 4, 1, 0, 5, 0, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 3, 1, 1, 5, 0, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 4, 0, 1, 5, 0, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 0, 5, 0, 5, 0, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 0, 4, 1, 5, 0, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 0, 0, 5, 5, 0, True, State.UPSTREAM_FAILED, False],
#
# Tests for all_failed
#
['all_failed', 5, 0, 0, 0, 5, 0, True, State.SKIPPED, False],
['all_failed', 0, 0, 5, 0, 5, 0, True, None, True],
['all_failed', 2, 0, 0, 0, 2, 0, True, State.SKIPPED, False],
['all_failed', 2, 0, 1, 0, 3, 0, True, State.SKIPPED, False],
['all_failed', 2, 1, 0, 0, 3, 0, True, State.SKIPPED, False],
['all_failed', 2, 1, 0, 0, 4, 1, True, State.SKIPPED, False], # One removed
#
# Tests for one_failed
#
['one_failed', 5, 0, 0, 0, 0, 0, True, None, False],
['one_failed', 2, 0, 0, 0, 0, 0, True, None, False],
['one_failed', 2, 0, 1, 0, 0, 0, True, None, True],
['one_failed', 2, 1, 0, 0, 3, 0, True, None, False],
['one_failed', 2, 3, 0, 0, 5, 0, True, State.SKIPPED, False],
['one_failed', 2, 2, 0, 0, 5, 1, True, State.SKIPPED, False], # One removed
#
# Tests for done
#
['all_done', 5, 0, 0, 0, 5, 0, True, None, True],
['all_done', 2, 0, 0, 0, 2, 0, True, None, False],
['all_done', 2, 0, 1, 0, 3, 0, True, None, False],
['all_done', 2, 1, 0, 0, 3, 0, True, None, False],
],
)
def test_check_task_dependencies_for_mapped(
self,
trigger_rule: str,
successes: int,
skipped: int,
failed: int,
removed: int,
upstream_failed: int,
done: int,
flag_upstream_failed: bool,
expect_state: State,
expect_completed: bool,
dag_maker,
session,
):
from airflow.decorators import task

@task
def do_something(i):
return 1

@task(trigger_rule=trigger_rule)
def do_something_else(i):
return 1

with dag_maker(dag_id='test_dag'):
nums = do_something.expand(i=[i + 1 for i in range(5)])
do_something_else.expand(i=nums)

dr = dag_maker.create_dagrun()

ti = dr.get_task_instance('do_something_else', session=session)
ti.map_index = 0
for map_index in range(1, 5):
ti = TaskInstance(ti.task, run_id=dr.run_id, map_index=map_index)
ti.dag_run = dr
session.add(ti)
session.flush()
downstream = ti.task
ti = dr.get_task_instance(task_id='do_something_else', map_index=3, session=session)
ti.task = downstream
dep_results = TriggerRuleDep()._evaluate_trigger_rule(
ti=ti,
successes=successes,
skipped=skipped,
failed=failed,
removed=removed,
upstream_failed=upstream_failed,
done=done,
dep_context=DepContext(),
Expand Down

0 comments on commit e91637f

Please sign in to comment.