diff --git a/airflow/ti_deps/deps/trigger_rule_dep.py b/airflow/ti_deps/deps/trigger_rule_dep.py index 72fa783a8ebfb..691dc3e3a5482 100644 --- a/airflow/ti_deps/deps/trigger_rule_dep.py +++ b/airflow/ti_deps/deps/trigger_rule_dep.py @@ -59,6 +59,7 @@ def _get_states_count_upstream_ti(task, finished_tis): counter.get(State.SKIPPED, 0), counter.get(State.FAILED, 0), counter.get(State.UPSTREAM_FAILED, 0), + counter.get(State.REMOVED, 0), sum(counter.values()), ) @@ -73,7 +74,7 @@ def _get_dep_statuses(self, ti, session, dep_context: DepContext): yield self._passing_status(reason="The task had a always trigger rule set.") return # see if the task name is in the task upstream for our task - successes, skipped, failed, upstream_failed, done = self._get_states_count_upstream_ti( + successes, skipped, failed, upstream_failed, removed, done = self._get_states_count_upstream_ti( task=ti.task, finished_tis=dep_context.ensure_finished_tis(ti.get_dagrun(session), session) ) @@ -83,6 +84,7 @@ def _get_dep_statuses(self, ti, session, dep_context: DepContext): skipped=skipped, failed=failed, upstream_failed=upstream_failed, + removed=removed, done=done, flag_upstream_failed=dep_context.flag_upstream_failed, dep_context=dep_context, @@ -122,6 +124,7 @@ def _evaluate_trigger_rule( skipped, failed, upstream_failed, + removed, done, flag_upstream_failed, dep_context: DepContext, @@ -152,6 +155,7 @@ def _evaluate_trigger_rule( "successes": successes, "skipped": skipped, "failed": failed, + "removed": removed, "upstream_failed": upstream_failed, "done": done, } @@ -162,6 +166,9 @@ def _evaluate_trigger_rule( changed = ti.set_state(State.UPSTREAM_FAILED, session) elif skipped: changed = ti.set_state(State.SKIPPED, session) + elif removed and successes and ti.map_index > -1: + if ti.map_index >= successes: + changed = ti.set_state(State.REMOVED, session) elif trigger_rule == TR.ALL_FAILED: if successes or skipped: changed = ti.set_state(State.SKIPPED, session) @@ -189,6 +196,7 @@ def _evaluate_trigger_rule( elif trigger_rule == TR.ALL_SKIPPED: if successes or failed: changed = ti.set_state(State.SKIPPED, session) + if changed: dep_context.have_changed_ti_states = True @@ -212,6 +220,8 @@ def _evaluate_trigger_rule( ) elif trigger_rule == TR.ALL_SUCCESS: num_failures = upstream - successes + if ti.map_index > -1: + num_failures -= removed if num_failures > 0: yield self._failing_status( reason=( @@ -223,6 +233,8 @@ def _evaluate_trigger_rule( ) elif trigger_rule == TR.ALL_FAILED: num_successes = upstream - failed - upstream_failed + if ti.map_index > -1: + num_successes -= removed if num_successes > 0: yield self._failing_status( reason=( @@ -244,6 +256,8 @@ def _evaluate_trigger_rule( ) elif trigger_rule == TR.NONE_FAILED: num_failures = upstream - successes - skipped + if ti.map_index > -1: + num_failures -= removed if num_failures > 0: yield self._failing_status( reason=( @@ -255,6 +269,8 @@ def _evaluate_trigger_rule( ) elif trigger_rule == TR.NONE_FAILED_MIN_ONE_SUCCESS: num_failures = upstream - successes - skipped + if ti.map_index > -1: + num_failures -= removed if num_failures > 0: yield self._failing_status( reason=( diff --git a/tests/models/test_dagrun.py b/tests/models/test_dagrun.py index 16b892f76e119..50e9e9a3d8eb3 100644 --- a/tests/models/test_dagrun.py +++ b/tests/models/test_dagrun.py @@ -1905,3 +1905,43 @@ def say_hi(): dr.update_state(session=session) assert dr.state == DagRunState.SUCCESS assert tis['add_one__1'].state == TaskInstanceState.SKIPPED + + +def test_schedulable_task_exist_when_rerun_removed_upstream_mapped_task(session, dag_maker): + from airflow.decorators import task + + @task + def do_something(i): + return 1 + + @task + def do_something_else(i): + return 1 + + with dag_maker(): + nums = do_something.expand(i=[i + 1 for i in range(5)]) + do_something_else.expand(i=nums) + + dr = dag_maker.create_dagrun() + + ti = dr.get_task_instance('do_something_else', session=session) + ti.map_index = 0 + task = ti.task + for map_index in range(1, 5): + ti = TI(task, run_id=dr.run_id, map_index=map_index) + ti.dag_run = dr + session.add(ti) + session.flush() + tis = dr.get_task_instances() + for ti in tis: + if ti.task_id == 'do_something': + if ti.map_index > 2: + ti.state = TaskInstanceState.REMOVED + else: + ti.state = TaskInstanceState.SUCCESS + session.merge(ti) + session.commit() + # The Upstream is done with 2 removed tis and 3 success tis + (tis, _) = dr.update_state() + assert len(tis) + assert dr.state != DagRunState.FAILED diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 0589e96d3289c..b1e36dcbaa8e9 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -1065,55 +1065,55 @@ def test_depends_on_past(self, dag_maker): # Parameterized tests to check for the correct firing # of the trigger_rule under various circumstances # Numeric fields are in order: - # successes, skipped, failed, upstream_failed, done + # successes, skipped, failed, upstream_failed, done, removed @pytest.mark.parametrize( - "trigger_rule,successes,skipped,failed,upstream_failed,done," + "trigger_rule,successes,skipped,failed,upstream_failed,done,removed," "flag_upstream_failed,expect_state,expect_completed", [ # # Tests for all_success # - ['all_success', 5, 0, 0, 0, 0, True, None, True], - ['all_success', 2, 0, 0, 0, 0, True, None, False], - ['all_success', 2, 0, 1, 0, 0, True, State.UPSTREAM_FAILED, False], - ['all_success', 2, 1, 0, 0, 0, True, State.SKIPPED, False], + ['all_success', 5, 0, 0, 0, 0, 0, True, None, True], + ['all_success', 2, 0, 0, 0, 0, 0, True, None, False], + ['all_success', 2, 0, 1, 0, 0, 0, True, State.UPSTREAM_FAILED, False], + ['all_success', 2, 1, 0, 0, 0, 0, True, State.SKIPPED, False], # # Tests for one_success # - ['one_success', 5, 0, 0, 0, 5, True, None, True], - ['one_success', 2, 0, 0, 0, 2, True, None, True], - ['one_success', 2, 0, 1, 0, 3, True, None, True], - ['one_success', 2, 1, 0, 0, 3, True, None, True], - ['one_success', 0, 5, 0, 0, 5, True, State.SKIPPED, False], - ['one_success', 0, 4, 1, 0, 5, True, State.UPSTREAM_FAILED, False], - ['one_success', 0, 3, 1, 1, 5, True, State.UPSTREAM_FAILED, False], - ['one_success', 0, 4, 0, 1, 5, True, State.UPSTREAM_FAILED, False], - ['one_success', 0, 0, 5, 0, 5, True, State.UPSTREAM_FAILED, False], - ['one_success', 0, 0, 4, 1, 5, True, State.UPSTREAM_FAILED, False], - ['one_success', 0, 0, 0, 5, 5, True, State.UPSTREAM_FAILED, False], + ['one_success', 5, 0, 0, 0, 5, 0, True, None, True], + ['one_success', 2, 0, 0, 0, 2, 0, True, None, True], + ['one_success', 2, 0, 1, 0, 3, 0, True, None, True], + ['one_success', 2, 1, 0, 0, 3, 0, True, None, True], + ['one_success', 0, 5, 0, 0, 5, 0, True, State.SKIPPED, False], + ['one_success', 0, 4, 1, 0, 5, 0, True, State.UPSTREAM_FAILED, False], + ['one_success', 0, 3, 1, 1, 5, 0, True, State.UPSTREAM_FAILED, False], + ['one_success', 0, 4, 0, 1, 5, 0, True, State.UPSTREAM_FAILED, False], + ['one_success', 0, 0, 5, 0, 5, 0, True, State.UPSTREAM_FAILED, False], + ['one_success', 0, 0, 4, 1, 5, 0, True, State.UPSTREAM_FAILED, False], + ['one_success', 0, 0, 0, 5, 5, 0, True, State.UPSTREAM_FAILED, False], # # Tests for all_failed # - ['all_failed', 5, 0, 0, 0, 5, True, State.SKIPPED, False], - ['all_failed', 0, 0, 5, 0, 5, True, None, True], - ['all_failed', 2, 0, 0, 0, 2, True, State.SKIPPED, False], - ['all_failed', 2, 0, 1, 0, 3, True, State.SKIPPED, False], - ['all_failed', 2, 1, 0, 0, 3, True, State.SKIPPED, False], + ['all_failed', 5, 0, 0, 0, 5, 0, True, State.SKIPPED, False], + ['all_failed', 0, 0, 5, 0, 5, 0, True, None, True], + ['all_failed', 2, 0, 0, 0, 2, 0, True, State.SKIPPED, False], + ['all_failed', 2, 0, 1, 0, 3, 0, True, State.SKIPPED, False], + ['all_failed', 2, 1, 0, 0, 3, 0, True, State.SKIPPED, False], # # Tests for one_failed # - ['one_failed', 5, 0, 0, 0, 0, True, None, False], - ['one_failed', 2, 0, 0, 0, 0, True, None, False], - ['one_failed', 2, 0, 1, 0, 0, True, None, True], - ['one_failed', 2, 1, 0, 0, 3, True, None, False], - ['one_failed', 2, 3, 0, 0, 5, True, State.SKIPPED, False], + ['one_failed', 5, 0, 0, 0, 0, 0, True, None, False], + ['one_failed', 2, 0, 0, 0, 0, 0, True, None, False], + ['one_failed', 2, 0, 1, 0, 0, 0, True, None, True], + ['one_failed', 2, 1, 0, 0, 3, 0, True, None, False], + ['one_failed', 2, 3, 0, 0, 5, 0, True, State.SKIPPED, False], # # Tests for done # - ['all_done', 5, 0, 0, 0, 5, True, None, True], - ['all_done', 2, 0, 0, 0, 2, True, None, False], - ['all_done', 2, 0, 1, 0, 3, True, None, False], - ['all_done', 2, 1, 0, 0, 3, True, None, False], + ['all_done', 5, 0, 0, 0, 5, 0, True, None, True], + ['all_done', 2, 0, 0, 0, 2, 0, True, None, False], + ['all_done', 2, 0, 1, 0, 3, 0, True, None, False], + ['all_done', 2, 1, 0, 0, 3, 0, True, None, False], ], ) def test_check_task_dependencies( @@ -1122,6 +1122,7 @@ def test_check_task_dependencies( successes: int, skipped: int, failed: int, + removed: int, upstream_failed: int, done: int, flag_upstream_failed: bool, @@ -1144,6 +1145,121 @@ def test_check_task_dependencies( successes=successes, skipped=skipped, failed=failed, + removed=removed, + upstream_failed=upstream_failed, + done=done, + dep_context=DepContext(), + flag_upstream_failed=flag_upstream_failed, + ) + completed = all(dep.passed for dep in dep_results) + + assert completed == expect_completed + assert ti.state == expect_state + + # Parameterized tests to check for the correct firing + # of the trigger_rule under various circumstances of mapped task + # Numeric fields are in order: + # successes, skipped, failed, upstream_failed, done,removed + @pytest.mark.parametrize( + "trigger_rule,successes,skipped,failed,upstream_failed,done,removed," + "flag_upstream_failed,expect_state,expect_completed", + [ + # + # Tests for all_success + # + ['all_success', 5, 0, 0, 0, 0, 0, True, None, True], + ['all_success', 2, 0, 0, 0, 0, 0, True, None, False], + ['all_success', 2, 0, 1, 0, 0, 0, True, State.UPSTREAM_FAILED, False], + ['all_success', 2, 1, 0, 0, 0, 0, True, State.SKIPPED, False], + ['all_success', 3, 0, 0, 0, 0, 2, True, State.REMOVED, True], # ti.map_index >=successes + # + # Tests for one_success + # + ['one_success', 5, 0, 0, 0, 5, 0, True, None, True], + ['one_success', 2, 0, 0, 0, 2, 0, True, None, True], + ['one_success', 2, 0, 1, 0, 3, 0, True, None, True], + ['one_success', 2, 1, 0, 0, 3, 0, True, None, True], + ['one_success', 0, 5, 0, 0, 5, 0, True, State.SKIPPED, False], + ['one_success', 0, 4, 1, 0, 5, 0, True, State.UPSTREAM_FAILED, False], + ['one_success', 0, 3, 1, 1, 5, 0, True, State.UPSTREAM_FAILED, False], + ['one_success', 0, 4, 0, 1, 5, 0, True, State.UPSTREAM_FAILED, False], + ['one_success', 0, 0, 5, 0, 5, 0, True, State.UPSTREAM_FAILED, False], + ['one_success', 0, 0, 4, 1, 5, 0, True, State.UPSTREAM_FAILED, False], + ['one_success', 0, 0, 0, 5, 5, 0, True, State.UPSTREAM_FAILED, False], + # + # Tests for all_failed + # + ['all_failed', 5, 0, 0, 0, 5, 0, True, State.SKIPPED, False], + ['all_failed', 0, 0, 5, 0, 5, 0, True, None, True], + ['all_failed', 2, 0, 0, 0, 2, 0, True, State.SKIPPED, False], + ['all_failed', 2, 0, 1, 0, 3, 0, True, State.SKIPPED, False], + ['all_failed', 2, 1, 0, 0, 3, 0, True, State.SKIPPED, False], + ['all_failed', 2, 1, 0, 0, 4, 1, True, State.SKIPPED, False], # One removed + # + # Tests for one_failed + # + ['one_failed', 5, 0, 0, 0, 0, 0, True, None, False], + ['one_failed', 2, 0, 0, 0, 0, 0, True, None, False], + ['one_failed', 2, 0, 1, 0, 0, 0, True, None, True], + ['one_failed', 2, 1, 0, 0, 3, 0, True, None, False], + ['one_failed', 2, 3, 0, 0, 5, 0, True, State.SKIPPED, False], + ['one_failed', 2, 2, 0, 0, 5, 1, True, State.SKIPPED, False], # One removed + # + # Tests for done + # + ['all_done', 5, 0, 0, 0, 5, 0, True, None, True], + ['all_done', 2, 0, 0, 0, 2, 0, True, None, False], + ['all_done', 2, 0, 1, 0, 3, 0, True, None, False], + ['all_done', 2, 1, 0, 0, 3, 0, True, None, False], + ], + ) + def test_check_task_dependencies_for_mapped( + self, + trigger_rule: str, + successes: int, + skipped: int, + failed: int, + removed: int, + upstream_failed: int, + done: int, + flag_upstream_failed: bool, + expect_state: State, + expect_completed: bool, + dag_maker, + session, + ): + from airflow.decorators import task + + @task + def do_something(i): + return 1 + + @task(trigger_rule=trigger_rule) + def do_something_else(i): + return 1 + + with dag_maker(dag_id='test_dag'): + nums = do_something.expand(i=[i + 1 for i in range(5)]) + do_something_else.expand(i=nums) + + dr = dag_maker.create_dagrun() + + ti = dr.get_task_instance('do_something_else', session=session) + ti.map_index = 0 + for map_index in range(1, 5): + ti = TaskInstance(ti.task, run_id=dr.run_id, map_index=map_index) + ti.dag_run = dr + session.add(ti) + session.flush() + downstream = ti.task + ti = dr.get_task_instance(task_id='do_something_else', map_index=3, session=session) + ti.task = downstream + dep_results = TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=successes, + skipped=skipped, + failed=failed, + removed=removed, upstream_failed=upstream_failed, done=done, dep_context=DepContext(), diff --git a/tests/ti_deps/deps/test_trigger_rule_dep.py b/tests/ti_deps/deps/test_trigger_rule_dep.py index fc6a4d546c1b7..4deeebe254143 100644 --- a/tests/ti_deps/deps/test_trigger_rule_dep.py +++ b/tests/ti_deps/deps/test_trigger_rule_dep.py @@ -25,12 +25,13 @@ from airflow import settings from airflow.models import DAG from airflow.models.baseoperator import BaseOperator +from airflow.models.taskinstance import TaskInstance from airflow.operators.empty import EmptyOperator from airflow.ti_deps.dep_context import DepContext from airflow.ti_deps.deps.trigger_rule_dep import TriggerRuleDep from airflow.utils import timezone from airflow.utils.session import create_session -from airflow.utils.state import State +from airflow.utils.state import State, TaskInstanceState from airflow.utils.trigger_rule import TriggerRule from tests.models import DEFAULT_DATE from tests.test_utils.db import clear_db_runs @@ -53,6 +54,46 @@ def _get_task_instance(trigger_rule=TriggerRule.ALL_SUCCESS, state=None, upstrea return _get_task_instance +@pytest.fixture +def get_mapped_task_dagrun(session, dag_maker): + def _get_dagrun(trigger_rule=TriggerRule.ALL_SUCCESS, state=State.SUCCESS): + from airflow.decorators import task + + @task + def do_something(i): + return 1 + + @task(trigger_rule=trigger_rule) + def do_something_else(i): + return 1 + + with dag_maker(dag_id='test_dag'): + nums = do_something.expand(i=[i + 1 for i in range(5)]) + do_something_else.expand(i=nums) + + dr = dag_maker.create_dagrun() + + ti = dr.get_task_instance('do_something_else', session=session) + ti.map_index = 0 + for map_index in range(1, 5): + ti = TaskInstance(ti.task, run_id=dr.run_id, map_index=map_index) + ti.dag_run = dr + session.add(ti) + session.flush() + tis = dr.get_task_instances() + for ti in tis: + if ti.task_id == 'do_something': + if ti.map_index > 2: + ti.state = TaskInstanceState.REMOVED + else: + ti.state = state + session.merge(ti) + session.commit() + return dr, ti.task + + return _get_dagrun + + class TestTriggerRuleDep: def test_no_upstream_tasks(self, get_task_instance): """ @@ -79,6 +120,7 @@ def test_one_success_tr_success(self, get_task_instance): successes=1, skipped=2, failed=2, + removed=0, upstream_failed=2, done=2, flag_upstream_failed=False, @@ -99,6 +141,7 @@ def test_one_success_tr_failure(self, get_task_instance): successes=0, skipped=2, failed=2, + removed=0, upstream_failed=2, done=2, flag_upstream_failed=False, @@ -120,6 +163,7 @@ def test_one_failure_tr_failure(self, get_task_instance): successes=2, skipped=0, failed=0, + removed=0, upstream_failed=0, done=2, flag_upstream_failed=False, @@ -141,6 +185,7 @@ def test_one_failure_tr_success(self, get_task_instance): successes=0, skipped=2, failed=2, + removed=0, upstream_failed=0, done=2, flag_upstream_failed=False, @@ -156,6 +201,7 @@ def test_one_failure_tr_success(self, get_task_instance): successes=0, skipped=2, failed=0, + removed=0, upstream_failed=2, done=2, flag_upstream_failed=False, @@ -176,6 +222,7 @@ def test_all_success_tr_success(self, get_task_instance): successes=1, skipped=0, failed=0, + removed=0, upstream_failed=0, done=1, flag_upstream_failed=False, @@ -196,6 +243,7 @@ def test_all_success_tr_failure(self, get_task_instance): successes=1, skipped=0, failed=1, + removed=0, upstream_failed=0, done=2, flag_upstream_failed=False, @@ -217,6 +265,7 @@ def test_all_success_tr_skip(self, get_task_instance): successes=1, skipped=1, failed=0, + removed=0, upstream_failed=0, done=2, flag_upstream_failed=False, @@ -239,6 +288,7 @@ def test_all_success_tr_skip_flag_upstream(self, get_task_instance): successes=1, skipped=1, failed=0, + removed=0, upstream_failed=0, done=2, flag_upstream_failed=True, @@ -261,6 +311,7 @@ def test_none_failed_tr_success(self, get_task_instance): successes=1, skipped=1, failed=0, + removed=0, upstream_failed=0, done=2, flag_upstream_failed=False, @@ -281,6 +332,7 @@ def test_none_failed_tr_skipped(self, get_task_instance): successes=0, skipped=2, failed=0, + removed=0, upstream_failed=0, done=2, flag_upstream_failed=True, @@ -304,6 +356,7 @@ def test_none_failed_tr_failure(self, get_task_instance): successes=1, skipped=1, failed=1, + removed=0, upstream_failed=0, done=3, flag_upstream_failed=False, @@ -327,6 +380,7 @@ def test_none_failed_min_one_success_tr_success(self, get_task_instance): successes=1, skipped=1, failed=0, + removed=0, upstream_failed=0, done=2, flag_upstream_failed=False, @@ -349,6 +403,7 @@ def test_none_failed_min_one_success_tr_skipped(self, get_task_instance): successes=0, skipped=2, failed=0, + removed=0, upstream_failed=0, done=2, flag_upstream_failed=True, @@ -373,6 +428,7 @@ def test_none_failed_min_one_success_tr_failure(self, session, get_task_instance successes=1, skipped=1, failed=1, + removed=0, upstream_failed=0, done=3, flag_upstream_failed=False, @@ -394,6 +450,7 @@ def test_all_failed_tr_success(self, get_task_instance): successes=0, skipped=0, failed=2, + removed=0, upstream_failed=0, done=2, flag_upstream_failed=False, @@ -414,6 +471,7 @@ def test_all_failed_tr_failure(self, get_task_instance): successes=2, skipped=0, failed=0, + removed=0, upstream_failed=0, done=2, flag_upstream_failed=False, @@ -435,6 +493,7 @@ def test_all_done_tr_success(self, get_task_instance): successes=2, skipped=0, failed=0, + removed=0, upstream_failed=0, done=2, flag_upstream_failed=False, @@ -455,6 +514,7 @@ def test_all_skipped_tr_failure(self, get_task_instance): successes=1, skipped=0, failed=0, + removed=0, upstream_failed=0, done=1, flag_upstream_failed=False, @@ -479,6 +539,7 @@ def test_all_skipped_tr_success(self, get_task_instance): successes=0, skipped=3, failed=0, + removed=0, upstream_failed=0, done=3, flag_upstream_failed=False, @@ -495,6 +556,7 @@ def test_all_skipped_tr_success(self, get_task_instance): successes=0, skipped=3, failed=0, + removed=0, upstream_failed=0, done=3, flag_upstream_failed=True, @@ -515,6 +577,7 @@ def test_all_done_tr_failure(self, get_task_instance): successes=1, skipped=0, failed=0, + removed=0, upstream_failed=0, done=1, flag_upstream_failed=False, @@ -539,6 +602,7 @@ def test_none_skipped_tr_success(self, get_task_instance): successes=2, skipped=0, failed=1, + removed=0, upstream_failed=0, done=3, flag_upstream_failed=False, @@ -555,6 +619,7 @@ def test_none_skipped_tr_success(self, get_task_instance): successes=0, skipped=0, failed=3, + removed=0, upstream_failed=0, done=3, flag_upstream_failed=True, @@ -577,6 +642,7 @@ def test_none_skipped_tr_failure(self, get_task_instance): successes=1, skipped=1, failed=0, + removed=0, upstream_failed=0, done=2, flag_upstream_failed=False, @@ -594,6 +660,7 @@ def test_none_skipped_tr_failure(self, get_task_instance): successes=1, skipped=1, failed=0, + removed=0, upstream_failed=0, done=2, flag_upstream_failed=True, @@ -611,6 +678,7 @@ def test_none_skipped_tr_failure(self, get_task_instance): successes=0, skipped=0, failed=0, + removed=0, upstream_failed=0, done=0, flag_upstream_failed=False, @@ -633,6 +701,7 @@ def test_unknown_tr(self, get_task_instance): successes=1, skipped=0, failed=0, + removed=0, upstream_failed=0, done=1, flag_upstream_failed=False, @@ -693,10 +762,128 @@ def test_get_states_count_upstream_ti(self): # check handling with cases that tasks are triggered from backfill with no finished tasks finished_tis = DepContext().ensure_finished_tis(ti_op2.dag_run, session) - assert get_states_count_upstream_ti(finished_tis=finished_tis, task=op2) == (1, 0, 0, 0, 1) + assert get_states_count_upstream_ti(finished_tis=finished_tis, task=op2) == (1, 0, 0, 0, 0, 1) finished_tis = dr.get_task_instances(state=State.finished, session=session) - assert get_states_count_upstream_ti(finished_tis=finished_tis, task=op4) == (1, 0, 1, 0, 2) - assert get_states_count_upstream_ti(finished_tis=finished_tis, task=op5) == (2, 0, 1, 0, 3) + assert get_states_count_upstream_ti(finished_tis=finished_tis, task=op4) == (1, 0, 1, 0, 0, 2) + assert get_states_count_upstream_ti(finished_tis=finished_tis, task=op5) == (2, 0, 1, 0, 0, 3) dr.update_state() assert State.SUCCESS == dr.state + + def test_mapped_task_upstream_removed_with_all_success_trigger_rules( + self, session, get_mapped_task_dagrun + ): + """ + Test ALL_SUCCESS trigger rule with mapped task upstream removed + """ + dr, task = get_mapped_task_dagrun() + + # ti with removed upstream ti + ti = dr.get_task_instance(task_id='do_something_else', map_index=3, session=session) + ti.task = task + + dep_statuses = tuple( + TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=3, + skipped=0, + failed=0, + removed=2, + upstream_failed=0, + done=5, + flag_upstream_failed=True, # marks the task as removed if upstream is removed + dep_context=DepContext(), + session=session, + ) + ) + + assert len(dep_statuses) == 0 + assert ti.state == TaskInstanceState.REMOVED + + def test_mapped_task_upstream_removed_with_all_failed_trigger_rules( + self, session, get_mapped_task_dagrun + ): + """ + Test ALL_FAILED trigger rule with mapped task upstream removed + """ + + dr, task = get_mapped_task_dagrun(trigger_rule=TriggerRule.ALL_FAILED, state=State.FAILED) + + # ti with removed upstream ti + ti = dr.get_task_instance(task_id='do_something_else', map_index=3, session=session) + ti.task = task + + dep_statuses = tuple( + TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=0, + skipped=0, + failed=3, + removed=2, + upstream_failed=0, + done=5, + flag_upstream_failed=False, + dep_context=DepContext(), + session=session, + ) + ) + + assert len(dep_statuses) == 0 + + def test_mapped_task_upstream_removed_with_none_failed_trigger_rules( + self, session, get_mapped_task_dagrun + ): + """ + Test NONE_FAILED trigger rule with mapped task upstream removed + """ + dr, task = get_mapped_task_dagrun(trigger_rule=TriggerRule.NONE_FAILED) + + # ti with removed upstream ti + ti = dr.get_task_instance(task_id='do_something_else', map_index=3, session=session) + ti.task = task + + dep_statuses = tuple( + TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=3, + skipped=0, + failed=0, + removed=2, + upstream_failed=0, + done=5, + flag_upstream_failed=False, + dep_context=DepContext(), + session=session, + ) + ) + + assert len(dep_statuses) == 0 + + def test_mapped_task_upstream_removed_with_none_failed_min_one_success_trigger_rules( + self, session, get_mapped_task_dagrun + ): + """ + Test NONE_FAILED_MIN_ONE_SUCCESS trigger rule with mapped task upstream removed + """ + dr, task = get_mapped_task_dagrun(trigger_rule=TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS) + + # ti with removed upstream ti + ti = dr.get_task_instance(task_id='do_something_else', map_index=3, session=session) + ti.task = task + + dep_statuses = tuple( + TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=3, + skipped=0, + failed=0, + removed=2, + upstream_failed=0, + done=5, + flag_upstream_failed=False, + dep_context=DepContext(), + session=session, + ) + ) + + assert len(dep_statuses) == 0