diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py index 35d0422e7d4b1..967b25bb38c2a 100644 --- a/airflow/jobs/local_task_job.py +++ b/airflow/jobs/local_task_job.py @@ -167,6 +167,7 @@ def handle_task_exit(self, return_code: int) -> None: if not self.task_instance.test_mode: if conf.getboolean('scheduler', 'schedule_after_task_execution', fallback=True): self._run_mini_scheduler_on_child_tasks() + self._update_dagrun_state_for_paused_dag() def on_kill(self): self.task_runner.terminate() @@ -264,3 +265,16 @@ def _run_mini_scheduler_on_child_tasks(self, session=None) -> None: exc_info=True, ) session.rollback() + + @provide_session + def _update_dagrun_state_for_paused_dag(self, session=None): + """ + Checks for paused dags with DagRuns in the running state and + update the DagRun state if possible + """ + dag = self.task_instance.task.dag + if dag.get_is_paused(): + dag_run = self.task_instance.get_dagrun(session=session) + if dag_run: + dag_run.dag = dag + dag_run.update_state(session=session, execute_callbacks=True) diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py index 905d83ae569e2..77099cb9dc799 100644 --- a/tests/jobs/test_local_task_job.py +++ b/tests/jobs/test_local_task_job.py @@ -45,6 +45,7 @@ from airflow.utils.session import create_session from airflow.utils.state import State from airflow.utils.timeout import timeout +from airflow.utils.types import DagRunType from tests.test_utils import db from tests.test_utils.asserts import assert_queries_count from tests.test_utils.config import conf_vars @@ -686,6 +687,43 @@ def test_fast_follow( if scheduler_job.processor_agent: scheduler_job.processor_agent.end() + def test_task_exit_should_update_state_of_finished_dagruns_with_dag_paused(self): + """Test that with DAG paused, DagRun state will update when the tasks finishes the run""" + dag = DAG(dag_id='test_dags', start_date=DEFAULT_DATE) + op1 = PythonOperator(task_id='dummy', dag=dag, owner='airflow', python_callable=lambda: True) + + session = settings.Session() + orm_dag = DagModel( + dag_id=dag.dag_id, + has_task_concurrency_limits=False, + next_dagrun=dag.start_date, + next_dagrun_create_after=dag.following_schedule(DEFAULT_DATE), + is_active=True, + is_paused=True, + ) + session.add(orm_dag) + session.flush() + # Write Dag to DB + dagbag = DagBag(dag_folder="/dev/null", include_examples=False, read_dags_from_db=False) + dagbag.bag_dag(dag, root_dag=dag) + dagbag.sync_to_db() + + dr = dag.create_dagrun( + run_type=DagRunType.SCHEDULED, + state=State.RUNNING, + execution_date=DEFAULT_DATE, + start_date=DEFAULT_DATE, + session=session, + ) + assert dr.state == State.RUNNING + ti = TaskInstance(op1, dr.execution_date) + job1 = LocalTaskJob(task_instance=ti, ignore_ti_state=True, executor=SequentialExecutor()) + job1.task_runner = StandardTaskRunner(job1) + job1.run() + session.add(dr) + session.refresh(dr) + assert dr.state == State.SUCCESS + @pytest.fixture() def clean_db_helper(): @@ -704,12 +742,12 @@ def test_number_of_queries_single_loop(self, mock_get_task_runner, return_codes) task = DummyOperator(task_id='test_state_succeeded1', dag=dag) dag.clear() - dag.create_dagrun(run_id=unique_prefix, state=State.NONE) + dag.create_dagrun(run_id=unique_prefix, execution_date=DEFAULT_DATE, state=State.NONE) ti = TaskInstance(task=task, execution_date=DEFAULT_DATE) mock_get_task_runner.return_value.return_code.side_effects = return_codes job = LocalTaskJob(task_instance=ti, executor=MockExecutor()) - with assert_queries_count(15): + with assert_queries_count(16): job.run()