Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix mini scheduler expansion of mapped task #27506

Merged
merged 3 commits into from Nov 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
59 changes: 1 addition & 58 deletions airflow/jobs/local_task_job.py
Expand Up @@ -18,25 +18,20 @@
from __future__ import annotations

import signal
from typing import TYPE_CHECKING

import psutil
from sqlalchemy.exc import OperationalError

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.jobs.base_job import BaseJob
from airflow.listeners.events import register_task_instance_state_events
from airflow.listeners.listener import get_listener_manager
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.sentry import Sentry
from airflow.stats import Stats
from airflow.task.task_runner import get_task_runner
from airflow.utils import timezone
from airflow.utils.net import get_hostname
from airflow.utils.session import provide_session
from airflow.utils.sqlalchemy import with_row_locks
from airflow.utils.state import State


Expand Down Expand Up @@ -165,7 +160,7 @@ def handle_task_exit(self, return_code: int) -> None:

if not self.task_instance.test_mode:
if conf.getboolean('scheduler', 'schedule_after_task_execution', fallback=True):
self._run_mini_scheduler_on_child_tasks()
self.task_instance.schedule_downstream_tasks()
ephraimbuddy marked this conversation as resolved.
Show resolved Hide resolved

def on_kill(self):
self.task_runner.terminate()
Expand Down Expand Up @@ -230,58 +225,6 @@ def heartbeat_callback(self, session=None):
self.terminating = True
self._state_change_checks += 1

@provide_session
@Sentry.enrich_errors
def _run_mini_scheduler_on_child_tasks(self, session=None) -> None:
try:
# Re-select the row with a lock
dag_run = with_row_locks(
session.query(DagRun).filter_by(
dag_id=self.dag_id,
run_id=self.task_instance.run_id,
),
session=session,
).one()

task = self.task_instance.task
if TYPE_CHECKING:
assert task.dag

# Get a partial DAG with just the specific tasks we want to examine.
# In order for dep checks to work correctly, we include ourself (so
# TriggerRuleDep can check the state of the task we just executed).
partial_dag = task.dag.partial_subset(
task.downstream_task_ids,
include_downstream=True,
include_upstream=False,
include_direct_upstream=True,
)

dag_run.dag = partial_dag
info = dag_run.task_instance_scheduling_decisions(session)

skippable_task_ids = {
task_id for task_id in partial_dag.task_ids if task_id not in task.downstream_task_ids
}

schedulable_tis = [ti for ti in info.schedulable_tis if ti.task_id not in skippable_task_ids]
for schedulable_ti in schedulable_tis:
if not hasattr(schedulable_ti, "task"):
schedulable_ti.task = task.dag.get_task(schedulable_ti.task_id)

num = dag_run.schedule_tis(schedulable_tis)
self.log.info("%d downstream tasks scheduled from follow-on schedule check", num)

session.commit()
except OperationalError as e:
# Any kind of DB error here is _non fatal_ as this block is just an optimisation.
self.log.info(
"Skipping mini scheduling run due to exception: %s",
e.statement,
exc_info=True,
)
session.rollback()

