From 64c0bd50155dfdb84671ac35d645b812fafa78a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C4=90=E1=BA=B7ng=20Minh=20D=C5=A9ng?= Date: Wed, 5 Jan 2022 14:42:57 +0700 Subject: [PATCH] bugfix: deferred tasks does not cancel when DAG is marked fail (#20649) --- airflow/api/common/experimental/mark_tasks.py | 119 ++++++++++++------ 1 file changed, 84 insertions(+), 35 deletions(-) diff --git a/airflow/api/common/experimental/mark_tasks.py b/airflow/api/common/experimental/mark_tasks.py index a2a36d9ee6329..c0db5b3a40845 100644 --- a/airflow/api/common/experimental/mark_tasks.py +++ b/airflow/api/common/experimental/mark_tasks.py @@ -17,23 +17,27 @@ # under the License. """Marks tasks APIs.""" -import datetime -from typing import Iterable +from datetime import datetime +from typing import Generator, Iterable, List, Optional -from sqlalchemy import or_ from sqlalchemy.orm import contains_eager +from sqlalchemy.orm.session import Session as SASession +from sqlalchemy.sql.expression import or_ +from airflow import DAG from airflow.models.baseoperator import BaseOperator from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstance from airflow.operators.subdag import SubDagOperator from airflow.utils import timezone -from airflow.utils.session import provide_session +from airflow.utils.session import NEW_SESSION, provide_session from airflow.utils.state import State, TaskInstanceState from airflow.utils.types import DagRunType -def _create_dagruns(dag, execution_dates, state, run_type): +def _create_dagruns( + dag: DAG, execution_dates: List[datetime], state: TaskInstanceState, run_type: DagRunType +) -> List[DagRun]: """ Infers from the dates which dag runs need to be created and does so. @@ -63,15 +67,15 @@ def _create_dagruns(dag, execution_dates, state, run_type): @provide_session def set_state( tasks: Iterable[BaseOperator], - execution_date: datetime.datetime, + execution_date: datetime, upstream: bool = False, downstream: bool = False, future: bool = False, past: bool = False, state: TaskInstanceState = TaskInstanceState.SUCCESS, commit: bool = False, - session=None, -): + session: SASession = NEW_SESSION, +) -> List[TaskInstance]: """ Set the state of a task instance and if needed its relatives. Can set state for future tasks (calculated from execution_date) and retroactively @@ -134,7 +138,9 @@ def set_state( return tis_altered -def all_subdag_tasks_query(sub_dag_run_ids, session, state, confirmed_dates): +def all_subdag_tasks_query( + sub_dag_run_ids: List[str], session: SASession, state: TaskInstanceState, confirmed_dates: List[datetime] +): """Get *all* tasks of the sub dags""" qry_sub_dag = ( session.query(TaskInstance) @@ -144,7 +150,13 @@ def all_subdag_tasks_query(sub_dag_run_ids, session, state, confirmed_dates): return qry_sub_dag -def get_all_dag_task_query(dag, session, state, task_ids, confirmed_dates): +def get_all_dag_task_query( + dag: DAG, + session: SASession, + state: TaskInstanceState, + task_ids: List[str], + confirmed_dates: List[datetime], +): """Get all tasks of the main dag that will be affected by a state change""" qry_dag = ( session.query(TaskInstance) @@ -160,7 +172,14 @@ def get_all_dag_task_query(dag, session, state, task_ids, confirmed_dates): return qry_dag -def get_subdag_runs(dag, session, state, task_ids, commit, confirmed_dates): +def get_subdag_runs( + dag: DAG, + session: SASession, + state: TaskInstanceState, + task_ids: List[str], + commit: bool, + confirmed_dates: List[datetime], +) -> List[str]: """Go through subdag operators and create dag runs. We will only work within the scope of the subdag. We won't propagate to the parent dag, but we will propagate from parent to subdag. @@ -181,7 +200,7 @@ def get_subdag_runs(dag, session, state, task_ids, commit, confirmed_dates): dag_runs = _create_dagruns( current_task.subdag, execution_dates=confirmed_dates, - state=State.RUNNING, + state=TaskInstanceState.RUNNING, run_type=DagRunType.BACKFILL_JOB, ) @@ -192,7 +211,13 @@ def get_subdag_runs(dag, session, state, task_ids, commit, confirmed_dates): return sub_dag_ids -def verify_dagruns(dag_runs, commit, state, session, current_task): +def verify_dagruns( + dag_runs: List[DagRun], + commit: bool, + state: TaskInstanceState, + session: SASession, + current_task: BaseOperator, +): """Verifies integrity of dag_runs. :param dag_runs: dag runs to verify @@ -210,7 +235,7 @@ def verify_dagruns(dag_runs, commit, state, session, current_task): session.merge(dag_run) -def verify_dag_run_integrity(dag, dates): +def verify_dag_run_integrity(dag: DAG, dates: List[datetime]) -> List[datetime]: """ Verify the integrity of the dag runs in case a task was added or removed set the confirmed execution dates as they might be different @@ -225,7 +250,9 @@ def verify_dag_run_integrity(dag, dates): return confirmed_dates -def find_task_relatives(tasks, downstream, upstream): +def find_task_relatives( + tasks: Iterable[BaseOperator], downstream: bool, upstream: bool +) -> Generator[str, None, None]: """Yield task ids and optionally ancestor and descendant ids.""" for task in tasks: yield task.task_id @@ -237,7 +264,7 @@ def find_task_relatives(tasks, downstream, upstream): yield relative.task_id -def get_execution_dates(dag, execution_date, future, past): +def get_execution_dates(dag: DAG, execution_date: datetime, future: bool, past: bool) -> List[datetime]: """Returns dates of DAG execution""" latest_execution_date = dag.get_latest_execution_date() if latest_execution_date is None: @@ -266,7 +293,9 @@ def get_execution_dates(dag, execution_date, future, past): @provide_session -def _set_dag_run_state(dag_id, execution_date, state, session=None): +def _set_dag_run_state( + dag_id: str, execution_date: datetime, state: TaskInstanceState, session: SASession = NEW_SESSION +): """ Helper method that set dag run state in the DB. @@ -279,7 +308,7 @@ def _set_dag_run_state(dag_id, execution_date, state, session=None): session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.execution_date == execution_date).one() ) dag_run.state = state - if state == State.RUNNING: + if state == TaskInstanceState.RUNNING: dag_run.start_date = timezone.utcnow() dag_run.end_date = None else: @@ -288,7 +317,12 @@ def _set_dag_run_state(dag_id, execution_date, state, session=None): @provide_session -def set_dag_run_state_to_success(dag, execution_date, commit=False, session=None): +def set_dag_run_state_to_success( + dag: Optional[DAG], + execution_date: Optional[datetime], + commit: bool = False, + session: SASession = NEW_SESSION, +) -> List[TaskInstance]: """ Set the dag run for a specific execution date and its task instances to success. @@ -306,18 +340,27 @@ def set_dag_run_state_to_success(dag, execution_date, commit=False, session=None # Mark the dag run to success. if commit: - _set_dag_run_state(dag.dag_id, execution_date, State.SUCCESS, session) + _set_dag_run_state(dag.dag_id, execution_date, TaskInstanceState.SUCCESS, session) # Mark all task instances of the dag run to success. for task in dag.tasks: task.dag = dag return set_state( - tasks=dag.tasks, execution_date=execution_date, state=State.SUCCESS, commit=commit, session=session + tasks=dag.tasks, + execution_date=execution_date, + state=TaskInstanceState.SUCCESS, + commit=commit, + session=session, ) @provide_session -def set_dag_run_state_to_failed(dag, execution_date, commit=False, session=None): +def set_dag_run_state_to_failed( + dag: Optional[DAG], + execution_date: Optional[datetime], + commit: bool = False, + session: SASession = NEW_SESSION, +) -> List[TaskInstance]: """ Set the dag run for a specific execution date and its running task instances to failed. @@ -335,18 +378,15 @@ def set_dag_run_state_to_failed(dag, execution_date, commit=False, session=None) # Mark the dag run to failed. if commit: - _set_dag_run_state(dag.dag_id, execution_date, State.FAILED, session) + _set_dag_run_state(dag.dag_id, execution_date, TaskInstanceState.FAILED, session) - # Mark only RUNNING task instances. + # Mark only running task instances. task_ids = [task.task_id for task in dag.tasks] - tis = ( - session.query(TaskInstance) - .filter( - TaskInstance.dag_id == dag.dag_id, - TaskInstance.execution_date == execution_date, - TaskInstance.task_id.in_(task_ids), - ) - .filter(TaskInstance.state == State.RUNNING) + tis = session.query(TaskInstance).filter( + TaskInstance.dag_id == dag.dag_id, + TaskInstance.execution_date == execution_date, + TaskInstance.task_id.in_(task_ids), + TaskInstance.state.in_(State.running), ) task_ids_of_running_tis = [task_instance.task_id for task_instance in tis] @@ -358,12 +398,21 @@ def set_dag_run_state_to_failed(dag, execution_date, commit=False, session=None) tasks.append(task) return set_state( - tasks=tasks, execution_date=execution_date, state=State.FAILED, commit=commit, session=session + tasks=tasks, + execution_date=execution_date, + state=TaskInstanceState.FAILED, + commit=commit, + session=session, ) @provide_session -def set_dag_run_state_to_running(dag, execution_date, commit=False, session=None): +def set_dag_run_state_to_running( + dag: Optional[DAG], + execution_date: Optional[datetime], + commit: bool = False, + session: SASession = NEW_SESSION, +) -> List[TaskInstance]: """ Set the dag run for a specific execution date to running. @@ -380,7 +429,7 @@ def set_dag_run_state_to_running(dag, execution_date, commit=False, session=None # Mark the dag run to running. if commit: - _set_dag_run_state(dag.dag_id, execution_date, State.RUNNING, session) + _set_dag_run_state(dag.dag_id, execution_date, TaskInstanceState.RUNNING, session) # To keep the return type consistent with the other similar functions. return res