/
mark_tasks.py
586 lines (508 loc) · 20.9 KB
/
mark_tasks.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Marks tasks APIs."""
from datetime import datetime
from typing import TYPE_CHECKING, Collection, Iterable, Iterator, List, NamedTuple, Optional, Tuple, Union
from sqlalchemy import or_
from sqlalchemy.orm import contains_eager
from sqlalchemy.orm.session import Session as SASession
from airflow.models.dag import DAG
from airflow.models.dagrun import DagRun
from airflow.models.operator import Operator
from airflow.models.taskinstance import TaskInstance
from airflow.operators.subdag import SubDagOperator
from airflow.utils import timezone
from airflow.utils.helpers import exactly_one
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.sqlalchemy import tuple_in_condition
from airflow.utils.state import DagRunState, State, TaskInstanceState
from airflow.utils.types import DagRunType
class _DagRunInfo(NamedTuple):
logical_date: datetime
data_interval: Tuple[datetime, datetime]
def _create_dagruns(
dag: DAG,
infos: Iterable[_DagRunInfo],
state: DagRunState,
run_type: DagRunType,
) -> Iterable[DagRun]:
"""Infers from data intervals which DAG runs need to be created and does so.
:param dag: The DAG to create runs for.
:param infos: List of logical dates and data intervals to evaluate.
:param state: The state to set the dag run to
:param run_type: The prefix will be used to construct dag run id: ``{run_id_prefix}__{execution_date}``.
:return: Newly created and existing dag runs for the execution dates supplied.
"""
# Find out existing DAG runs that we don't need to create.
dag_runs = {
run.logical_date: run
for run in DagRun.find(dag_id=dag.dag_id, execution_date=[info.logical_date for info in infos])
}
for info in infos:
if info.logical_date in dag_runs:
continue
dag_runs[info.logical_date] = dag.create_dagrun(
execution_date=info.logical_date,
data_interval=info.data_interval,
start_date=timezone.utcnow(),
external_trigger=False,
state=state,
run_type=run_type,
)
return dag_runs.values()
@provide_session
def set_state(
*,
tasks: Union[Collection[Operator], Collection[Tuple[Operator, int]]],
run_id: Optional[str] = None,
execution_date: Optional[datetime] = None,
upstream: bool = False,
downstream: bool = False,
future: bool = False,
past: bool = False,
state: TaskInstanceState = TaskInstanceState.SUCCESS,
commit: bool = False,
session: SASession = NEW_SESSION,
) -> List[TaskInstance]:
"""
Set the state of a task instance and if needed its relatives. Can set state
for future tasks (calculated from run_id) and retroactively
for past tasks. Will verify integrity of past dag runs in order to create
tasks that did not exist. It will not create dag runs that are missing
on the schedule (but it will as for subdag dag runs if needed).
:param tasks: the iterable of tasks or (task, map_index) tuples from which to work.
task.task.dag needs to be set
:param run_id: the run_id of the dagrun to start looking from
:param execution_date: the execution date from which to start looking(deprecated)
:param upstream: Mark all parents (upstream tasks)
:param downstream: Mark all siblings (downstream tasks) of task_id, including SubDags
:param future: Mark all future tasks on the interval of the dag up until
last execution date.
:param past: Retroactively mark all tasks starting from start_date of the DAG
:param state: State to which the tasks need to be set
:param commit: Commit tasks to be altered to the database
:param session: database session
:return: list of tasks that have been created and updated
"""
if not tasks:
return []
if not exactly_one(execution_date, run_id):
raise ValueError("Exactly one of dag_run_id and execution_date must be set")
if execution_date and not timezone.is_localized(execution_date):
raise ValueError(f"Received non-localized date {execution_date}")
task_dags = {task[0].dag if isinstance(task, tuple) else task.dag for task in tasks}
if len(task_dags) > 1:
raise ValueError(f"Received tasks from multiple DAGs: {task_dags}")
dag = next(iter(task_dags))
if dag is None:
raise ValueError("Received tasks with no DAG")
if execution_date:
run_id = dag.get_dagrun(execution_date=execution_date).run_id
if not run_id:
raise ValueError("Received tasks with no run_id")
dag_run_ids = get_run_ids(dag, run_id, future, past)
task_id_map_index_list = list(find_task_relatives(tasks, downstream, upstream))
task_ids = [task_id for task_id, _ in task_id_map_index_list]
# check if task_id_map_index_list contains map_index of None
# if it contains None, there was no map_index supplied for the task
for _, index in task_id_map_index_list:
if index is None:
task_id_map_index_list = [task_id for task_id, _ in task_id_map_index_list]
break
confirmed_infos = list(_iter_existing_dag_run_infos(dag, dag_run_ids))
confirmed_dates = [info.logical_date for info in confirmed_infos]
sub_dag_run_ids = list(
_iter_subdag_run_ids(dag, session, DagRunState(state), task_ids, commit, confirmed_infos),
)
# now look for the task instances that are affected
qry_dag = get_all_dag_task_query(dag, session, state, task_id_map_index_list, confirmed_dates)
if commit:
tis_altered = qry_dag.with_for_update().all()
if sub_dag_run_ids:
qry_sub_dag = all_subdag_tasks_query(sub_dag_run_ids, session, state, confirmed_dates)
tis_altered += qry_sub_dag.with_for_update().all()
for task_instance in tis_altered:
task_instance.set_state(state)
else:
tis_altered = qry_dag.all()
if sub_dag_run_ids:
qry_sub_dag = all_subdag_tasks_query(sub_dag_run_ids, session, state, confirmed_dates)
tis_altered += qry_sub_dag.all()
return tis_altered
def all_subdag_tasks_query(
sub_dag_run_ids: List[str],
session: SASession,
state: TaskInstanceState,
confirmed_dates: Iterable[datetime],
):
"""Get *all* tasks of the sub dags"""
qry_sub_dag = (
session.query(TaskInstance)
.filter(TaskInstance.dag_id.in_(sub_dag_run_ids), TaskInstance.execution_date.in_(confirmed_dates))
.filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state))
)
return qry_sub_dag
def get_all_dag_task_query(
dag: DAG,
session: SASession,
state: TaskInstanceState,
task_ids: Union[List[str], List[Tuple[str, int]]],
confirmed_dates: Iterable[datetime],
):
"""Get all tasks of the main dag that will be affected by a state change"""
is_string_list = isinstance(task_ids[0], str)
qry_dag = (
session.query(TaskInstance)
.join(TaskInstance.dag_run)
.filter(
TaskInstance.dag_id == dag.dag_id,
DagRun.execution_date.in_(confirmed_dates),
)
)
if is_string_list:
qry_dag = qry_dag.filter(TaskInstance.task_id.in_(task_ids))
else:
qry_dag = qry_dag.filter(tuple_in_condition((TaskInstance.task_id, TaskInstance.map_index), task_ids))
qry_dag = qry_dag.filter(or_(TaskInstance.state.is_(None), TaskInstance.state != state)).options(
contains_eager(TaskInstance.dag_run)
)
return qry_dag
def _iter_subdag_run_ids(
dag: DAG,
session: SASession,
state: DagRunState,
task_ids: List[str],
commit: bool,
confirmed_infos: Iterable[_DagRunInfo],
) -> Iterator[str]:
"""Go through subdag operators and create dag runs.
We only work within the scope of the subdag. A subdag does not propagate to
its parent DAG, but parent propagates to subdags.
"""
dags = [dag]
while dags:
current_dag = dags.pop()
for task_id in task_ids:
if not current_dag.has_task(task_id):
continue
current_task = current_dag.get_task(task_id)
if isinstance(current_task, SubDagOperator) or current_task.task_type == "SubDagOperator":
# this works as a kind of integrity check
# it creates missing dag runs for subdag operators,
# maybe this should be moved to dagrun.verify_integrity
if TYPE_CHECKING:
assert current_task.subdag
dag_runs = _create_dagruns(
current_task.subdag,
infos=confirmed_infos,
state=DagRunState.RUNNING,
run_type=DagRunType.BACKFILL_JOB,
)
verify_dagruns(dag_runs, commit, state, session, current_task)
dags.append(current_task.subdag)
yield current_task.subdag.dag_id
def verify_dagruns(
dag_runs: Iterable[DagRun],
commit: bool,
state: DagRunState,
session: SASession,
current_task: Operator,
):
"""Verifies integrity of dag_runs.
:param dag_runs: dag runs to verify
:param commit: whether dag runs state should be updated
:param state: state of the dag_run to set if commit is True
:param session: session to use
:param current_task: current task
:return:
"""
for dag_run in dag_runs:
dag_run.dag = current_task.subdag
dag_run.verify_integrity()
if commit:
dag_run.state = state
session.merge(dag_run)
def _iter_existing_dag_run_infos(dag: DAG, run_ids: List[str]) -> Iterator[_DagRunInfo]:
for dag_run in DagRun.find(dag_id=dag.dag_id, run_id=run_ids):
dag_run.dag = dag
dag_run.verify_integrity()
yield _DagRunInfo(dag_run.logical_date, dag.get_run_data_interval(dag_run))
def find_task_relatives(tasks, downstream, upstream):
"""Yield task ids and optionally ancestor and descendant ids."""
for item in tasks:
if isinstance(item, tuple):
task, map_index = item
else:
task, map_index = item, None
yield task.task_id, map_index
if downstream:
for relative in task.get_flat_relatives(upstream=False):
yield relative.task_id, map_index
if upstream:
for relative in task.get_flat_relatives(upstream=True):
yield relative.task_id, map_index
@provide_session
def get_execution_dates(
dag: DAG, execution_date: datetime, future: bool, past: bool, *, session: SASession = NEW_SESSION
) -> List[datetime]:
"""Returns dates of DAG execution"""
latest_execution_date = dag.get_latest_execution_date(session=session)
if latest_execution_date is None:
raise ValueError(f"Received non-localized date {execution_date}")
execution_date = timezone.coerce_datetime(execution_date)
# determine date range of dag runs and tasks to consider
end_date = latest_execution_date if future else execution_date
if dag.start_date:
start_date = dag.start_date
else:
start_date = execution_date
start_date = execution_date if not past else start_date
if not dag.timetable.can_run:
# If the DAG never schedules, need to look at existing DagRun if the user wants future or
# past runs.
dag_runs = dag.get_dagruns_between(start_date=start_date, end_date=end_date)
dates = sorted({d.execution_date for d in dag_runs})
elif not dag.timetable.periodic:
dates = [start_date]
else:
dates = [
info.logical_date for info in dag.iter_dagrun_infos_between(start_date, end_date, align=False)
]
return dates
@provide_session
def get_run_ids(dag: DAG, run_id: str, future: bool, past: bool, session: SASession = NEW_SESSION):
"""Returns run_ids of DAG execution"""
last_dagrun = dag.get_last_dagrun(include_externally_triggered=True)
current_dagrun = dag.get_dagrun(run_id=run_id)
first_dagrun = (
session.query(DagRun)
.filter(DagRun.dag_id == dag.dag_id)
.order_by(DagRun.execution_date.asc())
.first()
)
if last_dagrun is None:
raise ValueError(f'DagRun for {dag.dag_id} not found')
# determine run_id range of dag runs and tasks to consider
end_date = last_dagrun.logical_date if future else current_dagrun.logical_date
start_date = current_dagrun.logical_date if not past else first_dagrun.logical_date
if not dag.timetable.can_run:
# If the DAG never schedules, need to look at existing DagRun if the user wants future or
# past runs.
dag_runs = dag.get_dagruns_between(start_date=start_date, end_date=end_date)
run_ids = sorted({d.run_id for d in dag_runs})
elif not dag.timetable.periodic:
run_ids = [run_id]
else:
dates = [
info.logical_date for info in dag.iter_dagrun_infos_between(start_date, end_date, align=False)
]
run_ids = [dr.run_id for dr in DagRun.find(dag_id=dag.dag_id, execution_date=dates)]
return run_ids
def _set_dag_run_state(dag_id: str, run_id: str, state: DagRunState, session: SASession = NEW_SESSION):
"""
Helper method that set dag run state in the DB.
:param dag_id: dag_id of target dag run
:param run_id: run id of target dag run
:param state: target state
:param session: database session
"""
dag_run = session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == run_id).one()
dag_run.state = state
if state == State.RUNNING:
dag_run.start_date = timezone.utcnow()
dag_run.end_date = None
else:
dag_run.end_date = timezone.utcnow()
session.merge(dag_run)
@provide_session
def set_dag_run_state_to_success(
*,
dag: DAG,
execution_date: Optional[datetime] = None,
run_id: Optional[str] = None,
commit: bool = False,
session: SASession = NEW_SESSION,
) -> List[TaskInstance]:
"""
Set the dag run for a specific execution date and its task instances
to success.
:param dag: the DAG of which to alter state
:param execution_date: the execution date from which to start looking(deprecated)
:param run_id: the run_id to start looking from
:param commit: commit DAG and tasks to be altered to the database
:param session: database session
:return: If commit is true, list of tasks that have been updated,
otherwise list of tasks that will be updated
:raises: ValueError if dag or execution_date is invalid
"""
if not exactly_one(execution_date, run_id):
return []
if not dag:
return []
if execution_date:
if not timezone.is_localized(execution_date):
raise ValueError(f"Received non-localized date {execution_date}")
dag_run = dag.get_dagrun(execution_date=execution_date)
if not dag_run:
raise ValueError(f'DagRun with execution_date: {execution_date} not found')
run_id = dag_run.run_id
if not run_id:
raise ValueError(f'Invalid dag_run_id: {run_id}')
# Mark the dag run to success.
if commit:
_set_dag_run_state(dag.dag_id, run_id, DagRunState.SUCCESS, session)
# Mark all task instances of the dag run to success.
for task in dag.tasks:
task.dag = dag
return set_state(tasks=dag.tasks, run_id=run_id, state=State.SUCCESS, commit=commit, session=session)
@provide_session
def set_dag_run_state_to_failed(
*,
dag: DAG,
execution_date: Optional[datetime] = None,
run_id: Optional[str] = None,
commit: bool = False,
session: SASession = NEW_SESSION,
) -> List[TaskInstance]:
"""
Set the dag run for a specific execution date or run_id and its running task instances
to failed.
:param dag: the DAG of which to alter state
:param execution_date: the execution date from which to start looking(deprecated)
:param run_id: the DAG run_id to start looking from
:param commit: commit DAG and tasks to be altered to the database
:param session: database session
:return: If commit is true, list of tasks that have been updated,
otherwise list of tasks that will be updated
:raises: AssertionError if dag or execution_date is invalid
"""
if not exactly_one(execution_date, run_id):
return []
if not dag:
return []
if execution_date:
if not timezone.is_localized(execution_date):
raise ValueError(f"Received non-localized date {execution_date}")
dag_run = dag.get_dagrun(execution_date=execution_date)
if not dag_run:
raise ValueError(f'DagRun with execution_date: {execution_date} not found')
run_id = dag_run.run_id
if not run_id:
raise ValueError(f'Invalid dag_run_id: {run_id}')
# Mark the dag run to failed.
if commit:
_set_dag_run_state(dag.dag_id, run_id, DagRunState.FAILED, session)
# Mark only RUNNING task instances.
task_ids = [task.task_id for task in dag.tasks]
tis = session.query(TaskInstance).filter(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.run_id == run_id,
TaskInstance.task_id.in_(task_ids),
TaskInstance.state.in_(State.running),
)
task_ids_of_running_tis = [task_instance.task_id for task_instance in tis]
tasks = []
for task in dag.tasks:
if task.task_id not in task_ids_of_running_tis:
continue
task.dag = dag
tasks.append(task)
# Mark non-finished tasks as SKIPPED.
tis = session.query(TaskInstance).filter(
TaskInstance.dag_id == dag.dag_id,
TaskInstance.run_id == run_id,
TaskInstance.state.not_in(State.finished),
TaskInstance.state.not_in(State.running),
)
tis = [ti for ti in tis]
if commit:
for ti in tis:
ti.set_state(State.SKIPPED)
return tis + set_state(tasks=tasks, run_id=run_id, state=State.FAILED, commit=commit, session=session)
def __set_dag_run_state_to_running_or_queued(
*,
new_state: DagRunState,
dag: DAG,
execution_date: Optional[datetime] = None,
run_id: Optional[str] = None,
commit: bool = False,
session: SASession = NEW_SESSION,
) -> List[TaskInstance]:
"""
Set the dag run for a specific execution date to running.
:param dag: the DAG of which to alter state
:param execution_date: the execution date from which to start looking
:param run_id: the id of the DagRun
:param commit: commit DAG and tasks to be altered to the database
:param session: database session
:return: If commit is true, list of tasks that have been updated,
otherwise list of tasks that will be updated
"""
res: List[TaskInstance] = []
if not (execution_date is None) ^ (run_id is None):
return res
if not dag:
return res
if execution_date:
if not timezone.is_localized(execution_date):
raise ValueError(f"Received non-localized date {execution_date}")
dag_run = dag.get_dagrun(execution_date=execution_date)
if not dag_run:
raise ValueError(f'DagRun with execution_date: {execution_date} not found')
run_id = dag_run.run_id
if not run_id:
raise ValueError(f'DagRun with run_id: {run_id} not found')
# Mark the dag run to running.
if commit:
_set_dag_run_state(dag.dag_id, run_id, new_state, session)
# To keep the return type consistent with the other similar functions.
return res
@provide_session
def set_dag_run_state_to_running(
*,
dag: DAG,
execution_date: Optional[datetime] = None,
run_id: Optional[str] = None,
commit: bool = False,
session: SASession = NEW_SESSION,
) -> List[TaskInstance]:
return __set_dag_run_state_to_running_or_queued(
new_state=DagRunState.RUNNING,
dag=dag,
execution_date=execution_date,
run_id=run_id,
commit=commit,
session=session,
)
@provide_session
def set_dag_run_state_to_queued(
*,
dag: DAG,
execution_date: Optional[datetime] = None,
run_id: Optional[str] = None,
commit: bool = False,
session: SASession = NEW_SESSION,
) -> List[TaskInstance]:
return __set_dag_run_state_to_running_or_queued(
new_state=DagRunState.QUEUED,
dag=dag,
execution_date=execution_date,
run_id=run_id,
commit=commit,
session=session,
)