@staticmethod
def _enable_task_listeners():
"""
Expand Down
30 changes: 20 additions & 10 deletions airflow/models/mappedoperator.py
Expand Up @@ -620,13 +620,18 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence
try:
total_length = self._get_specified_expand_input().get_total_map_length(run_id, session=session)
except NotFullyPopulated as e:
self.log.info(
"Cannot expand %r for run %s; missing upstream values: %s",
self,
run_id,
sorted(e.missing),
)
total_length = None
# partial dags comes from the mini scheduler. It's
# possible that the upstream tasks are not yet done,
# but we don't have upstream of upstreams in partial dags,
# so we ignore this exception.
if not self.dag or not self.dag.partial:
self.log.error(
"Cannot expand %r for run %s; missing upstream values: %s",
self,
run_id,
sorted(e.missing),
)

state: TaskInstanceState | None = None
unmapped_ti: TaskInstance | None = (
Expand All @@ -647,10 +652,15 @@ def expand_mapped_task(self, run_id: str, *, session: Session) -> tuple[Sequence
# The unmapped task instance still exists and is unfinished, i.e. we
# haven't tried to run it before.
if total_length is None:
# If the map length cannot be calculated (due to unavailable
# upstream sources), fail the unmapped task.
unmapped_ti.state = TaskInstanceState.UPSTREAM_FAILED
indexes_to_map: Iterable[int] = ()
if self.dag and self.dag.partial:
# If the DAG is partial, it's likely that the upstream tasks
# are not done yet, so we do nothing
indexes_to_map: Iterable[int] = ()
else:
# If the map length cannot be calculated (due to unavailable
# upstream sources), fail the unmapped task.
unmapped_ti.state = TaskInstanceState.UPSTREAM_FAILED
indexes_to_map = ()
elif total_length < 1:
# If the upstream maps this to a zero-length value, simply mark
# the unmapped task instance as SKIPPED (if needed).
Expand Down
61 changes: 61 additions & 0 deletions airflow/models/taskinstance.py
Expand Up @@ -2459,6 +2459,67 @@ def ti_selector_condition(cls, vals: Collection[str | tuple[str, int]]) -> Colum
return filters[0]
return or_(*filters)

@Sentry.enrich_errors
@provide_session
def schedule_downstream_tasks(self, session=None):
"""
The mini-scheduler for scheduling downstream tasks of this task instance
:meta: private
"""
from sqlalchemy.exc import OperationalError

from airflow.models import DagRun

try:
# Re-select the row with a lock
dag_run = with_row_locks(
session.query(DagRun).filter_by(
dag_id=self.dag_id,
run_id=self.run_id,
),
session=session,
).one()

task = self.task
if TYPE_CHECKING:
assert task.dag

# Get a partial DAG with just the specific tasks we want to examine.
# In order for dep checks to work correctly, we include ourself (so
# TriggerRuleDep can check the state of the task we just executed).
partial_dag = task.dag.partial_subset(
task.downstream_task_ids,
include_downstream=True,
include_upstream=False,
include_direct_upstream=True,
)

dag_run.dag = partial_dag
info = dag_run.task_instance_scheduling_decisions(session)

skippable_task_ids = {
task_id for task_id in partial_dag.task_ids if task_id not in task.downstream_task_ids
}

schedulable_tis = [ti for ti in info.schedulable_tis if ti.task_id not in skippable_task_ids]
for schedulable_ti in schedulable_tis:
if not hasattr(schedulable_ti, "task"):
schedulable_ti.task = task.dag.get_task(schedulable_ti.task_id)

num = dag_run.schedule_tis(schedulable_tis, session=session)
self.log.info("%d downstream tasks scheduled from follow-on schedule check", num)
ephraimbuddy marked this conversation as resolved.
Show resolved Hide resolved

session.flush()

except OperationalError as e:
# Any kind of DB error here is _non fatal_ as this block is just an optimisation.
self.log.info(
ephraimbuddy marked this conversation as resolved.
Show resolved Hide resolved
"Skipping mini scheduling run due to exception: %s",
e.statement,
exc_info=True,
)
session.rollback()


# State of the task instance.
# Stores string version of the task state.
Expand Down
1 change: 0 additions & 1 deletion tests/jobs/test_local_task_job.py
Expand Up @@ -739,7 +739,6 @@ def test_mini_scheduler_works_with_wait_for_upstream(self, caplog, get_test_dag)
ti2_l.refresh_from_db()
assert ti2_k.state == State.SUCCESS
assert ti2_l.state == State.NONE
assert "0 downstream tasks scheduled from follow-on schedule" in caplog.text
ephraimbuddy marked this conversation as resolved.
Show resolved Hide resolved

failed_deps = list(ti2_l.get_failed_dep_statuses())
assert len(failed_deps) == 1
Expand Down
83 changes: 83 additions & 0 deletions tests/models/test_taskinstance.py
Expand Up @@ -3613,3 +3613,86 @@ def get_extra_env():

echo_task = dag.get_task("echo")
assert "get_extra_env" in echo_task.upstream_task_ids


def test_mapped_task_does_not_error_in_mini_scheduler_if_upstreams_are_not_done(dag_maker, caplog, session):
"""
This tests that when scheduling child tasks of a task and there's a mapped downstream task,
if the mapped downstream task has upstreams that are not yet done, the mapped downstream task is
not marked as `upstream_failed'
"""
with dag_maker() as dag:

@dag.task
def second_task():
return [0, 1, 2]

@dag.task
def first_task():
print(2)

@dag.task
def middle_task(id):
return id

middle = middle_task.expand(id=second_task())

@dag.task
def last_task():
print(3)

[first_task(), middle] >> last_task()

dag_run = dag_maker.create_dagrun()
first_ti = dag_run.get_task_instance(task_id="first_task")
second_ti = dag_run.get_task_instance(task_id="second_task")
first_ti.state = State.SUCCESS
second_ti.state = State.RUNNING
session.merge(first_ti)
session.merge(second_ti)
session.commit()
first_ti.schedule_downstream_tasks(session=session)
middle_ti = dag_run.get_task_instance(task_id="middle_task")
assert middle_ti.state != State.UPSTREAM_FAILED
assert "0 downstream tasks scheduled from follow-on schedule" in caplog.text


def test_mapped_task_expands_in_mini_scheduler_if_upstreams_are_done(dag_maker, caplog, session):
"""Test that mini scheduler expands mapped task"""
with dag_maker() as dag:

@dag.task
def second_task():
return [0, 1, 2]

@dag.task
def first_task():
print(2)

@dag.task
def middle_task(id):
return id

middle = middle_task.expand(id=second_task())

@dag.task
def last_task():
print(3)

[first_task(), middle] >> last_task()

dr = dag_maker.create_dagrun()

first_ti = dr.get_task_instance(task_id="first_task")
first_ti.state = State.SUCCESS
session.merge(first_ti)
session.commit()
second_task = dag.get_task("second_task")
second_ti = dr.get_task_instance(task_id="second_task")
second_ti.refresh_from_task(second_task)
second_ti.run()
second_ti.schedule_downstream_tasks(session=session)
for i in range(3):
middle_ti = dr.get_task_instance(task_id="middle_task", map_index=i)
assert middle_ti.state == State.SCHEDULED
assert "3 downstream tasks scheduled from follow-on schedule" in caplog.text