Skip to content

Commit

Permalink
Add test for mapped task dependencies check including removed task in…
Browse files Browse the repository at this point in the history
… test_taskinstance.py
  • Loading branch information
ephraimbuddy committed Sep 20, 2022
1 parent eed077a commit 2584bcb
Showing 1 changed file with 115 additions and 1 deletion.
116 changes: 115 additions & 1 deletion tests/models/test_taskinstance.py
Expand Up @@ -1065,7 +1065,7 @@ def test_depends_on_past(self, dag_maker):
# Parameterized tests to check for the correct firing
# of the trigger_rule under various circumstances
# Numeric fields are in order:
# successes, skipped, failed, upstream_failed, done
# successes, skipped, failed, upstream_failed, done, removed
@pytest.mark.parametrize(
"trigger_rule,successes,skipped,failed,upstream_failed,done,removed,"
"flag_upstream_failed,expect_state,expect_completed",
Expand Down Expand Up @@ -1156,6 +1156,120 @@ def test_check_task_dependencies(
assert completed == expect_completed
assert ti.state == expect_state

# Parameterized tests to check for the correct firing
# of the trigger_rule under various circumstances of mapped task
# Numeric fields are in order:
# successes, skipped, failed, upstream_failed, done,removed
@pytest.mark.parametrize(
"trigger_rule,successes,skipped,failed,upstream_failed,done,removed,"
"flag_upstream_failed,expect_state,expect_completed",
[
#
# Tests for all_success
#
['all_success', 5, 0, 0, 0, 0, 0, True, None, True],
['all_success', 2, 0, 0, 0, 0, 0, True, None, False],
['all_success', 2, 0, 1, 0, 0, 0, True, State.UPSTREAM_FAILED, False],
['all_success', 2, 1, 0, 0, 0, 0, True, State.SKIPPED, False],
['all_success', 3, 0, 0, 0, 0, 2, True, State.REMOVED, True], # ti.map_index >=successes
#
# Tests for one_success
#
['one_success', 5, 0, 0, 0, 5, 0, True, None, True],
['one_success', 2, 0, 0, 0, 2, 0, True, None, True],
['one_success', 2, 0, 1, 0, 3, 0, True, None, True],
['one_success', 2, 1, 0, 0, 3, 0, True, None, True],
['one_success', 0, 5, 0, 0, 5, 0, True, State.SKIPPED, False],
['one_success', 0, 4, 1, 0, 5, 0, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 3, 1, 1, 5, 0, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 4, 0, 1, 5, 0, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 0, 5, 0, 5, 0, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 0, 4, 1, 5, 0, True, State.UPSTREAM_FAILED, False],
['one_success', 0, 0, 0, 5, 5, 0, True, State.UPSTREAM_FAILED, False],
#
# Tests for all_failed
#
['all_failed', 5, 0, 0, 0, 5, 0, True, State.SKIPPED, False],
['all_failed', 0, 0, 5, 0, 5, 0, True, None, True],
['all_failed', 2, 0, 0, 0, 2, 0, True, State.SKIPPED, False],
['all_failed', 2, 0, 1, 0, 3, 0, True, State.SKIPPED, False],
['all_failed', 2, 1, 0, 0, 3, 0, True, State.SKIPPED, False],
['all_failed', 2, 1, 0, 0, 4, 1, True, State.SKIPPED, False], # One removed
#
# Tests for one_failed
#
['one_failed', 5, 0, 0, 0, 0, 0, True, None, False],
['one_failed', 2, 0, 0, 0, 0, 0, True, None, False],
['one_failed', 2, 0, 1, 0, 0, 0, True, None, True],
['one_failed', 2, 1, 0, 0, 3, 0, True, None, False],
['one_failed', 2, 3, 0, 0, 5, 0, True, State.SKIPPED, False],
['one_failed', 2, 2, 0, 0, 5, 1, True, State.SKIPPED, False], # One removed
#
# Tests for done
#
['all_done', 5, 0, 0, 0, 5, 0, True, None, True],
['all_done', 2, 0, 0, 0, 2, 0, True, None, False],
['all_done', 2, 0, 1, 0, 3, 0, True, None, False],
['all_done', 2, 1, 0, 0, 3, 0, True, None, False],
],
)
def test_check_task_dependencies_for_mapped(
self,
trigger_rule: str,
successes: int,
skipped: int,
failed: int,
removed: int,
upstream_failed: int,
done: int,
flag_upstream_failed: bool,
expect_state: State,
expect_completed: bool,
dag_maker,
session,
):
from airflow.decorators import task

@task
def do_something(i):
return 1

@task(trigger_rule=trigger_rule)
def do_something_else(i):
return 1

with dag_maker(dag_id='test_dag'):
nums = do_something.expand(i=[i + 1 for i in range(5)])
do_something_else.expand(i=nums)

dr = dag_maker.create_dagrun()

ti = dr.get_task_instance('do_something_else', session=session)
ti.map_index = 0
for map_index in range(1, 5):
ti = TaskInstance(ti.task, run_id=dr.run_id, map_index=map_index)
ti.dag_run = dr
session.add(ti)
session.flush()
downstream = ti.task
ti = dr.get_task_instance(task_id='do_something_else', map_index=3, session=session)
ti.task = downstream
dep_results = TriggerRuleDep()._evaluate_trigger_rule(
ti=ti,
successes=successes,
skipped=skipped,
failed=failed,
removed=removed,
upstream_failed=upstream_failed,
done=done,
dep_context=DepContext(),
flag_upstream_failed=flag_upstream_failed,
)
completed = all(dep.passed for dep in dep_results)

assert completed == expect_completed
assert ti.state == expect_state

def test_respects_prev_dagrun_dep(self, create_task_instance):
ti = create_task_instance()
failing_status = [TIDepStatus('test fail status name', False, 'test fail reason')]
Expand Down

0 comments on commit 2584bcb

Please sign in to comment.