Skip to content

Commit

Permalink
Task adoption for hybrid executors (apache#39531)
Browse files Browse the repository at this point in the history
Sort the set of tasks that are up for adoption by the executor they're
configured to run on (if any) and send them to the appropriate executor
for adoption.
  • Loading branch information
o-nikolas committed May 13, 2024
1 parent 8bc6c32 commit 3e229d8
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 5 deletions.
8 changes: 5 additions & 3 deletions airflow/executors/executor_loader.py
Expand Up @@ -202,10 +202,10 @@ def lookup_executor_name_by_str(cls, executor_name_str: str) -> ExecutorName:
elif executor_name := _module_to_executors.get(executor_name_str):
return executor_name
else:
raise AirflowException(f"Unknown executor being loaded: {executor_name}")
raise AirflowException(f"Unknown executor being loaded: {executor_name_str}")

@classmethod
def load_executor(cls, executor_name: ExecutorName | str) -> BaseExecutor:
def load_executor(cls, executor_name: ExecutorName | str | None) -> BaseExecutor:
"""
Load the executor.
Expand All @@ -217,7 +217,9 @@ def load_executor(cls, executor_name: ExecutorName | str) -> BaseExecutor:
:return: an instance of executor class via executor_name
"""
if isinstance(executor_name, str):
if not executor_name:
_executor_name = cls.get_default_executor_name()
elif isinstance(executor_name, str):
_executor_name = cls.lookup_executor_name_by_str(executor_name)
else:
_executor_name = executor_name
Expand Down
22 changes: 20 additions & 2 deletions airflow/jobs/scheduler_job_runner.py
Expand Up @@ -24,7 +24,7 @@
import sys
import time
import warnings
from collections import Counter
from collections import Counter, defaultdict
from dataclasses import dataclass
from datetime import timedelta
from functools import lru_cache, partial
Expand Down Expand Up @@ -83,6 +83,7 @@
from sqlalchemy.orm import Query, Session

from airflow.dag_processing.manager import DagFileProcessorAgent
from airflow.executors.base_executor import BaseExecutor
from airflow.models.taskinstance import TaskInstanceKey
from airflow.utils.sqlalchemy import (
CommitProhibitorGuard,
Expand Down Expand Up @@ -1651,7 +1652,11 @@ def adopt_or_reset_orphaned_tasks(self, session: Session = NEW_SESSION) -> int:
# Lock these rows, so that another scheduler can't try and adopt these too
tis_to_adopt_or_reset = with_row_locks(query, of=TI, session=session, skip_locked=True)
tis_to_adopt_or_reset = session.scalars(tis_to_adopt_or_reset).all()
to_reset = self.job.executor.try_adopt_task_instances(tis_to_adopt_or_reset)

to_reset: list[TaskInstance] = []
exec_to_tis = self._executor_to_tis(tis_to_adopt_or_reset)
for executor, tis in exec_to_tis.items():
to_reset.extend(executor.try_adopt_task_instances(tis))

reset_tis_message = []
for ti in to_reset:
Expand Down Expand Up @@ -1831,3 +1836,16 @@ def _orphan_unreferenced_datasets(self, session: Session = NEW_SESSION) -> None:

updated_count = sum(self._set_orphaned(dataset) for dataset in orphaned_dataset_query)
Stats.gauge("dataset.orphaned", updated_count)

def _executor_to_tis(self, tis: list[TaskInstance]) -> dict[BaseExecutor, list[TaskInstance]]:
"""Organize TIs into lists per their respective executor."""
_executor_to_tis: defaultdict[BaseExecutor, list[TaskInstance]] = defaultdict(list)
executor: str | None
for ti in tis:
if ti.executor:
executor = str(ti.executor)
else:
executor = None
_executor_to_tis[ExecutorLoader.load_executor(executor)].append(ti)

return _executor_to_tis
19 changes: 19 additions & 0 deletions tests/executors/test_executor_loader.py
Expand Up @@ -26,6 +26,7 @@
from airflow.exceptions import AirflowConfigException
from airflow.executors import executor_loader
from airflow.executors.executor_loader import ConnectorSource, ExecutorLoader, ExecutorName
from airflow.executors.local_executor import LocalExecutor
from airflow.providers.celery.executors.celery_executor import CeleryExecutor
from tests.test_utils.config import conf_vars

Expand Down Expand Up @@ -301,3 +302,21 @@ def test_validate_database_executor_compatibility_sqlite(self, monkeypatch, exec
monkeypatch.delenv("_AIRFLOW__SKIP_DATABASE_EXECUTOR_COMPATIBILITY_CHECK")
with expectation:
ExecutorLoader.validate_database_executor_compatibility(executor)

def test_load_executor(self):
ExecutorLoader.block_use_of_hybrid_exec = mock.Mock()
with conf_vars({("core", "executor"): "LocalExecutor"}):
ExecutorLoader.init_executors()
assert isinstance(ExecutorLoader.load_executor("LocalExecutor"), LocalExecutor)
assert isinstance(ExecutorLoader.load_executor(executor_loader._executor_names[0]), LocalExecutor)
assert isinstance(ExecutorLoader.load_executor(None), LocalExecutor)

def test_load_executor_alias(self):
ExecutorLoader.block_use_of_hybrid_exec = mock.Mock()
with conf_vars({("core", "executor"): "local_exec:airflow.executors.local_executor.LocalExecutor"}):
ExecutorLoader.init_executors()
assert isinstance(ExecutorLoader.load_executor("local_exec"), LocalExecutor)
assert isinstance(
ExecutorLoader.load_executor("airflow.executors.local_executor.LocalExecutor"), LocalExecutor
)
assert isinstance(ExecutorLoader.load_executor(executor_loader._executor_names[0]), LocalExecutor)
69 changes: 69 additions & 0 deletions tests/jobs/test_scheduler_job.py
Expand Up @@ -23,6 +23,7 @@
import os
from collections import deque
from datetime import timedelta
from importlib import reload
from typing import Generator
from unittest import mock
from unittest.mock import MagicMock, PropertyMock, patch
Expand Down Expand Up @@ -165,6 +166,18 @@ def set_instance_attrs(self, dagbag) -> Generator:
self.null_exec = None
del self.dagbag

@pytest.fixture
def mock_executors(self):
default_executor = mock.MagicMock(slots_available=8, slots_occupied=0)
default_executor.name = MagicMock(alias="default_exec", module_path="default.exec.module.path")
second_executor = mock.MagicMock(slots_available=8, slots_occupied=0)
second_executor.name = MagicMock(alias="secondary_exec", module_path="secondary.exec.module.path")
with mock.patch("airflow.jobs.job.Job.executors", new_callable=PropertyMock) as executors_mock:
with mock.patch("airflow.jobs.job.Job.executor", new_callable=PropertyMock) as executor_mock:
executor_mock.return_value = default_executor
executors_mock.return_value = [default_executor, second_executor]
yield [default_executor, second_executor]

@pytest.mark.parametrize(
"configs",
[
Expand Down Expand Up @@ -1740,6 +1753,62 @@ def test_adopt_or_reset_orphaned_tasks(self, dag_maker):
ti2 = dr2.get_task_instance(task_id=op1.task_id, session=session)
assert ti2.state == State.QUEUED, "Tasks run by Backfill Jobs should not be reset"

def test_adopt_or_reset_orphaned_tasks_multiple_executors(self, dag_maker, mock_executors):
"""Test that with multiple executors configured tasks are sorted correctly and handed off to the
correct executor for adoption."""
session = settings.Session()
with dag_maker("test_execute_helper_reset_orphaned_tasks_multiple_executors"):
op1 = EmptyOperator(task_id="op1")
op2 = EmptyOperator(task_id="op2", executor="default_exec")
op3 = EmptyOperator(task_id="op3", executor="secondary_exec")

dr = dag_maker.create_dagrun()
scheduler_job = Job()
session.add(scheduler_job)
session.commit()
ti1 = dr.get_task_instance(task_id=op1.task_id, session=session)
ti2 = dr.get_task_instance(task_id=op2.task_id, session=session)
ti3 = dr.get_task_instance(task_id=op3.task_id, session=session)
tis = [ti1, ti2, ti3]
for ti in tis:
ti.state = State.QUEUED
ti.queued_by_job_id = scheduler_job.id
session.commit()

with mock.patch("airflow.executors.executor_loader.ExecutorLoader.load_executor") as loader_mock:
# reload the scheduler_job_runner module so that it loads a fresh executor_loader module which
# contains the mocked load_executor method.
from airflow.jobs import scheduler_job_runner

reload(scheduler_job_runner)

processor = mock.MagicMock()

new_scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=new_scheduler_job, num_runs=0)
self.job_runner.processor_agent = processor
# The executors are mocked, so cannot be loaded/imported. Mock load_executor and return the
# correct object for the given input executor name.
loader_mock.side_effect = lambda *x: {
("default_exec",): mock_executors[0],
(None,): mock_executors[0],
("secondary_exec",): mock_executors[1],
}[x]

self.job_runner.adopt_or_reset_orphaned_tasks()

# Default executor is called for ti1 (no explicit executor override uses default) and ti2 (where we
# explicitly marked that for execution by the default executor)
try:
mock_executors[0].try_adopt_task_instances.assert_called_once_with([ti1, ti2])
except AssertionError:
# The order of the TIs given to try_adopt_task_instances is not consistent, so check the other
# order first before allowing AssertionError to fail the test
mock_executors[0].try_adopt_task_instances.assert_called_once_with([ti2, ti1])

# Second executor called for ti3
mock_executors[1].try_adopt_task_instances.assert_called_once_with([ti3])

def test_fail_stuck_queued_tasks(self, dag_maker, session):
with dag_maker("test_fail_stuck_queued_tasks"):
op1 = EmptyOperator(task_id="op1")
Expand Down

0 comments on commit 3e229d8

Please sign in to comment.