From e91637f8894cac19c6b467b6669cbcc13184be70 Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Wed, 21 Sep 2022 13:52:45 +0100 Subject: [PATCH] Fix deadlock when mapped task with removed upstream is rerun (#26518) When a dag with a mapped downstream tasks that depends on a mapped upstream tasks that have some mapped indexes removed is rerun, we run into a deadlock because the trigger rules evaluation is not accounting for removed task instances. The fix for the deadlocks was to account for the removed task instances where possible in the trigger rules In this fix, I added a case where if we set flag_upstream_failed, then for the removed task instance, the downstream of that task instance will be removed. That's if the upstream with index 3 is removed, then downstream with index 3 will also be removed if flag_upstream_failed is set to True. --- airflow/ti_deps/deps/trigger_rule_dep.py | 18 +- tests/models/test_dagrun.py | 40 ++++ tests/models/test_taskinstance.py | 178 ++++++++++++++---- tests/ti_deps/deps/test_trigger_rule_dep.py | 195 +++++++++++++++++++- 4 files changed, 395 insertions(+), 36 deletions(-) diff --git a/airflow/ti_deps/deps/trigger_rule_dep.py b/airflow/ti_deps/deps/trigger_rule_dep.py index 72fa783a8ebfb..691dc3e3a5482 100644 --- a/airflow/ti_deps/deps/trigger_rule_dep.py +++ b/airflow/ti_deps/deps/trigger_rule_dep.py @@ -59,6 +59,7 @@ def _get_states_count_upstream_ti(task, finished_tis): counter.get(State.SKIPPED, 0), counter.get(State.FAILED, 0), counter.get(State.UPSTREAM_FAILED, 0), + counter.get(State.REMOVED, 0), sum(counter.values()), ) @@ -73,7 +74,7 @@ def _get_dep_statuses(self, ti, session, dep_context: DepContext): yield self._passing_status(reason="The task had a always trigger rule set.") return # see if the task name is in the task upstream for our task - successes, skipped, failed, upstream_failed, done = self._get_states_count_upstream_ti( + successes, skipped, failed, upstream_failed, removed, done = self._get_states_count_upstream_ti( task=ti.task, finished_tis=dep_context.ensure_finished_tis(ti.get_dagrun(session), session) ) @@ -83,6 +84,7 @@ def _get_dep_statuses(self, ti, session, dep_context: DepContext): skipped=skipped, failed=failed, upstream_failed=upstream_failed, + removed=removed, done=done, flag_upstream_failed=dep_context.flag_upstream_failed, dep_context=dep_context, @@ -122,6 +124,7 @@ def _evaluate_trigger_rule( skipped, failed, upstream_failed, + removed, done, flag_upstream_failed, dep_context: DepContext, @@ -152,6 +155,7 @@ def _evaluate_trigger_rule( "successes": successes, "skipped": skipped, "failed": failed, + "removed": removed, "upstream_failed": upstream_failed, "done": done, } @@ -162,6 +166,9 @@ def _evaluate_trigger_rule( changed = ti.set_state(State.UPSTREAM_FAILED, session) elif skipped: changed = ti.set_state(State.SKIPPED, session) + elif removed and successes and ti.map_index > -1: + if ti.map_index >= successes: + changed = ti.set_state(State.REMOVED, session) elif trigger_rule == TR.ALL_FAILED: if successes or skipped: changed = ti.set_state(State.SKIPPED, session) @@ -189,6 +196,7 @@ def _evaluate_trigger_rule( elif trigger_rule == TR.ALL_SKIPPED: if successes or failed: changed = ti.set_state(State.SKIPPED, session) + if changed: dep_context.have_changed_ti_states = True @@ -212,6 +220,8 @@ def _evaluate_trigger_rule( ) elif trigger_rule == TR.ALL_SUCCESS: num_failures = upstream - successes + if ti.map_index > -1: + num_failures -= removed if num_failures > 0: yield self._failing_status( reason=( @@ -223,6 +233,8 @@ def _evaluate_trigger_rule( ) elif trigger_rule == TR.ALL_FAILED: num_successes = upstream - failed - upstream_failed + if ti.map_index > -1: + num_successes -= removed if num_successes > 0: yield self._failing_status( reason=( @@ -244,6 +256,8 @@ def _evaluate_trigger_rule( ) elif trigger_rule == TR.NONE_FAILED: num_failures = upstream - successes - skipped + if ti.map_index > -1: + num_failures -= removed if num_failures > 0: yield self._failing_status( reason=( @@ -255,6 +269,8 @@ def _evaluate_trigger_rule( ) elif trigger_rule == TR.NONE_FAILED_MIN_ONE_SUCCESS: num_failures = upstream - successes - skipped + if ti.map_index > -1: + num_failures -= removed if num_failures > 0: yield self._failing_status( reason=( diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index 16b892f76e119..50e9e9a3d8eb3 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -1905,3 +1905,43 @@ def say_hi(): dr.update_state(session=session) assert dr.state == DagRunState.SUCCESS assert tis['add_one__1'].state == TaskInstanceState.SKIPPED + + +def test_schedulable_task_exist_when_rerun_removed_upstream_mapped_task(session, dag_maker): + from airflow.decorators import task + + @task + def do_something(i): + return 1 + + @task + def do_something_else(i): + return 1 + + with dag_maker(): + nums = do_something.expand(i=[i + 1 for i in range(5)]) + do_something_else.expand(i=nums) + + dr = dag_maker.create_dagrun() + + ti = dr.get_task_instance('do_something_else', session=session) + ti.map_index = 0 + task = ti.task + for map_index in range(1, 5): + ti = TI(task, run_id=dr.run_id, map_index=map_index) + ti.dag_run = dr + session.add(ti) + session.flush() + tis = dr.get_task_instances() + for ti in tis: + if ti.task_id == 'do_something': + if ti.map_index > 2: + ti.state = TaskInstanceState.REMOVED + else: + ti.state = TaskInstanceState.SUCCESS + session.merge(ti) + session.commit() + # The Upstream is done with 2 removed tis and 3 success tis + (tis, _) = dr.update_state() + assert len(tis) + assert dr.state != DagRunState.FAILED diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 0589e96d3289c..b1e36dcbaa8e9 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -1065,55 +1065,55 @@ def test_depends_on_past(self, dag_maker): # Parameterized tests to check for the correct firing # of the trigger_rule under various circumstances # Numeric fields are in order: - # successes, skipped, failed, upstream_failed, done + # successes, skipped, failed, upstream_failed, done, removed @pytest.mark.parametrize( - "trigger_rule,successes,skipped,failed,upstream_failed,done," + "trigger_rule,successes,skipped,failed,upstream_failed,done,removed," "flag_upstream_failed,expect_state,expect_completed", [ # # Tests for all_success # - ['all_success', 5, 0, 0, 0, 0, True, None, True], - ['all_success', 2, 0, 0, 0, 0, True, None, False], - ['all_success', 2, 0, 1, 0, 0, True, State.UPSTREAM_FAILED, False], - ['all_success', 2, 1, 0, 0, 0, True, State.SKIPPED, False], + ['all_success', 5, 0, 0, 0, 0, 0, True, None, True], + ['all_success', 2, 0, 0, 0, 0, 0, True, None, False], + ['all_success', 2, 0, 1, 0, 0, 0, True, State.UPSTREAM_FAILED, False], + ['all_success', 2, 1, 0, 0, 0, 0, True, State.SKIPPED, False], # # Tests for one_success # - ['one_success', 5, 0, 0, 0, 5, True, None, True], - ['one_success', 2, 0, 0, 0, 2, True, None, True], - ['one_success', 2, 0, 1, 0, 3, True, None, True], - ['one_success', 2, 1, 0, 0, 3, True, None, True], - ['one_success', 0, 5, 0, 0, 5, True, State.SKIPPED, False], - ['one_success', 0, 4, 1, 0, 5, True, State.UPSTREAM_FAILED, False], - ['one_success', 0, 3, 1, 1, 5, True, State.UPSTREAM_FAILED, False], - ['one_success', 0, 4, 0, 1, 5, True, State.UPSTREAM_FAILED, False], - ['one_success', 0, 0, 5, 0, 5, True, State.UPSTREAM_FAILED, False], - ['one_success', 0, 0, 4, 1, 5, True, State.UPSTREAM_FAILED, False], - ['one_success', 0, 0, 0, 5, 5, True, State.UPSTREAM_FAILED, False], + ['one_success', 5, 0, 0, 0, 5, 0, True, None, True], + ['one_success', 2, 0, 0, 0, 2, 0, True, None, True], + ['one_success', 2, 0, 1, 0, 3, 0, True, None, True], + ['one_success', 2, 1, 0, 0, 3, 0, True, None, True], + ['one_success', 0, 5, 0, 0, 5, 0, True, State.SKIPPED, False], + ['one_success', 0, 4, 1, 0, 5, 0, True, State.UPSTREAM_FAILED, False], + ['one_success', 0, 3, 1, 1, 5, 0, True, State.UPSTREAM_FAILED, False], + ['one_success', 0, 4, 0, 1, 5, 0, True, State.UPSTREAM_FAILED, False], + ['one_success', 0, 0, 5, 0, 5, 0, True, State.UPSTREAM_FAILED, False], + ['one_success', 0, 0, 4, 1, 5, 0, True, State.UPSTREAM_FAILED, False], + ['one_success', 0, 0, 0, 5, 5, 0, True, State.UPSTREAM_FAILED, False], # # Tests for all_failed # - ['all_failed', 5, 0, 0, 0, 5, True, State.SKIPPED, False], - ['all_failed', 0, 0, 5, 0, 5, True, None, True], - ['all_failed', 2, 0, 0, 0, 2, True, State.SKIPPED, False], - ['all_failed', 2, 0, 1, 0, 3, True, State.SKIPPED, False], - ['all_failed', 2, 1, 0, 0, 3, True, State.SKIPPED, False], + ['all_failed', 5, 0, 0, 0, 5, 0, True, State.SKIPPED, False], + ['all_failed', 0, 0, 5, 0, 5, 0, True, None, True], + ['all_failed', 2, 0, 0, 0, 2, 0, True, State.SKIPPED, False], + ['all_failed', 2, 0, 1, 0, 3, 0, True, State.SKIPPED, False], + ['all_failed', 2, 1, 0, 0, 3, 0, True, State.SKIPPED, False], # # Tests for one_failed # - ['one_failed', 5, 0, 0, 0, 0, True, None, False], - ['one_failed', 2, 0, 0, 0, 0, True, None, False], - ['one_failed', 2, 0, 1, 0, 0, True, None, True], - ['one_failed', 2, 1, 0, 0, 3, True, None, False], - ['one_failed', 2, 3, 0, 0, 5, True, State.SKIPPED, False], + ['one_failed', 5, 0, 0, 0, 0, 0, True, None, False], + ['one_failed', 2, 0, 0, 0, 0, 0, True, None, False], + ['one_failed', 2, 0, 1, 0, 0, 0, True, None, True], + ['one_failed', 2, 1, 0, 0, 3, 0, True, None, False], + ['one_failed', 2, 3, 0, 0, 5, 0, True, State.SKIPPED, False], # # Tests for done # - ['all_done', 5, 0, 0, 0, 5, True, None, True], - ['all_done', 2, 0, 0, 0, 2, True, None, False], - ['all_done', 2, 0, 1, 0, 3, True, None, False], - ['all_done', 2, 1, 0, 0, 3, True, None, False], + ['all_done', 5, 0, 0, 0, 5, 0, True, None, True], + ['all_done', 2, 0, 0, 0, 2, 0, True, None, False], + ['all_done', 2, 0, 1, 0, 3, 0, True, None, False], + ['all_done', 2, 1, 0, 0, 3, 0, True, None, False], ], ) def test_check_task_dependencies( @@ -1122,6 +1122,7 @@ def test_check_task_dependencies( successes: int, skipped: int, failed: int, + removed: int, upstream_failed: int, done: int, flag_upstream_failed: bool, @@ -1144,6 +1145,121 @@ def test_check_task_dependencies( successes=successes, skipped=skipped, failed=failed, + removed=removed, + upstream_failed=upstream_failed, + done=done, + dep_context=DepContext(), + flag_upstream_failed=flag_upstream_failed, + ) + completed = all(dep.passed for dep in dep_results) + + assert completed == expect_completed + assert ti.state == expect_state + + # Parameterized tests to check for the correct firing + # of the trigger_rule under various circumstances of mapped task + # Numeric fields are in order: + # successes, skipped, failed, upstream_failed, done,removed + @pytest.mark.parametrize( + "trigger_rule,successes,skipped,failed,upstream_failed,done,removed," + "flag_upstream_failed,expect_state,expect_completed", + [ + # + # Tests for all_success + # + ['all_success', 5, 0, 0, 0, 0, 0, True, None, True], + ['all_success', 2, 0, 0, 0, 0, 0, True, None, False], + ['all_success', 2, 0, 1, 0, 0, 0, True, State.UPSTREAM_FAILED, False], + ['all_success', 2, 1, 0, 0, 0, 0, True, State.SKIPPED, False], + ['all_success', 3, 0, 0, 0, 0, 2, True, State.REMOVED, True], # ti.map_index >=successes + # + # Tests for one_success + # + ['one_success', 5, 0, 0, 0, 5, 0, True, None, True], + ['one_success', 2, 0, 0, 0, 2, 0, True, None, True], + ['one_success', 2, 0, 1, 0, 3, 0, True, None, True], + ['one_success', 2, 1, 0, 0, 3, 0, True, None, True], + ['one_success', 0, 5, 0, 0, 5, 0, True, State.SKIPPED, False], + ['one_success', 0, 4, 1, 0, 5, 0, True, State.UPSTREAM_FAILED, False], + ['one_success', 0, 3, 1, 1, 5, 0, True, State.UPSTREAM_FAILED, False], + ['one_success', 0, 4, 0, 1, 5, 0, True, State.UPSTREAM_FAILED, False], + ['one_success', 0, 0, 5, 0, 5, 0, True, State.UPSTREAM_FAILED, False], + ['one_success', 0, 0, 4, 1, 5, 0, True, State.UPSTREAM_FAILED, False], + ['one_success', 0, 0, 0, 5, 5, 0, True, State.UPSTREAM_FAILED, False], + # + # Tests for all_failed + # + ['all_failed', 5, 0, 0, 0, 5, 0, True, State.SKIPPED, False], + ['all_failed', 0, 0, 5, 0, 5, 0, True, None, True], + ['all_failed', 2, 0, 0, 0, 2, 0, True, State.SKIPPED, False], + ['all_failed', 2, 0, 1, 0, 3, 0, True, State.SKIPPED, False], + ['all_failed', 2, 1, 0, 0, 3, 0, True, State.SKIPPED, False], + ['all_failed', 2, 1, 0, 0, 4, 1, True, State.SKIPPED, False], # One removed + # + # Tests for one_failed + # + ['one_failed', 5, 0, 0, 0, 0, 0, True, None, False], + ['one_failed', 2, 0, 0, 0, 0, 0, True, None, False], + ['one_failed', 2, 0, 1, 0, 0, 0, True, None, True], + ['one_failed', 2, 1, 0, 0, 3, 0, True, None, False], + ['one_failed', 2, 3, 0, 0, 5, 0, True, State.SKIPPED, False], + ['one_failed', 2, 2, 0, 0, 5, 1, True, State.SKIPPED, False], # One removed + # + # Tests for done + # + ['all_done', 5, 0, 0, 0, 5, 0, True, None, True], + ['all_done', 2, 0, 0, 0, 2, 0, True, None, False], + ['all_done', 2, 0, 1, 0, 3, 0, True, None, False], + ['all_done', 2, 1, 0, 0, 3, 0, True, None, False], + ], + ) + def test_check_task_dependencies_for_mapped( + self, + trigger_rule: str, + successes: int, + skipped: int, + failed: int, + removed: int, + upstream_failed: int, + done: int, + flag_upstream_failed: bool, + expect_state: State, + expect_completed: bool, + dag_maker, + session, + ): + from airflow.decorators import task + + @task + def do_something(i): + return 1 + + @task(trigger_rule=trigger_rule) + def do_something_else(i): + return 1 + + with dag_maker(dag_id='test_dag'): + nums = do_something.expand(i=[i + 1 for i in range(5)]) + do_something_else.expand(i=nums) + + dr = dag_maker.create_dagrun() + + ti = dr.get_task_instance('do_something_else', session=session) + ti.map_index = 0 + for map_index in range(1, 5): + ti = TaskInstance(ti.task, run_id=dr.run_id, map_index=map_index) + ti.dag_run = dr + session.add(ti) + session.flush() + downstream = ti.task + ti = dr.get_task_instance(task_id='do_something_else', map_index=3, session=session) + ti.task = downstream + dep_results = TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=successes, + skipped=skipped, + failed=failed, + removed=removed, upstream_failed=upstream_failed, done=done, dep_context=DepContext(), diff --git a/tests/ti_deps/deps/test_trigger_rule_dep.py b/tests/ti_deps/deps/test_trigger_rule_dep.py index fc6a4d546c1b7..4deeebe254143 100644 --- a/tests/ti_deps/deps/test_trigger_rule_dep.py +++ b/tests/ti_deps/deps/test_trigger_rule_dep.py @@ -25,12 +25,13 @@ from airflow import settings from airflow.models import DAG from airflow.models.baseoperator import BaseOperator +from airflow.models.taskinstance import TaskInstance from airflow.operators.empty import EmptyOperator from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep from airflow.utils import timezone from airflow.utils.session import create_session -from airflow.utils.state import State +from airflow.utils.state import State, TaskInstanceState from airflow.utils.trigger_rule import TriggerRule from tests.models import DEFAULT_DATE from tests.test_utils.db import clear_db_runs @@ -53,6 +54,46 @@ def _get_task_instance(trigger_rule=TriggerRule.ALL_SUCCESS, state=None, upstrea return _get_task_instance +@pytest.fixture +def get_mapped_task_dagrun(session, dag_maker): + def _get_dagrun(trigger_rule=TriggerRule.ALL_SUCCESS, state=State.SUCCESS): + from airflow.decorators import task + + @task + def do_something(i): + return 1 + + @task(trigger_rule=trigger_rule) + def do_something_else(i): + return 1 + + with dag_maker(dag_id='test_dag'): + nums = do_something.expand(i=[i + 1 for i in range(5)]) + do_something_else.expand(i=nums) + + dr = dag_maker.create_dagrun() + + ti = dr.get_task_instance('do_something_else', session=session) + ti.map_index = 0 + for map_index in range(1, 5): + ti = TaskInstance(ti.task, run_id=dr.run_id, map_index=map_index) + ti.dag_run = dr + session.add(ti) + session.flush() + tis = dr.get_task_instances() + for ti in tis: + if ti.task_id == 'do_something': + if ti.map_index > 2: + ti.state = TaskInstanceState.REMOVED + else: + ti.state = state + session.merge(ti) + session.commit() + return dr, ti.task + + return _get_dagrun + + class TestTriggerRuleDep: def test_no_upstream_tasks(self, get_task_instance): """ @@ -79,6 +120,7 @@ def test_one_success_tr_success(self, get_task_instance): successes=1, skipped=2, failed=2, + removed=0, upstream_failed=2, done=2, flag_upstream_failed=False, @@ -99,6 +141,7 @@ def test_one_success_tr_failure(self, get_task_instance): successes=0, skipped=2, failed=2, + removed=0, upstream_failed=2, done=2, flag_upstream_failed=False, @@ -120,6 +163,7 @@ def test_one_failure_tr_failure(self, get_task_instance): successes=2, skipped=0, failed=0, + removed=0, upstream_failed=0, done=2, flag_upstream_failed=False, @@ -141,6 +185,7 @@ def test_one_failure_tr_success(self, get_task_instance): successes=0, skipped=2, failed=2, + removed=0, upstream_failed=0, done=2, flag_upstream_failed=False, @@ -156,6 +201,7 @@ def test_one_failure_tr_success(self, get_task_instance): successes=0, skipped=2, failed=0, + removed=0, upstream_failed=2, done=2, flag_upstream_failed=False, @@ -176,6 +222,7 @@ def test_all_success_tr_success(self, get_task_instance): successes=1, skipped=0, failed=0, + removed=0, upstream_failed=0, done=1, flag_upstream_failed=False, @@ -196,6 +243,7 @@ def test_all_success_tr_failure(self, get_task_instance): successes=1, skipped=0, failed=1, + removed=0, upstream_failed=0, done=2, flag_upstream_failed=False, @@ -217,6 +265,7 @@ def test_all_success_tr_skip(self, get_task_instance): successes=1, skipped=1, failed=0, + removed=0, upstream_failed=0, done=2, flag_upstream_failed=False, @@ -239,6 +288,7 @@ def test_all_success_tr_skip_flag_upstream(self, get_task_instance): successes=1, skipped=1, failed=0, + removed=0, upstream_failed=0, done=2, flag_upstream_failed=True, @@ -261,6 +311,7 @@ def test_none_failed_tr_success(self, get_task_instance): successes=1, skipped=1, failed=0, + removed=0, upstream_failed=0, done=2, flag_upstream_failed=False, @@ -281,6 +332,7 @@ def test_none_failed_tr_skipped(self, get_task_instance): successes=0, skipped=2, failed=0, + removed=0, upstream_failed=0, done=2, flag_upstream_failed=True, @@ -304,6 +356,7 @@ def test_none_failed_tr_failure(self, get_task_instance): successes=1, skipped=1, failed=1, + removed=0, upstream_failed=0, done=3, flag_upstream_failed=False, @@ -327,6 +380,7 @@ def test_none_failed_min_one_success_tr_success(self, get_task_instance): successes=1, skipped=1, failed=0, + removed=0, upstream_failed=0, done=2, flag_upstream_failed=False, @@ -349,6 +403,7 @@ def test_none_failed_min_one_success_tr_skipped(self, get_task_instance): successes=0, skipped=2, failed=0, + removed=0, upstream_failed=0, done=2, flag_upstream_failed=True, @@ -373,6 +428,7 @@ def test_none_failed_min_one_success_tr_failure(self, session, get_task_instance successes=1, skipped=1, failed=1, + removed=0, upstream_failed=0, done=3, flag_upstream_failed=False, @@ -394,6 +450,7 @@ def test_all_failed_tr_success(self, get_task_instance): successes=0, skipped=0, failed=2, + removed=0, upstream_failed=0, done=2, flag_upstream_failed=False, @@ -414,6 +471,7 @@ def test_all_failed_tr_failure(self, get_task_instance): successes=2, skipped=0, failed=0, + removed=0, upstream_failed=0, done=2, flag_upstream_failed=False, @@ -435,6 +493,7 @@ def test_all_done_tr_success(self, get_task_instance): successes=2, skipped=0, failed=0, + removed=0, upstream_failed=0, done=2, flag_upstream_failed=False, @@ -455,6 +514,7 @@ def test_all_skipped_tr_failure(self, get_task_instance): successes=1, skipped=0, failed=0, + removed=0, upstream_failed=0, done=1, flag_upstream_failed=False, @@ -479,6 +539,7 @@ def test_all_skipped_tr_success(self, get_task_instance): successes=0, skipped=3, failed=0, + removed=0, upstream_failed=0, done=3, flag_upstream_failed=False, @@ -495,6 +556,7 @@ def test_all_skipped_tr_success(self, get_task_instance): successes=0, skipped=3, failed=0, + removed=0, upstream_failed=0, done=3, flag_upstream_failed=True, @@ -515,6 +577,7 @@ def test_all_done_tr_failure(self, get_task_instance): successes=1, skipped=0, failed=0, + removed=0, upstream_failed=0, done=1, flag_upstream_failed=False, @@ -539,6 +602,7 @@ def test_none_skipped_tr_success(self, get_task_instance): successes=2, skipped=0, failed=1, + removed=0, upstream_failed=0, done=3, flag_upstream_failed=False, @@ -555,6 +619,7 @@ def test_none_skipped_tr_success(self, get_task_instance): successes=0, skipped=0, failed=3, + removed=0, upstream_failed=0, done=3, flag_upstream_failed=True, @@ -577,6 +642,7 @@ def test_none_skipped_tr_failure(self, get_task_instance): successes=1, skipped=1, failed=0, + removed=0, upstream_failed=0, done=2, flag_upstream_failed=False, @@ -594,6 +660,7 @@ def test_none_skipped_tr_failure(self, get_task_instance): successes=1, skipped=1, failed=0, + removed=0, upstream_failed=0, done=2, flag_upstream_failed=True, @@ -611,6 +678,7 @@ def test_none_skipped_tr_failure(self, get_task_instance): successes=0, skipped=0, failed=0, + removed=0, upstream_failed=0, done=0, flag_upstream_failed=False, @@ -633,6 +701,7 @@ def test_unknown_tr(self, get_task_instance): successes=1, skipped=0, failed=0, + removed=0, upstream_failed=0, done=1, flag_upstream_failed=False, @@ -693,10 +762,128 @@ def test_get_states_count_upstream_ti(self): # check handling with cases that tasks are triggered from backfill with no finished tasks finished_tis = DepContext().ensure_finished_tis(ti_op2.dag_run, session) - assert get_states_count_upstream_ti(finished_tis=finished_tis, task=op2) == (1, 0, 0, 0, 1) + assert get_states_count_upstream_ti(finished_tis=finished_tis, task=op2) == (1, 0, 0, 0, 0, 1) finished_tis = dr.get_task_instances(state=State.finished, session=session) - assert get_states_count_upstream_ti(finished_tis=finished_tis, task=op4) == (1, 0, 1, 0, 2) - assert get_states_count_upstream_ti(finished_tis=finished_tis, task=op5) == (2, 0, 1, 0, 3) + assert get_states_count_upstream_ti(finished_tis=finished_tis, task=op4) == (1, 0, 1, 0, 0, 2) + assert get_states_count_upstream_ti(finished_tis=finished_tis, task=op5) == (2, 0, 1, 0, 0, 3) dr.update_state() assert State.SUCCESS == dr.state + + def test_mapped_task_upstream_removed_with_all_success_trigger_rules( + self, session, get_mapped_task_dagrun + ): + """ + Test ALL_SUCCESS trigger rule with mapped task upstream removed + """ + dr, task = get_mapped_task_dagrun() + + # ti with removed upstream ti + ti = dr.get_task_instance(task_id='do_something_else', map_index=3, session=session) + ti.task = task + + dep_statuses = tuple( + TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=3, + skipped=0, + failed=0, + removed=2, + upstream_failed=0, + done=5, + flag_upstream_failed=True, # marks the task as removed if upstream is removed + dep_context=DepContext(), + session=session, + ) + ) + + assert len(dep_statuses) == 0 + assert ti.state == TaskInstanceState.REMOVED + + def test_mapped_task_upstream_removed_with_all_failed_trigger_rules( + self, session, get_mapped_task_dagrun + ): + """ + Test ALL_FAILED trigger rule with mapped task upstream removed + """ + + dr, task = get_mapped_task_dagrun(trigger_rule=TriggerRule.ALL_FAILED, state=State.FAILED) + + # ti with removed upstream ti + ti = dr.get_task_instance(task_id='do_something_else', map_index=3, session=session) + ti.task = task + + dep_statuses = tuple( + TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=0, + skipped=0, + failed=3, + removed=2, + upstream_failed=0, + done=5, + flag_upstream_failed=False, + dep_context=DepContext(), + session=session, + ) + ) + + assert len(dep_statuses) == 0 + + def test_mapped_task_upstream_removed_with_none_failed_trigger_rules( + self, session, get_mapped_task_dagrun + ): + """ + Test NONE_FAILED trigger rule with mapped task upstream removed + """ + dr, task = get_mapped_task_dagrun(trigger_rule=TriggerRule.NONE_FAILED) + + # ti with removed upstream ti + ti = dr.get_task_instance(task_id='do_something_else', map_index=3, session=session) + ti.task = task + + dep_statuses = tuple( + TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=3, + skipped=0, + failed=0, + removed=2, + upstream_failed=0, + done=5, + flag_upstream_failed=False, + dep_context=DepContext(), + session=session, + ) + ) + + assert len(dep_statuses) == 0 + + def test_mapped_task_upstream_removed_with_none_failed_min_one_success_trigger_rules( + self, session, get_mapped_task_dagrun + ): + """ + Test NONE_FAILED_MIN_ONE_SUCCESS trigger rule with mapped task upstream removed + """ + dr, task = get_mapped_task_dagrun(trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS) + + # ti with removed upstream ti + ti = dr.get_task_instance(task_id='do_something_else', map_index=3, session=session) + ti.task = task + + dep_statuses = tuple( + TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=3, + skipped=0, + failed=0, + removed=2, + upstream_failed=0, + done=5, + flag_upstream_failed=False, + dep_context=DepContext(), + session=session, + ) + ) + + assert len(dep_statuses) == 0