Skip to content

Commit

Permalink
Fix DAG run state not updated while DAG is paused (#16343)
Browse files Browse the repository at this point in the history
The state of a DAG run does not update while the DAG is paused.
The tasks continue to run if the DAG run was kicked off before
the DAG was paused and eventually finish and are marked correctly.
The DAG run state does not get updated and stays in Running state until the DAG is unpaused.

This change fixes it by running a check on task exit to update state(if possible)
 of the DagRun if the task was able to finish the DagRun while the DAG is paused

Co-authored-by: Ash Berlin-Taylor <ash_github@firemirror.com>
  • Loading branch information
ephraimbuddy and ashb committed Jun 17, 2021
1 parent d53371b commit 3834df6
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 2 deletions.
14 changes: 14 additions & 0 deletions airflow/jobs/local_task_job.py
Expand Up @@ -167,6 +167,7 @@ def handle_task_exit(self, return_code: int) -> None:
if not self.task_instance.test_mode:
if conf.getboolean('scheduler', 'schedule_after_task_execution', fallback=True):
self._run_mini_scheduler_on_child_tasks()
self._update_dagrun_state_for_paused_dag()

def on_kill(self):
self.task_runner.terminate()
Expand Down Expand Up @@ -264,3 +265,16 @@ def _run_mini_scheduler_on_child_tasks(self, session=None) -> None:
exc_info=True,
)
session.rollback()

@provide_session
def _update_dagrun_state_for_paused_dag(self, session=None):
"""
Checks for paused dags with DagRuns in the running state and
update the DagRun state if possible
"""
dag = self.task_instance.task.dag
if dag.get_is_paused():
dag_run = self.task_instance.get_dagrun(session=session)
if dag_run:
dag_run.dag = dag
dag_run.update_state(session=session, execute_callbacks=True)
42 changes: 40 additions & 2 deletions tests/jobs/test_local_task_job.py
Expand Up @@ -45,6 +45,7 @@
from airflow.utils.session import create_session
from airflow.utils.state import State
from airflow.utils.timeout import timeout
from airflow.utils.types import DagRunType
from tests.test_utils import db
from tests.test_utils.asserts import assert_queries_count
from tests.test_utils.config import conf_vars
Expand Down Expand Up @@ -686,6 +687,43 @@ def test_fast_follow(
if scheduler_job.processor_agent:
scheduler_job.processor_agent.end()

def test_task_exit_should_update_state_of_finished_dagruns_with_dag_paused(self):
"""Test that with DAG paused, DagRun state will update when the tasks finishes the run"""
dag = DAG(dag_id='test_dags', start_date=DEFAULT_DATE)
op1 = PythonOperator(task_id='dummy', dag=dag, owner='airflow', python_callable=lambda: True)

session = settings.Session()
orm_dag = DagModel(
dag_id=dag.dag_id,
has_task_concurrency_limits=False,
next_dagrun=dag.start_date,
next_dagrun_create_after=dag.following_schedule(DEFAULT_DATE),
is_active=True,
is_paused=True,
)
session.add(orm_dag)
session.flush()
# Write Dag to DB
dagbag = DagBag(dag_folder="/dev/null", include_examples=False, read_dags_from_db=False)
dagbag.bag_dag(dag, root_dag=dag)
dagbag.sync_to_db()

dr = dag.create_dagrun(
run_type=DagRunType.SCHEDULED,
state=State.RUNNING,
execution_date=DEFAULT_DATE,
start_date=DEFAULT_DATE,
session=session,
)
assert dr.state == State.RUNNING
ti = TaskInstance(op1, dr.execution_date)
job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor())
job1.task_runner = StandardTaskRunner(job1)
job1.run()
session.add(dr)
session.refresh(dr)
assert dr.state == State.SUCCESS


@pytest.fixture()
def clean_db_helper():
Expand All @@ -704,12 +742,12 @@ def test_number_of_queries_single_loop(self, mock_get_task_runner, return_codes)
task = DummyOperator(task_id='test_state_succeeded1', dag=dag)

dag.clear()
dag.create_dagrun(run_id=unique_prefix, state=State.NONE)
dag.create_dagrun(run_id=unique_prefix, execution_date=DEFAULT_DATE, state=State.NONE)

ti = TaskInstance(task=task, execution_date=DEFAULT_DATE)

mock_get_task_runner.return_value.return_code.side_effects = return_codes

job = LocalTaskJob(task_instance=ti, executor=MockExecutor())
with assert_queries_count(15):
with assert_queries_count(16):
job.run()

0 comments on commit 3834df6

Please sign in to comment.