From 2584bcbc32cf273683f50845cdee8807b7ac7b6c Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Tue, 20 Sep 2022 13:00:44 +0100 Subject: [PATCH] Add test for mapped task dependencies check including removed task in test_taskinstance.py --- tests/models/test_taskinstance.py | 116 +++++++++++++++++++++++++++++- 1 file changed, 115 insertions(+), 1 deletion(-) diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index c3c094bbdfdf0..b1e36dcbaa8e9 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -1065,7 +1065,7 @@ def test_depends_on_past(self, dag_maker): # Parameterized tests to check for the correct firing # of the trigger_rule under various circumstances # Numeric fields are in order: - # successes, skipped, failed, upstream_failed, done + # successes, skipped, failed, upstream_failed, done, removed @pytest.mark.parametrize( "trigger_rule,successes,skipped,failed,upstream_failed,done,removed," "flag_upstream_failed,expect_state,expect_completed", @@ -1156,6 +1156,120 @@ def test_check_task_dependencies( assert completed == expect_completed assert ti.state == expect_state + # Parameterized tests to check for the correct firing + # of the trigger_rule under various circumstances of mapped task + # Numeric fields are in order: + # successes, skipped, failed, upstream_failed, done,removed + @pytest.mark.parametrize( + "trigger_rule,successes,skipped,failed,upstream_failed,done,removed," + "flag_upstream_failed,expect_state,expect_completed", + [ + # + # Tests for all_success + # + ['all_success', 5, 0, 0, 0, 0, 0, True, None, True], + ['all_success', 2, 0, 0, 0, 0, 0, True, None, False], + ['all_success', 2, 0, 1, 0, 0, 0, True, State.UPSTREAM_FAILED, False], + ['all_success', 2, 1, 0, 0, 0, 0, True, State.SKIPPED, False], + ['all_success', 3, 0, 0, 0, 0, 2, True, State.REMOVED, True], # ti.map_index >=successes + # + # Tests for one_success + # + ['one_success', 5, 0, 0, 0, 5, 0, True, None, True], + ['one_success', 2, 0, 0, 0, 2, 0, True, None, True], + ['one_success', 2, 0, 1, 0, 3, 0, True, None, True], + ['one_success', 2, 1, 0, 0, 3, 0, True, None, True], + ['one_success', 0, 5, 0, 0, 5, 0, True, State.SKIPPED, False], + ['one_success', 0, 4, 1, 0, 5, 0, True, State.UPSTREAM_FAILED, False], + ['one_success', 0, 3, 1, 1, 5, 0, True, State.UPSTREAM_FAILED, False], + ['one_success', 0, 4, 0, 1, 5, 0, True, State.UPSTREAM_FAILED, False], + ['one_success', 0, 0, 5, 0, 5, 0, True, State.UPSTREAM_FAILED, False], + ['one_success', 0, 0, 4, 1, 5, 0, True, State.UPSTREAM_FAILED, False], + ['one_success', 0, 0, 0, 5, 5, 0, True, State.UPSTREAM_FAILED, False], + # + # Tests for all_failed + # + ['all_failed', 5, 0, 0, 0, 5, 0, True, State.SKIPPED, False], + ['all_failed', 0, 0, 5, 0, 5, 0, True, None, True], + ['all_failed', 2, 0, 0, 0, 2, 0, True, State.SKIPPED, False], + ['all_failed', 2, 0, 1, 0, 3, 0, True, State.SKIPPED, False], + ['all_failed', 2, 1, 0, 0, 3, 0, True, State.SKIPPED, False], + ['all_failed', 2, 1, 0, 0, 4, 1, True, State.SKIPPED, False], # One removed + # + # Tests for one_failed + # + ['one_failed', 5, 0, 0, 0, 0, 0, True, None, False], + ['one_failed', 2, 0, 0, 0, 0, 0, True, None, False], + ['one_failed', 2, 0, 1, 0, 0, 0, True, None, True], + ['one_failed', 2, 1, 0, 0, 3, 0, True, None, False], + ['one_failed', 2, 3, 0, 0, 5, 0, True, State.SKIPPED, False], + ['one_failed', 2, 2, 0, 0, 5, 1, True, State.SKIPPED, False], # One removed + # + # Tests for done + # + ['all_done', 5, 0, 0, 0, 5, 0, True, None, True], + ['all_done', 2, 0, 0, 0, 2, 0, True, None, False], + ['all_done', 2, 0, 1, 0, 3, 0, True, None, False], + ['all_done', 2, 1, 0, 0, 3, 0, True, None, False], + ], + ) + def test_check_task_dependencies_for_mapped( + self, + trigger_rule: str, + successes: int, + skipped: int, + failed: int, + removed: int, + upstream_failed: int, + done: int, + flag_upstream_failed: bool, + expect_state: State, + expect_completed: bool, + dag_maker, + session, + ): + from airflow.decorators import task + + @task + def do_something(i): + return 1 + + @task(trigger_rule=trigger_rule) + def do_something_else(i): + return 1 + + with dag_maker(dag_id='test_dag'): + nums = do_something.expand(i=[i + 1 for i in range(5)]) + do_something_else.expand(i=nums) + + dr = dag_maker.create_dagrun() + + ti = dr.get_task_instance('do_something_else', session=session) + ti.map_index = 0 + for map_index in range(1, 5): + ti = TaskInstance(ti.task, run_id=dr.run_id, map_index=map_index) + ti.dag_run = dr + session.add(ti) + session.flush() + downstream = ti.task + ti = dr.get_task_instance(task_id='do_something_else', map_index=3, session=session) + ti.task = downstream + dep_results = TriggerRuleDep()._evaluate_trigger_rule( + ti=ti, + successes=successes, + skipped=skipped, + failed=failed, + removed=removed, + upstream_failed=upstream_failed, + done=done, + dep_context=DepContext(), + flag_upstream_failed=flag_upstream_failed, + ) + completed = all(dep.passed for dep in dep_results) + + assert completed == expect_completed + assert ti.state == expect_state + def test_respects_prev_dagrun_dep(self, create_task_instance): ti = create_task_instance() failing_status = [TIDepStatus('test fail status name', False, 'test fail reason')]