Skip to content

Commit

Permalink
Fix mini scheduler expansion of mapped task (#27506)
Browse files Browse the repository at this point in the history
We have a case where the mini scheduler tries to expand a mapped task even when the downstream tasks are not yet done.

The mini scheduler extracts a partial subset of a dag and in the process, some upstream tasks are dropped.
If the task happens to be a mapped task, the expansion will fail since it needs the upstream output to make the expansion. When the expansion fails, the task is marked as `upstream_failed`. This leads to other downstream tasks being marked as upstream failed.

The solution was to ignore this error and not mark the mapped task as upstream_failed when the expansion fails and the dag is a partial subset

Co-authored-by: Ash Berlin-Taylor <ash_github@firemirror.com>
  • Loading branch information
ephraimbuddy and ashb committed Nov 9, 2022
1 parent 47a2b9e commit ed92e5d
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 69 deletions.
59 changes: 1 addition & 58 deletions airflow/jobs/local_task_job.py
Expand Up @@ -18,25 +18,20 @@
from __future__ import annotations

import signal
from typing import TYPE_CHECKING

import psutil
from sqlalchemy.exc import OperationalError

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.jobs.base_job import BaseJob
from airflow.listeners.events import register_task_instance_state_events
from airflow.listeners.listener import get_listener_manager
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.sentry import Sentry
from airflow.stats import Stats
from airflow.task.task_runner import get_task_runner
from airflow.utils import timezone
from airflow.utils.net import get_hostname
from airflow.utils.session import provide_session
from airflow.utils.sqlalchemy import with_row_locks
from airflow.utils.state import State


Expand Down Expand Up @@ -165,7 +160,7 @@ def handle_task_exit(self, return_code: int) -> None:

if not self.task_instance.test_mode:
if conf.getboolean('scheduler', 'schedule_after_task_execution', fallback=True):
self._run_mini_scheduler_on_child_tasks()
self.task_instance.schedule_downstream_tasks()

def on_kill(self):
self.task_runner.terminate()
Expand Down Expand Up @@ -230,58 +225,6 @@ def heartbeat_callback(self, session=None):
self.terminating = True
self._state_change_checks += 1

@provide_session
@Sentry.enrich_errors
def _run_mini_scheduler_on_child_tasks(self, session=None) -> None:
try:
# Re-select the row with a lock
dag_run = with_row_locks(
session.query(DagRun).filter_by(
dag_id=self.dag_id,
run_id=self.task_instance.run_id,
),
session=session,
).one()

task = self.task_instance.task
if TYPE_CHECKING:
assert task.dag

# Get a partial DAG with just the specific tasks we want to examine.
# In order for dep checks to work correctly, we include ourself (so
# TriggerRuleDep can check the state of the task we just executed).
partial_dag = task.dag.partial_subset(
task.downstream_task_ids,
include_downstream=True,
include_upstream=False,
include_direct_upstream=True,
)

dag_run.dag = partial_dag
info = dag_run.task_instance_scheduling_decisions(session)

skippable_task_ids = {
task_id for task_id in partial_dag.task_ids if task_id not in task.downstream_task_ids
}

schedulable_tis = [ti for ti in info.schedulable_tis if ti.task_id not in skippable_task_ids]
for schedulable_ti in schedulable_tis:
if not hasattr(schedulable_ti, "task"):
schedulable_ti.task = task.dag.get_task(schedulable_ti.task_id)

num = dag_run.schedule_tis(schedulable_tis)
self.log.info("%d downstream tasks scheduled from follow-on schedule check", num)

session.commit()
except OperationalError as e:
# Any kind of DB error here is _non fatal_ as this block is just an optimisation.
self.log.info(
"Skipping mini scheduling run due to exception: %s",
e.statement,
exc_info=True,
)
session.rollback()

@staticmethod
def _enable_task_listeners():
"""
Expand Down
30 changes: 20 additions & 10 deletions airflow/models/mappedoperator.py
Expand Up @@ -620,13 +620,18 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence
try:
total_length = self._get_specified_expand_input().get_total_map_length(run_id, session=session)
except NotFullyPopulated as e:
self.log.info(
"Cannot expand %r for run %s; missing upstream values: %s",
self,
run_id,
sorted(e.missing),
)
total_length = None
# partial dags comes from the mini scheduler. It's
# possible that the upstream tasks are not yet done,
# but we don't have upstream of upstreams in partial dags,
# so we ignore this exception.
if not self.dag or not self.dag.partial:
self.log.error(
"Cannot expand %r for run %s; missing upstream values: %s",
self,
run_id,
sorted(e.missing),
)

state: TaskInstanceState | None = None
unmapped_ti: TaskInstance | None = (
Expand All @@ -647,10 +652,15 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence
# The unmapped task instance still exists and is unfinished, i.e. we
# haven't tried to run it before.
if total_length is None:
# If the map length cannot be calculated (due to unavailable
# upstream sources), fail the unmapped task.
unmapped_ti.state = TaskInstanceState.UPSTREAM_FAILED
indexes_to_map: Iterable[int] = ()
if self.dag and self.dag.partial:
# If the DAG is partial, it's likely that the upstream tasks
# are not done yet, so we do nothing
indexes_to_map: Iterable[int] = ()
else:
# If the map length cannot be calculated (due to unavailable
# upstream sources), fail the unmapped task.
unmapped_ti.state = TaskInstanceState.UPSTREAM_FAILED
indexes_to_map = ()
elif total_length < 1:
# If the upstream maps this to a zero-length value, simply mark
# the unmapped task instance as SKIPPED (if needed).
Expand Down
61 changes: 61 additions & 0 deletions airflow/models/taskinstance.py
Expand Up @@ -2459,6 +2459,67 @@ def ti_selector_condition(cls, vals: Collection[str | tuple[str, int]]) -> Colum
return filters[0]
return or_(*filters)

@Sentry.enrich_errors
@provide_session
def schedule_downstream_tasks(self, session=None):
"""
The mini-scheduler for scheduling downstream tasks of this task instance
:meta: private
"""
from sqlalchemy.exc import OperationalError

from airflow.models import DagRun

try:
# Re-select the row with a lock
dag_run = with_row_locks(
session.query(DagRun).filter_by(
dag_id=self.dag_id,
run_id=self.run_id,
),
session=session,
).one()

task = self.task
if TYPE_CHECKING:
assert task.dag

# Get a partial DAG with just the specific tasks we want to examine.
# In order for dep checks to work correctly, we include ourself (so
# TriggerRuleDep can check the state of the task we just executed).
partial_dag = task.dag.partial_subset(
task.downstream_task_ids,
include_downstream=True,
include_upstream=False,
include_direct_upstream=True,
)

dag_run.dag = partial_dag
info = dag_run.task_instance_scheduling_decisions(session)

skippable_task_ids = {
task_id for task_id in partial_dag.task_ids if task_id not in task.downstream_task_ids
}

schedulable_tis = [ti for ti in info.schedulable_tis if ti.task_id not in skippable_task_ids]
for schedulable_ti in schedulable_tis:
if not hasattr(schedulable_ti, "task"):
schedulable_ti.task = task.dag.get_task(schedulable_ti.task_id)

num = dag_run.schedule_tis(schedulable_tis, session=session)
self.log.info("%d downstream tasks scheduled from follow-on schedule check", num)

session.flush()

except OperationalError as e:
# Any kind of DB error here is _non fatal_ as this block is just an optimisation.
self.log.info(
"Skipping mini scheduling run due to exception: %s",
e.statement,
exc_info=True,
)
session.rollback()


# State of the task instance.
# Stores string version of the task state.
Expand Down
1 change: 0 additions & 1 deletion tests/jobs/test_local_task_job.py
Expand Up @@ -739,7 +739,6 @@ def test_mini_scheduler_works_with_wait_for_upstream(self, caplog, get_test_dag)
ti2_l.refresh_from_db()
assert ti2_k.state == State.SUCCESS
assert ti2_l.state == State.NONE
assert "0 downstream tasks scheduled from follow-on schedule" in caplog.text

failed_deps = list(ti2_l.get_failed_dep_statuses())
assert len(failed_deps) == 1
Expand Down
83 changes: 83 additions & 0 deletions tests/models/test_taskinstance.py
Expand Up @@ -3613,3 +3613,86 @@ def get_extra_env():

echo_task = dag.get_task("echo")
assert "get_extra_env" in echo_task.upstream_task_ids


def test_mapped_task_does_not_error_in_mini_scheduler_if_upstreams_are_not_done(dag_maker, caplog, session):
"""
This tests that when scheduling child tasks of a task and there's a mapped downstream task,
if the mapped downstream task has upstreams that are not yet done, the mapped downstream task is
not marked as `upstream_failed'
"""
with dag_maker() as dag:

@dag.task
def second_task():
return [0, 1, 2]

@dag.task
def first_task():
print(2)

@dag.task
def middle_task(id):
return id

middle = middle_task.expand(id=second_task())

@dag.task
def last_task():
print(3)

[first_task(), middle] >> last_task()

dag_run = dag_maker.create_dagrun()
first_ti = dag_run.get_task_instance(task_id="first_task")
second_ti = dag_run.get_task_instance(task_id="second_task")
first_ti.state = State.SUCCESS
second_ti.state = State.RUNNING
session.merge(first_ti)
session.merge(second_ti)
session.commit()
first_ti.schedule_downstream_tasks(session=session)
middle_ti = dag_run.get_task_instance(task_id="middle_task")
assert middle_ti.state != State.UPSTREAM_FAILED
assert "0 downstream tasks scheduled from follow-on schedule" in caplog.text


def test_mapped_task_expands_in_mini_scheduler_if_upstreams_are_done(dag_maker, caplog, session):
"""Test that mini scheduler expands mapped task"""
with dag_maker() as dag:

@dag.task
def second_task():
return [0, 1, 2]

@dag.task
def first_task():
print(2)

@dag.task
def middle_task(id):
return id

middle = middle_task.expand(id=second_task())

@dag.task
def last_task():
print(3)

[first_task(), middle] >> last_task()

dr = dag_maker.create_dagrun()

first_ti = dr.get_task_instance(task_id="first_task")
first_ti.state = State.SUCCESS
session.merge(first_ti)
session.commit()
second_task = dag.get_task("second_task")
second_ti = dr.get_task_instance(task_id="second_task")
second_ti.refresh_from_task(second_task)
second_ti.run()
second_ti.schedule_downstream_tasks(session=session)
for i in range(3):
middle_ti = dr.get_task_instance(task_id="middle_task", map_index=i)
assert middle_ti.state == State.SCHEDULED
assert "3 downstream tasks scheduled from follow-on schedule" in caplog.text

0 comments on commit ed92e5d

Please sign in to comment.