Skip to content

Commit

Permalink
Support clearing and updating state of individual mapped task instanc…
Browse files Browse the repository at this point in the history
…es (#22958)

* Allow marking/clearing mapped taskinstances from the UI

* Refactor to straighten up types

* Accept multiple map_index param from front end

This allows setting multiple instances of the same task to SUCCESS or
FAILED in one request. This is translated to multiple task specifier
tuples (task_id, map_index) when passed to set_state().

Also made some drive-through improvements adding types and clean some
formatting up.

* Introduce tuple_().in_() shim for MSSQL compat

Co-authored-by: Tzu-ping Chung <uranusjr@gmail.com>
Co-authored-by: Tzu-ping Chung <tp@astronomer.io>
  • Loading branch information
3 people committed Apr 20, 2022
1 parent eb26510 commit 4fa718e
Show file tree
Hide file tree
Showing 10 changed files with 378 additions and 181 deletions.
48 changes: 33 additions & 15 deletions airflow/api/common/mark_tasks.py
Expand Up @@ -18,7 +18,7 @@
"""Marks tasks APIs."""

from datetime import datetime
from typing import TYPE_CHECKING, Iterable, Iterator, List, NamedTuple, Optional, Tuple
from typing import TYPE_CHECKING, Collection, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Union

from sqlalchemy import or_
from sqlalchemy.orm import contains_eager
Expand All @@ -32,6 +32,7 @@
from airflow.utils import timezone
from airflow.utils.helpers import exactly_one
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import tuple_in_condition
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.types import DagRunType

Expand Down Expand Up @@ -78,7 +79,7 @@ def _create_dagruns(
@provide_session
def set_state(
*,
tasks: Iterable[Operator],
tasks: Union[Collection[Operator], Collection[Tuple[Operator, int]]],
run_id: Optional[str] = None,
execution_date: Optional[datetime] = None,
upstream: bool = False,
Expand All @@ -96,7 +97,8 @@ def set_state(
tasks that did not exist. It will not create dag runs that are missing
on the schedule (but it will as for subdag dag runs if needed).
:param tasks: the iterable of tasks from which to work. task.task.dag needs to be set
:param tasks: the iterable of tasks or (task, map_index) tuples from which to work.
task.task.dag needs to be set
:param run_id: the run_id of the dagrun to start looking from
:param execution_date: the execution date from which to start looking(deprecated)
:param upstream: Mark all parents (upstream tasks)
Expand All @@ -118,7 +120,7 @@ def set_state(
if execution_date and not timezone.is_localized(execution_date):
raise ValueError(f"Received non-localized date {execution_date}")

task_dags = {task.dag for task in tasks}
task_dags = {task[0].dag if isinstance(task, tuple) else task.dag for task in tasks}
if len(task_dags) > 1:
raise ValueError(f"Received tasks from multiple DAGs: {task_dags}")
dag = next(iter(task_dags))
Expand All @@ -131,8 +133,14 @@ def set_state(
raise ValueError("Received tasks with no run_id")

dag_run_ids = get_run_ids(dag, run_id, future, past)

task_ids = list(find_task_relatives(tasks, downstream, upstream))
task_id_map_index_list = list(find_task_relatives(tasks, downstream, upstream))
task_ids = [task_id for task_id, _ in task_id_map_index_list]
# check if task_id_map_index_list contains map_index of None
# if it contains None, there was no map_index supplied for the task
for _, index in task_id_map_index_list:
if index is None:
task_id_map_index_list = [task_id for task_id, _ in task_id_map_index_list]
break

confirmed_infos = list(_iter_existing_dag_run_infos(dag, dag_run_ids))
confirmed_dates = [info.logical_date for info in confirmed_infos]
Expand All @@ -143,7 +151,7 @@ def set_state(

# now look for the task instances that are affected

qry_dag = get_all_dag_task_query(dag, session, state, task_ids, confirmed_dates)
qry_dag = get_all_dag_task_query(dag, session, state, task_id_map_index_list, confirmed_dates)

if commit:
tis_altered = qry_dag.with_for_update().all()
Expand Down Expand Up @@ -179,20 +187,26 @@ def get_all_dag_task_query(
dag: DAG,
session: SASession,
state: TaskInstanceState,
task_ids: List[str],
task_ids: Union[List[str], List[Tuple[str, int]]],
confirmed_dates: Iterable[datetime],
):
"""Get all tasks of the main dag that will be affected by a state change"""
is_string_list = isinstance(task_ids[0], str)
qry_dag = (
session.query(TaskInstance)
.join(TaskInstance.dag_run)
.filter(
TaskInstance.dag_id == dag.dag_id,
DagRun.execution_date.in_(confirmed_dates),
TaskInstance.task_id.in_(task_ids),
)
.filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state))
.options(contains_eager(TaskInstance.dag_run))
)

if is_string_list:
qry_dag = qry_dag.filter(TaskInstance.task_id.in_(task_ids))
else:
qry_dag = qry_dag.filter(tuple_in_condition((TaskInstance.task_id, TaskInstance.map_index), task_ids))
qry_dag = qry_dag.filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state)).options(
contains_eager(TaskInstance.dag_run)
)
return qry_dag

Expand Down Expand Up @@ -270,14 +284,18 @@ def _iter_existing_dag_run_infos(dag: DAG, run_ids: List[str]) -> Iterator[_DagR

def find_task_relatives(tasks, downstream, upstream):
"""Yield task ids and optionally ancestor and descendant ids."""
for task in tasks:
yield task.task_id
for item in tasks:
if isinstance(item, tuple):
task, map_index = item
else:
task, map_index = item, None
yield task.task_id, map_index
if downstream:
for relative in task.get_flat_relatives(upstream=False):
yield relative.task_id
yield relative.task_id, map_index
if upstream:
for relative in task.get_flat_relatives(upstream=True):
yield relative.task_id
yield relative.task_id, map_index


@provide_session
Expand Down
47 changes: 17 additions & 30 deletions airflow/jobs/scheduler_job.py
Expand Up @@ -28,7 +28,7 @@
from datetime import timedelta
from typing import Collection, DefaultDict, Dict, Iterator, List, Optional, Set, Tuple

from sqlalchemy import and_, func, not_, or_, text, tuple_
from sqlalchemy import func, not_, or_, text
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import load_only, selectinload
from sqlalchemy.orm.session import Session, make_transient
Expand All @@ -55,7 +55,13 @@
from airflow.utils.event_scheduler import EventScheduler
from airflow.utils.retries import MAX_DB_RETRIES, retry_db_transaction, run_with_db_retries
from airflow.utils.session import create_session, provide_session
from airflow.utils.sqlalchemy import is_lock_not_available_error, prohibit_commit, skip_locked, with_row_locks
from airflow.utils.sqlalchemy import (
is_lock_not_available_error,
prohibit_commit,
skip_locked,
tuple_in_condition,
with_row_locks,
)
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.types import DagRunType

Expand Down Expand Up @@ -321,17 +327,7 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session =
query = query.filter(not_(TI.dag_id.in_(starved_dags)))

if starved_tasks:
if settings.engine.dialect.name == 'mssql':
task_filter = or_(
and_(
TaskInstance.dag_id == dag_id,
TaskInstance.task_id == task_id,
)
for (dag_id, task_id) in starved_tasks
)
else:
task_filter = tuple_(TaskInstance.dag_id, TaskInstance.task_id).in_(starved_tasks)

task_filter = tuple_in_condition((TaskInstance.dag_id, TaskInstance.task_id), starved_tasks)
query = query.filter(not_(task_filter))

query = query.limit(max_tis)
Expand Down Expand Up @@ -980,24 +976,15 @@ def _create_dag_runs(self, dag_models: Collection[DagModel], session: Session) -
# as DagModel.dag_id and DagModel.next_dagrun
# This list is used to verify if the DagRun already exist so that we don't attempt to create
# duplicate dag runs

if session.bind.dialect.name == 'mssql':
existing_dagruns_filter = or_(
*(
and_(
DagRun.dag_id == dm.dag_id,
DagRun.execution_date == dm.next_dagrun,
)
for dm in dag_models
)
)
else:
existing_dagruns_filter = tuple_(DagRun.dag_id, DagRun.execution_date).in_(
[(dm.dag_id, dm.next_dagrun) for dm in dag_models]
)

existing_dagruns = (
session.query(DagRun.dag_id, DagRun.execution_date).filter(existing_dagruns_filter).all()
session.query(DagRun.dag_id, DagRun.execution_date)
.filter(
tuple_in_condition(
(DagRun.dag_id, DagRun.execution_date),
((dm.dag_id, dm.next_dagrun) for dm in dag_models),
),
)
.all()
)

active_runs_of_dags = defaultdict(
Expand Down

0 comments on commit 4fa718e

Please sign in to comment.