Skip to content

Commit

Permalink
Fix backfill occassional deadlocking (#26161)
Browse files Browse the repository at this point in the history
  • Loading branch information
ephraimbuddy committed Sep 6, 2022
1 parent 5b216e9 commit 6931fbf
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 15 deletions.
22 changes: 8 additions & 14 deletions airflow/jobs/backfill_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def _get_dag_run(self, dagrun_info: DagRunInfo, dag: DAG, session: Session = Non
return run

@provide_session
def _task_instances_for_dag_run(self, dag_run, session=None):
def _task_instances_for_dag_run(self, dag, dag_run, session=None):
"""
Returns a map of task instance key to task instance object for the tasks to
run in the given dag run.
Expand All @@ -351,18 +351,19 @@ def _task_instances_for_dag_run(self, dag_run, session=None):
dag_run.refresh_from_db()
make_transient(dag_run)

dag_run.dag = dag
info = dag_run.task_instance_scheduling_decisions(session=session)
schedulable_tis = info.schedulable_tis
try:
for ti in dag_run.get_task_instances():
# all tasks part of the backfill are scheduled to run
if ti.state == State.NONE:
ti.set_state(TaskInstanceState.SCHEDULED, session=session)
for ti in dag_run.get_task_instances(session=session):
if ti in schedulable_tis:
ti.set_state(TaskInstanceState.SCHEDULED)
if ti.state != TaskInstanceState.REMOVED:
tasks_to_run[ti.key] = ti
session.commit()
except Exception:
session.rollback()
raise

return tasks_to_run

def _log_progress(self, ti_status):
Expand Down Expand Up @@ -441,13 +442,6 @@ def _per_task_process(key, ti: TaskInstance, session=None):
ti_status.running.pop(key)
return

# guard against externally modified tasks instances or
# in case max concurrency has been reached at task runtime
elif ti.state == State.NONE:
self.log.warning(
"FIXME: Task instance %s state was set to None externally. This should not happen", ti
)
ti.set_state(TaskInstanceState.SCHEDULED, session=session)
if self.rerun_failed_tasks:
# Rerun failed tasks or upstreamed failed tasks
if ti.state in (TaskInstanceState.FAILED, TaskInstanceState.UPSTREAM_FAILED):
Expand Down Expand Up @@ -729,7 +723,7 @@ def _execute_dagruns(self, dagrun_infos, ti_status, executor, pickle_id, start_d
for dagrun_info in dagrun_infos:
for dag in self._get_dag_with_subdags():
dag_run = self._get_dag_run(dagrun_info, dag, session=session)
tis_map = self._task_instances_for_dag_run(dag_run, session=session)
tis_map = self._task_instances_for_dag_run(dag, dag_run, session=session)
if dag_run is None:
continue

Expand Down
31 changes: 30 additions & 1 deletion tests/jobs/test_backfill_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ def test_backfill_override_conf(self, dag_maker):
wraps=job._task_instances_for_dag_run,
) as wrapped_task_instances_for_dag_run:
job.run()
dr = wrapped_task_instances_for_dag_run.call_args_list[0][0][0]
dr = wrapped_task_instances_for_dag_run.call_args_list[0][0][1]
assert dr.conf == {"a": 1}

def test_backfill_skip_active_scheduled_dagrun(self, dag_maker, caplog):
Expand Down Expand Up @@ -1783,3 +1783,32 @@ def test_start_date_set_for_resetted_dagruns(self, dag_maker, session, caplog):
(dr,) = DagRun.find(dag_id=dag.dag_id, execution_date=DEFAULT_DATE, session=session)
assert dr.start_date
assert f'Failed to record duration of {dr}' not in caplog.text

def test_task_instances_are_not_set_to_scheduled_when_dagrun_reset(self, dag_maker, session):
"""Test that when dagrun is reset, task instances are not set to scheduled"""

with dag_maker() as dag:
task1 = EmptyOperator(task_id='task1')
task2 = EmptyOperator(task_id='task2')
task3 = EmptyOperator(task_id='task3')
task1 >> task2 >> task3

for i in range(1, 4):
dag_maker.create_dagrun(
run_id=f'test_dagrun_{i}', execution_date=DEFAULT_DATE + datetime.timedelta(days=i)
)

dag.clear()

job = BackfillJob(
dag=dag,
start_date=DEFAULT_DATE + datetime.timedelta(days=1),
end_date=DEFAULT_DATE + datetime.timedelta(days=4),
executor=MockExecutor(),
donot_pickle=True,
)
for dr in DagRun.find(dag_id=dag.dag_id, session=session):
tasks_to_run = job._task_instances_for_dag_run(dag, dr, session=session)
states = [ti.state for _, ti in tasks_to_run.items()]
assert TaskInstanceState.SCHEDULED in states
assert State.NONE in states

0 comments on commit 6931fbf

Please sign in to comment.