diff --git a/airflow/api/common/mark_tasks.py b/airflow/api/common/mark_tasks.py index d11f490247f44..1d4709fb82b9f 100644 --- a/airflow/api/common/mark_tasks.py +++ b/airflow/api/common/mark_tasks.py @@ -18,7 +18,7 @@ """Marks tasks APIs.""" from datetime import datetime -from typing import TYPE_CHECKING, Iterable, Iterator, List, NamedTuple, Optional, Tuple +from typing import TYPE_CHECKING, Collection, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Union from sqlalchemy import or_ from sqlalchemy.orm import contains_eager @@ -32,6 +32,7 @@ from airflow.utils import timezone from airflow.utils.helpers import exactly_one from airflow.utils.session import NEW_SESSION, provide_session +from airflow.utils.sqlalchemy import tuple_in_condition from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.types import DagRunType @@ -78,7 +79,7 @@ def _create_dagruns( @provide_session def set_state( *, - tasks: Iterable[Operator], + tasks: Union[Collection[Operator], Collection[Tuple[Operator, int]]], run_id: Optional[str] = None, execution_date: Optional[datetime] = None, upstream: bool = False, @@ -96,7 +97,8 @@ def set_state( tasks that did not exist. It will not create dag runs that are missing on the schedule (but it will as for subdag dag runs if needed). - :param tasks: the iterable of tasks from which to work. task.task.dag needs to be set + :param tasks: the iterable of tasks or (task, map_index) tuples from which to work. + task.task.dag needs to be set :param run_id: the run_id of the dagrun to start looking from :param execution_date: the execution date from which to start looking(deprecated) :param upstream: Mark all parents (upstream tasks) @@ -118,7 +120,7 @@ def set_state( if execution_date and not timezone.is_localized(execution_date): raise ValueError(f"Received non-localized date {execution_date}") - task_dags = {task.dag for task in tasks} + task_dags = {task[0].dag if isinstance(task, tuple) else task.dag for task in tasks} if len(task_dags) > 1: raise ValueError(f"Received tasks from multiple DAGs: {task_dags}") dag = next(iter(task_dags)) @@ -131,8 +133,14 @@ def set_state( raise ValueError("Received tasks with no run_id") dag_run_ids = get_run_ids(dag, run_id, future, past) - - task_ids = list(find_task_relatives(tasks, downstream, upstream)) + task_id_map_index_list = list(find_task_relatives(tasks, downstream, upstream)) + task_ids = [task_id for task_id, _ in task_id_map_index_list] + # check if task_id_map_index_list contains map_index of None + # if it contains None, there was no map_index supplied for the task + for _, index in task_id_map_index_list: + if index is None: + task_id_map_index_list = [task_id for task_id, _ in task_id_map_index_list] + break confirmed_infos = list(_iter_existing_dag_run_infos(dag, dag_run_ids)) confirmed_dates = [info.logical_date for info in confirmed_infos] @@ -143,7 +151,7 @@ def set_state( # now look for the task instances that are affected - qry_dag = get_all_dag_task_query(dag, session, state, task_ids, confirmed_dates) + qry_dag = get_all_dag_task_query(dag, session, state, task_id_map_index_list, confirmed_dates) if commit: tis_altered = qry_dag.with_for_update().all() @@ -179,20 +187,26 @@ def get_all_dag_task_query( dag: DAG, session: SASession, state: TaskInstanceState, - task_ids: List[str], + task_ids: Union[List[str], List[Tuple[str, int]]], confirmed_dates: Iterable[datetime], ): """Get all tasks of the main dag that will be affected by a state change""" + is_string_list = isinstance(task_ids[0], str) qry_dag = ( session.query(TaskInstance) .join(TaskInstance.dag_run) .filter( TaskInstance.dag_id == dag.dag_id, DagRun.execution_date.in_(confirmed_dates), - TaskInstance.task_id.in_(task_ids), ) - .filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state)) - .options(contains_eager(TaskInstance.dag_run)) + ) + + if is_string_list: + qry_dag = qry_dag.filter(TaskInstance.task_id.in_(task_ids)) + else: + qry_dag = qry_dag.filter(tuple_in_condition((TaskInstance.task_id, TaskInstance.map_index), task_ids)) + qry_dag = qry_dag.filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state)).options( + contains_eager(TaskInstance.dag_run) ) return qry_dag @@ -270,14 +284,18 @@ def _iter_existing_dag_run_infos(dag: DAG, run_ids: List[str]) -> Iterator[_DagR def find_task_relatives(tasks, downstream, upstream): """Yield task ids and optionally ancestor and descendant ids.""" - for task in tasks: - yield task.task_id + for item in tasks: + if isinstance(item, tuple): + task, map_index = item + else: + task, map_index = item, None + yield task.task_id, map_index if downstream: for relative in task.get_flat_relatives(upstream=False): - yield relative.task_id + yield relative.task_id, map_index if upstream: for relative in task.get_flat_relatives(upstream=True): - yield relative.task_id + yield relative.task_id, map_index @provide_session diff --git a/airflow/jobs/scheduler_job.py b/airflow/jobs/scheduler_job.py index e0b8c437ac37a..ac1d25833b5aa 100644 --- a/airflow/jobs/scheduler_job.py +++ b/airflow/jobs/scheduler_job.py @@ -28,7 +28,7 @@ from datetime import timedelta from typing import Collection, DefaultDict, Dict, Iterator, List, Optional, Set, Tuple -from sqlalchemy import and_, func, not_, or_, text, tuple_ +from sqlalchemy import func, not_, or_, text from sqlalchemy.exc import OperationalError from sqlalchemy.orm import load_only, selectinload from sqlalchemy.orm.session import Session, make_transient @@ -55,7 +55,13 @@ from airflow.utils.event_scheduler import EventScheduler from airflow.utils.retries import MAX_DB_RETRIES, retry_db_transaction, run_with_db_retries from airflow.utils.session import create_session, provide_session -from airflow.utils.sqlalchemy import is_lock_not_available_error, prohibit_commit, skip_locked, with_row_locks +from airflow.utils.sqlalchemy import ( + is_lock_not_available_error, + prohibit_commit, + skip_locked, + tuple_in_condition, + with_row_locks, +) from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.types import DagRunType @@ -321,17 +327,7 @@ def _executable_task_instances_to_queued(self, max_tis: int, session: Session = query = query.filter(not_(TI.dag_id.in_(starved_dags))) if starved_tasks: - if settings.engine.dialect.name == 'mssql': - task_filter = or_( - and_( - TaskInstance.dag_id == dag_id, - TaskInstance.task_id == task_id, - ) - for (dag_id, task_id) in starved_tasks - ) - else: - task_filter = tuple_(TaskInstance.dag_id, TaskInstance.task_id).in_(starved_tasks) - + task_filter = tuple_in_condition((TaskInstance.dag_id, TaskInstance.task_id), starved_tasks) query = query.filter(not_(task_filter)) query = query.limit(max_tis) @@ -980,24 +976,15 @@ def _create_dag_runs(self, dag_models: Collection[DagModel], session: Session) - # as DagModel.dag_id and DagModel.next_dagrun # This list is used to verify if the DagRun already exist so that we don't attempt to create # duplicate dag runs - - if session.bind.dialect.name == 'mssql': - existing_dagruns_filter = or_( - *( - and_( - DagRun.dag_id == dm.dag_id, - DagRun.execution_date == dm.next_dagrun, - ) - for dm in dag_models - ) - ) - else: - existing_dagruns_filter = tuple_(DagRun.dag_id, DagRun.execution_date).in_( - [(dm.dag_id, dm.next_dagrun) for dm in dag_models] - ) - existing_dagruns = ( - session.query(DagRun.dag_id, DagRun.execution_date).filter(existing_dagruns_filter).all() + session.query(DagRun.dag_id, DagRun.execution_date) + .filter( + tuple_in_condition( + (DagRun.dag_id, DagRun.execution_date), + ((dm.dag_id, dm.next_dagrun) for dm in dag_models), + ), + ) + .all() ) active_runs_of_dags = defaultdict( diff --git a/airflow/models/dag.py b/airflow/models/dag.py index 4013f40bdcd59..83860ba59146f 100644 --- a/airflow/models/dag.py +++ b/airflow/models/dag.py @@ -39,6 +39,7 @@ Iterable, List, Optional, + Sequence, Set, Tuple, Type, @@ -51,7 +52,7 @@ import pendulum from dateutil.relativedelta import relativedelta from pendulum.tz.timezone import Timezone -from sqlalchemy import Boolean, Column, ForeignKey, Index, Integer, String, Text, func, or_ +from sqlalchemy import Boolean, Column, ForeignKey, Index, Integer, String, Text, func, not_, or_ from sqlalchemy.orm import backref, joinedload, relationship from sqlalchemy.orm.query import Query from sqlalchemy.orm.session import Session @@ -84,7 +85,7 @@ from airflow.utils.helpers import exactly_one, validate_key from airflow.utils.log.logging_mixin import LoggingMixin from airflow.utils.session import NEW_SESSION, provide_session -from airflow.utils.sqlalchemy import Interval, UtcDateTime, skip_locked, with_row_locks +from airflow.utils.sqlalchemy import Interval, UtcDateTime, skip_locked, tuple_in_condition, with_row_locks from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.types import NOTSET, ArgNotSet, DagRunType, EdgeInfoType @@ -1340,43 +1341,33 @@ def get_task_instances( start_date = (timezone.utcnow() - timedelta(30)).replace( hour=0, minute=0, second=0, microsecond=0 ) - - if state is None: - state = [] - - return ( - cast( - Query, - self._get_task_instances( - task_ids=None, - start_date=start_date, - end_date=end_date, - run_id=None, - state=state, - include_subdags=False, - include_parentdag=False, - include_dependent_dags=False, - exclude_task_ids=cast(List[str], []), - session=session, - ), - ) - .order_by(DagRun.execution_date) - .all() + query = self._get_task_instances( + task_ids=None, + start_date=start_date, + end_date=end_date, + run_id=None, + state=state or (), + include_subdags=False, + include_parentdag=False, + include_dependent_dags=False, + exclude_task_ids=(), + session=session, ) + return cast(Query, query).order_by(DagRun.execution_date).all() @overload def _get_task_instances( self, *, - task_ids, + task_ids: Union[Collection[str], Collection[Tuple[str, int]], None], start_date: Optional[datetime], end_date: Optional[datetime], run_id: Optional[str], - state: Union[TaskInstanceState, List[TaskInstanceState]], + state: Union[TaskInstanceState, Sequence[TaskInstanceState]], include_subdags: bool, include_parentdag: bool, include_dependent_dags: bool, - exclude_task_ids: Collection[str], + exclude_task_ids: Union[Collection[str], Collection[Tuple[str, int]], None], session: Session, dag_bag: Optional["DagBag"] = ..., ) -> Iterable[TaskInstance]: @@ -1386,16 +1377,16 @@ def _get_task_instances( def _get_task_instances( self, *, - task_ids, + task_ids: Union[Collection[str], Collection[Tuple[str, int]], None], as_pk_tuple: Literal[True], start_date: Optional[datetime], end_date: Optional[datetime], run_id: Optional[str], - state: Union[TaskInstanceState, List[TaskInstanceState]], + state: Union[TaskInstanceState, Sequence[TaskInstanceState]], include_subdags: bool, include_parentdag: bool, include_dependent_dags: bool, - exclude_task_ids: Collection[str], + exclude_task_ids: Union[Collection[str], Collection[Tuple[str, int]], None], session: Session, dag_bag: Optional["DagBag"] = ..., recursion_depth: int = ..., @@ -1407,16 +1398,16 @@ def _get_task_instances( def _get_task_instances( self, *, - task_ids, + task_ids: Union[Collection[str], Collection[Tuple[str, int]], None], as_pk_tuple: Literal[True, None] = None, start_date: Optional[datetime], end_date: Optional[datetime], run_id: Optional[str], - state: Union[TaskInstanceState, List[TaskInstanceState]], + state: Union[TaskInstanceState, Sequence[TaskInstanceState]], include_subdags: bool, include_parentdag: bool, include_dependent_dags: bool, - exclude_task_ids: Collection[str], + exclude_task_ids: Union[Collection[str], Collection[Tuple[str, int]], None], session: Session, dag_bag: Optional["DagBag"] = None, recursion_depth: int = 0, @@ -1435,7 +1426,7 @@ def _get_task_instances( # Do we want full objects, or just the primary columns? if as_pk_tuple: - tis = session.query(TI.dag_id, TI.task_id, TI.run_id) + tis = session.query(TI.dag_id, TI.task_id, TI.run_id, TI.map_index) else: tis = session.query(TaskInstance) tis = tis.join(TaskInstance.dag_run) @@ -1454,8 +1445,13 @@ def _get_task_instances( tis = tis.filter(TaskInstance.run_id == run_id) if start_date: tis = tis.filter(DagRun.execution_date >= start_date) - if task_ids: - tis = tis.filter(TaskInstance.task_id.in_(task_ids)) + + if task_ids is None: + pass # Disable filter if not set. + elif isinstance(next(iter(task_ids), None), str): + tis = tis.filter(TI.task_id.in_(task_ids)) + else: + tis = tis.filter(tuple_in_condition((TI.task_id, TI.map_index), task_ids)) # This allows allow_trigger_in_future config to take affect, rather than mandating exec_date <= UTC if end_date or not self.allow_future_exec_dates: @@ -1593,25 +1589,29 @@ def _get_task_instances( if as_pk_tuple: result.update(TaskInstanceKey(*cols) for cols in tis.all()) else: - result.update(ti.key for ti in tis.all()) - - if exclude_task_ids: - result = set( - filter( - lambda key: key.task_id not in exclude_task_ids, - result, - ) - ) + result.update(ti.key for ti in tis) + + if exclude_task_ids is not None: + result = { + task + for task in result + if task.task_id not in exclude_task_ids + and (task.task_id, task.map_index) not in exclude_task_ids + } if as_pk_tuple: return result - elif result: + if result: # We've been asked for objects, lets combine it all back in to a result set - tis = tis.with_entities(TI.dag_id, TI.task_id, TI.run_id) - - tis = session.query(TI).filter(TI.filter_for_tis(result)) - elif exclude_task_ids: - tis = tis.filter(TI.task_id.notin_(list(exclude_task_ids))) + ti_filters = TI.filter_for_tis(result) + if ti_filters is not None: + tis = session.query(TI).filter(ti_filters) + elif exclude_task_ids is None: + pass # Disable filter if not set. + elif isinstance(next(iter(exclude_task_ids), None), str): + tis = tis.filter(TI.task_id.notin_(exclude_task_ids)) + else: + tis = tis.filter(not_(tuple_in_condition((TI.task_id, TI.map_index), exclude_task_ids))) return tis @@ -1620,6 +1620,7 @@ def set_task_instance_state( self, *, task_id: str, + map_indexes: Optional[Collection[int]] = None, execution_date: Optional[datetime] = None, run_id: Optional[str] = None, state: TaskInstanceState, @@ -1635,6 +1636,8 @@ def set_task_instance_state( in failed or upstream_failed state. :param task_id: Task ID of the TaskInstance + :param map_indexes: Only set TaskInstance if its map_index matches. + If None (default), all mapped TaskInstances of the task are set. :param execution_date: Execution date of the TaskInstance :param run_id: The run_id of the TaskInstance :param state: State to set the TaskInstance to @@ -1660,8 +1663,17 @@ def set_task_instance_state( task = self.get_task(task_id) task.dag = self + tasks_to_set_state: Union[List[Operator], List[Tuple[Operator, int]]] + task_ids_to_exclude_from_clear: Union[Set[str], Set[Tuple[str, int]]] + if map_indexes is None: + tasks_to_set_state = [task] + task_ids_to_exclude_from_clear = {task_id} + else: + tasks_to_set_state = [(task, map_index) for map_index in map_indexes] + task_ids_to_exclude_from_clear = {(task_id, map_index) for map_index in map_indexes} + altered = set_state( - tasks=[task], + tasks=tasks_to_set_state, execution_date=execution_date, run_id=run_id, upstream=upstream, @@ -1696,7 +1708,7 @@ def set_task_instance_state( only_failed=True, session=session, # Exclude the task itself from being cleared - exclude_task_ids={task_id}, + exclude_task_ids=task_ids_to_exclude_from_clear, ) return altered @@ -1754,7 +1766,7 @@ def set_dag_runs_state( @provide_session def clear( self, - task_ids=None, + task_ids: Union[Collection[str], Collection[Tuple[str, int]], None] = None, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, only_failed: bool = False, @@ -1769,13 +1781,13 @@ def clear( recursion_depth: int = 0, max_recursion_depth: Optional[int] = None, dag_bag: Optional["DagBag"] = None, - exclude_task_ids: FrozenSet[str] = frozenset({}), + exclude_task_ids: Union[FrozenSet[str], FrozenSet[Tuple[str, int]], None] = frozenset(), ) -> Union[int, Iterable[TaskInstance]]: """ Clears a set of task instances associated with the current dag for a specified date range. - :param task_ids: List of task ids to clear + :param task_ids: List of task ids or (``task_id``, ``map_index``) tuples to clear :param start_date: The minimum execution_date to clear :param end_date: The maximum execution_date to clear :param only_failed: Only clear failed tasks @@ -1789,7 +1801,8 @@ def clear( :param dry_run: Find the tasks to clear but don't clear them. :param session: The sqlalchemy session to use :param dag_bag: The DagBag used to find the dags subdags (Optional) - :param exclude_task_ids: A set of ``task_id`` that should not be cleared + :param exclude_task_ids: A set of ``task_id`` or (``task_id``, ``map_index``) + tuples that should not be cleared """ if get_tis: warnings.warn( diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 48d3a047fb183..9d135a47b8aa6 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -67,7 +67,6 @@ inspect, or_, text, - tuple_, ) from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.ext.mutable import MutableDict @@ -122,7 +121,7 @@ from airflow.utils.platform import getuser from airflow.utils.retries import run_with_db_retries from airflow.utils.session import NEW_SESSION, create_session, provide_session -from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime, with_row_locks +from airflow.utils.sqlalchemy import ExtendedJSON, UtcDateTime, tuple_in_condition, with_row_locks from airflow.utils.state import DagRunState, State, TaskInstanceState from airflow.utils.timeout import timeout @@ -2540,20 +2539,10 @@ def filter_for_tis(tis: Iterable[Union["TaskInstance", TaskInstanceKey]]) -> Opt TaskInstance.task_id == first_task_id, ) - if settings.engine.dialect.name == 'mssql': - return or_( - and_( - TaskInstance.dag_id == ti.dag_id, - TaskInstance.task_id == ti.task_id, - TaskInstance.run_id == ti.run_id, - TaskInstance.map_index == ti.map_index, - ) - for ti in tis - ) - else: - return tuple_( - TaskInstance.dag_id, TaskInstance.task_id, TaskInstance.run_id, TaskInstance.map_index - ).in_([ti.key.primary for ti in tis]) + return tuple_in_condition( + (TaskInstance.dag_id, TaskInstance.task_id, TaskInstance.run_id, TaskInstance.map_index), + (ti.key.primary for ti in tis), + ) # State of the task instance. diff --git a/airflow/utils/sqlalchemy.py b/airflow/utils/sqlalchemy.py index c240a9445690b..de4ad01e6901e 100644 --- a/airflow/utils/sqlalchemy.py +++ b/airflow/utils/sqlalchemy.py @@ -19,15 +19,18 @@ import datetime import json import logging -from typing import Any, Dict +from typing import Any, Dict, Iterable, Tuple import pendulum from dateutil import relativedelta -from sqlalchemy import event, nullsfirst +from sqlalchemy import and_, event, false, nullsfirst, or_, tuple_ from sqlalchemy.exc import OperationalError from sqlalchemy.orm.session import Session +from sqlalchemy.sql import ColumnElement +from sqlalchemy.sql.expression import ColumnOperators from sqlalchemy.types import JSON, DateTime, Text, TypeDecorator, TypeEngine, UnicodeText +from airflow import settings from airflow.configuration import conf log = logging.getLogger(__name__) @@ -319,3 +322,23 @@ def is_lock_not_available_error(error: OperationalError): if db_err_code in ('55P03', 1205, 3572): return True return False + + +def tuple_in_condition( + columns: Tuple[ColumnElement, ...], + collection: Iterable[Any], +) -> ColumnOperators: + """Generates a tuple-in-collection operator to use in ``.filter()``. + + For most SQL backends, this generates a simple ``([col, ...]) IN [condition]`` + clause. This however does not work with MSSQL, where we need to expand to + ``(c1 = v1a AND c2 = v2a ...) OR (c1 = v1b AND c2 = v2b ...) ...`` manually. + + :meta private: + """ + if settings.engine.dialect.name != "mssql": + return tuple_(*columns).in_(collection) + clauses = [and_(*(c == v for c, v in zip(columns, values))) for values in collection] + if not clauses: + return false() + return or_(*clauses) diff --git a/airflow/www/views.py b/airflow/www/views.py index de672a0416d8e..ae0186e493222 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -95,6 +95,7 @@ set_dag_run_state_to_failed, set_dag_run_state_to_queued, set_dag_run_state_to_success, + set_state, ) from airflow.compat.functools import cached_property from airflow.configuration import AIRFLOW_CONFIG, conf @@ -107,6 +108,7 @@ from airflow.models.abstractoperator import AbstractOperator from airflow.models.dagcode import DagCode from airflow.models.dagrun import DagRun, DagRunType +from airflow.models.operator import Operator from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskinstance import TaskInstance from airflow.providers_manager import ProvidersManager @@ -1958,10 +1960,11 @@ def trigger(self, session=None): def _clear_dag_tis( self, - dag, + dag: DAG, start_date, end_date, origin, + task_ids=None, recursive=False, confirmed=False, only_failed=False, @@ -1970,6 +1973,7 @@ def _clear_dag_tis( count = dag.clear( start_date=start_date, end_date=end_date, + task_ids=task_ids, include_subdags=recursive, include_parentdag=recursive, only_failed=only_failed, @@ -1982,6 +1986,7 @@ def _clear_dag_tis( tis = dag.clear( start_date=start_date, end_date=end_date, + task_ids=task_ids, include_subdags=recursive, include_parentdag=recursive, only_failed=only_failed, @@ -1990,24 +1995,19 @@ def _clear_dag_tis( except AirflowException as ex: return redirect_or_json(origin, msg=str(ex), status="error") - if not tis: - msg = "No task instances to clear" - return redirect_or_json(origin, msg, status="error") - elif request.headers.get('Accept') == 'application/json': - details = [str(t) for t in tis] + assert isinstance(tis, collections.abc.Iterable) + details = [str(t) for t in tis] + if not details: + return redirect_or_json(origin, "No task instances to clear", status="error") + elif request.headers.get('Accept') == 'application/json': return htmlsafe_json_dumps(details, separators=(',', ':')) - else: - details = "\n".join(str(t) for t in tis) - - response = self.render_template( - 'airflow/confirm.html', - endpoint=None, - message="Task instances you are about to clear:", - details=details, - ) - - return response + return self.render_template( + 'airflow/confirm.html', + endpoint=None, + message="Task instances you are about to clear:", + details="\n".join(details), + ) @expose('/clear', methods=['POST']) @auth.has_access( @@ -2024,6 +2024,11 @@ def clear(self): origin = get_safe_url(request.form.get('origin')) dag = current_app.dag_bag.get_dag(dag_id) + if 'map_index' not in request.form: + map_indexes: Optional[List[int]] = None + else: + map_indexes = request.form.getlist('map_index', type=int) + execution_date = request.form.get('execution_date') execution_date = timezone.parse(execution_date) confirmed = request.form.get('confirmed') == "true" @@ -2042,11 +2047,17 @@ def clear(self): end_date = execution_date if not future else None start_date = execution_date if not past else None + if map_indexes is None: + task_ids: Union[List[str], List[Tuple[str, int]]] = [task_id] + else: + task_ids = [(task_id, map_index) for map_index in map_indexes] + return self._clear_dag_tis( dag, start_date, end_date, origin, + task_ids=task_ids, recursive=recursive, confirmed=confirmed, only_failed=only_failed, @@ -2279,26 +2290,28 @@ def dagrun_details(self, session=None): def _mark_task_instance_state( self, - dag_id, - task_id, - origin, - dag_run_id, - upstream, - downstream, - future, - past, - state, + *, + dag_id: str, + run_id: str, + task_id: str, + map_indexes: Optional[List[int]], + origin: str, + upstream: bool, + downstream: bool, + future: bool, + past: bool, + state: TaskInstanceState, ): - dag = current_app.dag_bag.get_dag(dag_id) - latest_execution_date = dag.get_latest_execution_date() + dag: DAG = current_app.dag_bag.get_dag(dag_id) - if not latest_execution_date: - flash(f"Cannot mark tasks as {state}, seem that dag {dag_id} has never run", "error") + if not run_id: + flash(f"Cannot mark tasks as {state}, seem that DAG {dag_id} has never run", "error") return redirect(origin) altered = dag.set_task_instance_state( task_id=task_id, - run_id=dag_run_id, + map_indexes=map_indexes, + run_id=run_id, state=state, upstream=upstream, downstream=downstream, @@ -2326,6 +2339,11 @@ def confirm(self): state = args.get('state') origin = args.get('origin') + if 'map_index' not in args: + map_indexes: Optional[List[int]] = None + else: + map_indexes = args.getlist('map_index', type=int) + upstream = to_boolean(args.get('upstream')) downstream = to_boolean(args.get('downstream')) future = to_boolean(args.get('future')) @@ -2357,10 +2375,13 @@ def confirm(self): msg = f"Cannot mark tasks as {state}, seem that dag {dag_id} has never run" return redirect_or_json(origin, msg, status='error') - from airflow.api.common.mark_tasks import set_state + if map_indexes is None: + tasks: Union[List[Operator], List[Tuple[Operator, int]]] = [task] + else: + tasks = [(task, map_index) for map_index in map_indexes] to_be_altered = set_state( - tasks=[task], + tasks=tasks, run_id=dag_run_id, upstream=upstream, downstream=downstream, @@ -2398,24 +2419,30 @@ def failed(self): args = request.form dag_id = args.get('dag_id') task_id = args.get('task_id') - origin = get_safe_url(args.get('origin')) - dag_run_id = args.get('dag_run_id') + run_id = args.get('dag_run_id') + + if 'map_index' not in args: + map_indexes: Optional[List[int]] = None + else: + map_indexes = args.getlist('map_index', type=int) + origin = get_safe_url(args.get('origin')) upstream = to_boolean(args.get('upstream')) downstream = to_boolean(args.get('downstream')) future = to_boolean(args.get('future')) past = to_boolean(args.get('past')) return self._mark_task_instance_state( - dag_id, - task_id, - origin, - dag_run_id, - upstream, - downstream, - future, - past, - State.FAILED, + dag_id=dag_id, + run_id=run_id, + task_id=task_id, + map_indexes=map_indexes, + origin=origin, + upstream=upstream, + downstream=downstream, + future=future, + past=past, + state=TaskInstanceState.FAILED, ) @expose('/success', methods=['POST']) @@ -2431,24 +2458,30 @@ def success(self): args = request.form dag_id = args.get('dag_id') task_id = args.get('task_id') - origin = get_safe_url(args.get('origin')) - dag_run_id = args.get('dag_run_id') + run_id = args.get('dag_run_id') + + if 'map_index' not in args: + map_indexes: Optional[List[int]] = None + else: + map_indexes = args.getlist('map_index', type=int) + origin = get_safe_url(args.get('origin')) upstream = to_boolean(args.get('upstream')) downstream = to_boolean(args.get('downstream')) future = to_boolean(args.get('future')) past = to_boolean(args.get('past')) return self._mark_task_instance_state( - dag_id, - task_id, - origin, - dag_run_id, - upstream, - downstream, - future, - past, - State.SUCCESS, + dag_id=dag_id, + run_id=run_id, + task_id=task_id, + map_indexes=map_indexes, + origin=origin, + upstream=upstream, + downstream=downstream, + future=future, + past=past, + state=TaskInstanceState.SUCCESS, ) @expose('/dags/') diff --git a/tests/api/common/test_mark_tasks.py b/tests/api/common/test_mark_tasks.py index 3a3bcfc621a07..4c1d3c604b28c 100644 --- a/tests/api/common/test_mark_tasks.py +++ b/tests/api/common/test_mark_tasks.py @@ -60,6 +60,7 @@ def create_dags(cls, dagbag): cls.dag1 = dagbag.get_dag('miscellaneous_test_dag') cls.dag2 = dagbag.get_dag('example_subdag_operator') cls.dag3 = dagbag.get_dag('example_trigger_target_dag') + cls.dag4 = dagbag.get_dag('test_mapped_classic') cls.execution_dates = [days_ago(2), days_ago(1)] start_date3 = cls.dag3.start_date cls.dag3_execution_dates = [ @@ -105,6 +106,20 @@ def setup(self): for dr in drs: dr.dag = self.dag3 + drs = _create_dagruns( + self.dag4, + [ + _DagRunInfo( + self.dag4.start_date, + (self.dag4.start_date, self.dag4.start_date + timedelta(days=1)), + ) + ], + state=State.SUCCESS, + run_type=DagRunType.MANUAL, + ) + for dr in drs: + dr.dag = self.dag4 + yield clear_db_runs() @@ -123,7 +138,7 @@ def snapshot_state(dag, execution_dates): ) @provide_session - def verify_state(self, dag, task_ids, execution_dates, state, old_tis, session=None): + def verify_state(self, dag, task_ids, execution_dates, state, old_tis, session=None, map_indexes=None): TI = models.TaskInstance DR = models.DagRun @@ -140,13 +155,25 @@ def verify_state(self, dag, task_ids, execution_dates, state, old_tis, session=N for ti in tis: assert ti.operator == dag.get_task(ti.task_id).task_type if ti.task_id in task_ids and ti.execution_date in execution_dates: - assert ti.state == state + if map_indexes: + if ti.map_index in map_indexes: + assert ti.state == state + else: + assert ti.state == state if state in State.finished: - assert ti.end_date is not None + if map_indexes: + if ti.map_index in map_indexes: + assert ti.end_date is not None + else: + assert ti.end_date is not None else: for old_ti in old_tis: if old_ti.task_id == ti.task_id and old_ti.execution_date == ti.execution_date: - assert ti.state == old_ti.state + if map_indexes: + if ti.map_index in map_indexes: + assert ti.state == old_ti.state + else: + assert ti.state == old_ti.state def test_mark_tasks_now(self): # set one task to success but do not commit @@ -409,6 +436,33 @@ def test_mark_tasks_subdag(self): # tested logic. self.verify_state(self.dag2, task_ids, [self.execution_dates[0]], State.SUCCESS, []) + def test_mark_mapped_task_instance_state(self): + # set mapped task instance to success + snapshot = TestMarkTasks.snapshot_state(self.dag4, self.execution_dates) + task = self.dag4.get_task("consumer_literal") + tasks = [(task, 0), (task, 1)] + map_indexes = [0, 1] + dr = DagRun.find(dag_id=self.dag4.dag_id, execution_date=self.execution_dates[0])[0] + altered = set_state( + tasks=tasks, + run_id=dr.run_id, + upstream=False, + downstream=False, + future=False, + past=False, + state=State.SUCCESS, + commit=True, + ) + assert len(altered) == 2 + self.verify_state( + self.dag4, + [task.task_id for task, _ in tasks], + [self.execution_dates[0]], + State.SUCCESS, + snapshot, + map_indexes=map_indexes, + ) + class TestMarkDAGRun: INITIAL_TASK_STATES = { diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index e18c6f425607b..6cd8ea660f7fc 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -49,6 +49,7 @@ from airflow.models.param import DagParam, Param, ParamsDict from airflow.operators.bash import BashOperator from airflow.operators.empty import EmptyOperator +from airflow.operators.python import PythonOperator from airflow.operators.subdag import SubDagOperator from airflow.security import permissions from airflow.templates import NativeEnvironment, SandboxedEnvironment @@ -1411,6 +1412,84 @@ def test_clear_set_dagrun_state(self, dag_run_state): dagrun = dagruns[0] # type: DagRun assert dagrun.state == dag_run_state + @parameterized.expand( + [ + (State.QUEUED,), + (State.RUNNING,), + ] + ) + def test_clear_set_dagrun_state_for_mapped_task(self, dag_run_state): + dag_id = 'test_clear_set_dagrun_state' + self._clean_up(dag_id) + task_id = 't1' + + def consumer(value): + print(value) + + dag = DAG(dag_id, start_date=DEFAULT_DATE, max_active_runs=1) + PythonOperator.partial(task_id=task_id, dag=dag, python_callable=consumer).expand(op_args=[1, 2, 4]) + + session = settings.Session() + dagrun_1 = dag.create_dagrun( + run_type=DagRunType.BACKFILL_JOB, + state=State.FAILED, + start_date=DEFAULT_DATE, + execution_date=DEFAULT_DATE, + ) + session.merge(dagrun_1) + ti = ( + session.query(TI) + .filter(TI.map_index == 0, TI.task_id == task_id, TI.dag_id == dag.dag_id) + .first() + ) + ti2 = ( + session.query(TI) + .filter(TI.map_index == 1, TI.task_id == task_id, TI.dag_id == dag.dag_id) + .first() + ) + ti.state = State.SUCCESS + ti2.state = State.SUCCESS + ti.execution_date = DEFAULT_DATE + ti2.execution_date = DEFAULT_DATE + session.merge(ti) + session.merge(ti2) + session.flush() + + dag.clear( + task_ids=[(task_id, 0)], + start_date=DEFAULT_DATE, + end_date=DEFAULT_DATE + datetime.timedelta(days=1), + dag_run_state=dag_run_state, + include_subdags=False, + include_parentdag=False, + session=session, + ) + ti = ( + session.query(TI) + .filter(TI.map_index == ti.map_index, TI.task_id == ti.task_id, TI.dag_id == ti.dag_id) + .first() + ) + ti2 = ( + session.query(TI) + .filter(TI.map_index == ti2.map_index, TI.task_id == ti2.task_id, TI.dag_id == ti2.dag_id) + .first() + ) + assert ti.state is None # cleared + assert ti2.state == State.SUCCESS # not cleared + dagruns = ( + session.query( + DagRun, + ) + .filter( + DagRun.dag_id == dag_id, + ) + .all() + ) + + assert len(dagruns) == 1 + dagrun = dagruns[0] # type: DagRun + assert dagrun.state == dag_run_state + def _make_test_subdag(self, session): dag_id = 'test_subdag' self._clean_up(dag_id) diff --git a/tests/www/views/test_views.py b/tests/www/views/test_views.py index f4be0540c333d..c7900d64fd949 100644 --- a/tests/www/views/test_views.py +++ b/tests/www/views/test_views.py @@ -271,9 +271,10 @@ def get_task_instance(session, task): view._mark_task_instance_state( dag_id=dag.dag_id, + run_id=dagrun.run_id, task_id=task_1.task_id, + map_indexes=None, origin="", - dag_run_id=dagrun.run_id, upstream=False, downstream=False, future=False, diff --git a/tests/www/views/test_views_tasks.py b/tests/www/views/test_views_tasks.py index fce94fd5e43dd..ebed9ab05fe11 100644 --- a/tests/www/views/test_views_tasks.py +++ b/tests/www/views/test_views_tasks.py @@ -517,7 +517,7 @@ def test_dag_never_run(admin_client, url): ) clear_db_runs() resp = admin_client.post(url, data=form, follow_redirects=True) - check_content_in_response(f"Cannot mark tasks as {url}, seem that dag {dag_id} has never run", resp) + check_content_in_response(f"Cannot mark tasks as {url}, seem that DAG {dag_id} has never run", resp) class _ForceHeartbeatCeleryExecutor(CeleryExecutor):