Skip to content

Commit

Permalink
Run mini scheduler in LocalTaskJob during task exit (#16289)
Browse files Browse the repository at this point in the history
Currently, the chances of tasks being killed by the LocalTaskJob heartbeat is high.

This is because, after marking a task successful/failed in Taskinstance.py and mini scheduler is enabled,
we start running the mini scheduler. Whenever the mini scheduling takes time and meet the next job heartbeat,
the heartbeat detects that this task has succeeded with no return code because LocalTaskJob.handle_task_exit
was not called after the task succeeded. Hence, the heartbeat thinks that this task was externally marked failed/successful.

This change resolves this by moving the mini scheduler to LocalTaskJob at the handle_task_exit method ensuring
that the task will no longer be killed by the next heartbeat
  • Loading branch information
ephraimbuddy committed Jun 10, 2021
1 parent 59c6720 commit 408bd26
Show file tree
Hide file tree
Showing 5 changed files with 185 additions and 174 deletions.
60 changes: 59 additions & 1 deletion airflow/jobs/local_task_job.py
Expand Up @@ -16,19 +16,23 @@
# specific language governing permissions and limitations
# under the License.
#

import signal
from typing import Optional

from sqlalchemy.exc import OperationalError

from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.jobs.base_job import BaseJob
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.sentry import Sentry
from airflow.stats import Stats
from airflow.task.task_runner import get_task_runner
from airflow.utils import timezone
from airflow.utils.net import get_hostname
from airflow.utils.session import provide_session
from airflow.utils.sqlalchemy import with_row_locks
from airflow.utils.state import State


Expand Down Expand Up @@ -160,6 +164,9 @@ def handle_task_exit(self, return_code: int) -> None:
if self.task_instance.state != State.SUCCESS:
error = self.task_runner.deserialize_run_error()
self.task_instance._run_finished_callback(error=error) # pylint: disable=protected-access
if not self.task_instance.test_mode:
if conf.getboolean('scheduler', 'schedule_after_task_execution', fallback=True):
self._run_mini_scheduler_on_child_tasks()

def on_kill(self):
self.task_runner.terminate()
Expand Down Expand Up @@ -206,3 +213,54 @@ def heartbeat_callback(self, session=None):
error = self.task_runner.deserialize_run_error() or "task marked as failed externally"
ti._run_finished_callback(error=error) # pylint: disable=protected-access
self.terminating = True

@provide_session
@Sentry.enrich_errors
def _run_mini_scheduler_on_child_tasks(self, session=None) -> None:
try:
# Re-select the row with a lock
dag_run = with_row_locks(
session.query(DagRun).filter_by(
dag_id=self.dag_id,
execution_date=self.task_instance.execution_date,
),
session=session,
).one()

# Get a partial dag with just the specific tasks we want to
# examine. In order for dep checks to work correctly, we
# include ourself (so TriggerRuleDep can check the state of the
# task we just executed)
task = self.task_instance.task

partial_dag = task.dag.partial_subset(
task.downstream_task_ids,
include_downstream=False,
include_upstream=False,
include_direct_upstream=True,
)

dag_run.dag = partial_dag
info = dag_run.task_instance_scheduling_decisions(session)

skippable_task_ids = {
task_id for task_id in partial_dag.task_ids if task_id not in task.downstream_task_ids
}

schedulable_tis = [ti for ti in info.schedulable_tis if ti.task_id not in skippable_task_ids]
for schedulable_ti in schedulable_tis:
if not hasattr(schedulable_ti, "task"):
schedulable_ti.task = task.dag.get_task(schedulable_ti.task_id)

num = dag_run.schedule_tis(schedulable_tis)
self.log.info("%d downstream tasks scheduled from follow-on schedule check", num)

session.commit()
except OperationalError as e:
# Any kind of DB error here is _non fatal_ as this block is just an optimisation.
self.log.info(
"Skipping mini scheduling run due to exception: %s",
e.statement,
exc_info=True,
)
session.rollback()
60 changes: 2 additions & 58 deletions airflow/models/taskinstance.py
Expand Up @@ -35,7 +35,6 @@
import pendulum
from jinja2 import TemplateAssertionError, UndefinedError
from sqlalchemy import Column, Float, Index, Integer, PickleType, String, and_, func, or_
from sqlalchemy.exc import OperationalError
from sqlalchemy.orm import reconstructor, relationship
from sqlalchemy.orm.session import Session
from sqlalchemy.sql.elements import BooleanClauseList
Expand Down Expand Up @@ -70,7 +69,7 @@
from airflow.utils.operator_helpers import context_to_airflow_vars
from airflow.utils.platform import getuser
from airflow.utils.session import provide_session
from airflow.utils.sqlalchemy import UtcDateTime, with_row_locks
from airflow.utils.sqlalchemy import UtcDateTime
from airflow.utils.state import State
from airflow.utils.timeout import timeout

Expand Down Expand Up @@ -1200,62 +1199,6 @@ def _run_raw_task(

session.commit()

if not test_mode:
self._run_mini_scheduler_on_child_tasks(session)

@provide_session
@Sentry.enrich_errors
def _run_mini_scheduler_on_child_tasks(self, session=None) -> None:
if conf.getboolean('scheduler', 'schedule_after_task_execution', fallback=True):
from airflow.models.dagrun import DagRun # Avoid circular import

try:
# Re-select the row with a lock
dag_run = with_row_locks(
session.query(DagRun).filter_by(
dag_id=self.dag_id,
execution_date=self.execution_date,
),
session=session,
).one()

# Get a partial dag with just the specific tasks we want to
# examine. In order for dep checks to work correctly, we
# include ourself (so TriggerRuleDep can check the state of the
# task we just executed)
partial_dag = self.task.dag.partial_subset(
self.task.downstream_task_ids,
include_downstream=False,
include_upstream=False,
include_direct_upstream=True,
)

dag_run.dag = partial_dag
info = dag_run.task_instance_scheduling_decisions(session)

skippable_task_ids = {
task_id
for task_id in partial_dag.task_ids
if task_id not in self.task.downstream_task_ids
}

schedulable_tis = [ti for ti in info.schedulable_tis if ti.task_id not in skippable_task_ids]
for schedulable_ti in schedulable_tis:
if not hasattr(schedulable_ti, "task"):
schedulable_ti.task = self.task.dag.get_task(schedulable_ti.task_id)

num = dag_run.schedule_tis(schedulable_tis)
self.log.info("%d downstream tasks scheduled from follow-on schedule check", num)

session.commit()
except OperationalError as e:
# Any kind of DB error here is _non fatal_ as this block is just an optimisation.
self.log.info(
f"Skipping mini scheduling run due to exception: {e.statement}",
exc_info=True,
)
session.rollback()

def _prepare_and_execute_task_with_callbacks(self, context, task):
"""Prepare Task for Execution"""
from airflow.models.renderedtifields import RenderedTaskInstanceFields
Expand Down Expand Up @@ -1408,6 +1351,7 @@ def run( # pylint: disable=too-many-arguments
session=session,
)
if not res:
self.log.info("CHECK AND CHANGE")
return

try:
Expand Down
4 changes: 1 addition & 3 deletions tests/cli/commands/test_task_command.py
Expand Up @@ -71,8 +71,7 @@ def test_cli_list_tasks(self):
args = self.parser.parse_args(['tasks', 'list', 'example_bash_operator', '--tree'])
task_command.task_list(args)

@mock.patch("airflow.models.taskinstance.TaskInstance._run_mini_scheduler_on_child_tasks")
def test_test(self, mock_run_mini_scheduler):
def test_test(self):
"""Test the `airflow test` command"""
args = self.parser.parse_args(
["tasks", "test", "example_python_operator", 'print_the_context', '2018-01-01']
Expand All @@ -81,7 +80,6 @@ def test_test(self, mock_run_mini_scheduler):
with redirect_stdout(io.StringIO()) as stdout:
task_command.task_test(args)

mock_run_mini_scheduler.assert_not_called()
# Check that prints, and log messages, are shown
assert "'example_python_operator__print_the_context__20180101'" in stdout.getvalue()

Expand Down
132 changes: 123 additions & 9 deletions tests/jobs/test_local_task_job.py
Expand Up @@ -33,7 +33,8 @@
from airflow.exceptions import AirflowException, AirflowFailException
from airflow.executors.sequential_executor import SequentialExecutor
from airflow.jobs.local_task_job import LocalTaskJob
from airflow.models.dag import DAG
from airflow.jobs.scheduler_job import SchedulerJob
from airflow.models.dag import DAG, DagModel
from airflow.models.dagbag import DagBag
from airflow.models.taskinstance import TaskInstance
from airflow.operators.dummy import DummyOperator
Expand All @@ -44,8 +45,9 @@
from airflow.utils.session import create_session
from airflow.utils.state import State
from airflow.utils.timeout import timeout
from tests.test_utils import db
from tests.test_utils.asserts import assert_queries_count
from tests.test_utils.db import clear_db_jobs, clear_db_runs
from tests.test_utils.config import conf_vars
from tests.test_utils.mock_executor import MockExecutor

DEFAULT_DATE = timezone.datetime(2016, 1, 1)
Expand All @@ -54,15 +56,25 @@

class TestLocalTaskJob(unittest.TestCase):
def setUp(self):
clear_db_jobs()
clear_db_runs()
db.clear_db_dags()
db.clear_db_jobs()
db.clear_db_runs()
db.clear_db_task_fail()
patcher = patch('airflow.jobs.base_job.sleep')
self.addCleanup(patcher.stop)
self.mock_base_job_sleep = patcher.start()

def tearDown(self) -> None:
clear_db_jobs()
clear_db_runs()
db.clear_db_dags()
db.clear_db_jobs()
db.clear_db_runs()
db.clear_db_task_fail()

def validate_ti_states(self, dag_run, ti_state_mapping, error_message):
for task_id, expected_state in ti_state_mapping.items():
task_instance = dag_run.get_task_instance(task_id=task_id)
task_instance.refresh_from_db()
assert task_instance.state == expected_state, error_message

def test_localtaskjob_essential_attr(self):
"""
Expand Down Expand Up @@ -563,20 +575,122 @@ def task_function(ti):
if ti.state == State.RUNNING and ti.pid is not None:
break
time.sleep(0.2)
assert ti.state == State.RUNNING
assert ti.pid is not None
assert ti.state == State.RUNNING
os.kill(ti.pid, signal_type)
process.join(timeout=10)
assert failure_callback_called.value == 1
assert task_terminated_externally.value == 1
assert not process.is_alive()

@parameterized.expand(
[
(
{('scheduler', 'schedule_after_task_execution'): 'True'},
{'A': 'B', 'B': 'C'},
{'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE},
{'A': State.SUCCESS, 'B': State.SCHEDULED, 'C': State.NONE},
{'A': State.SUCCESS, 'B': State.SUCCESS, 'C': State.SCHEDULED},
"A -> B -> C, with fast-follow ON when A runs, B should be QUEUED. Same for B and C.",
),
(
{('scheduler', 'schedule_after_task_execution'): 'False'},
{'A': 'B', 'B': 'C'},
{'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE},
{'A': State.SUCCESS, 'B': State.NONE, 'C': State.NONE},
None,
"A -> B -> C, with fast-follow OFF, when A runs, B shouldn't be QUEUED.",
),
(
{('scheduler', 'schedule_after_task_execution'): 'True'},
{'A': 'B', 'C': 'B', 'D': 'C'},
{'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE, 'D': State.NONE},
{'A': State.SUCCESS, 'B': State.NONE, 'C': State.NONE, 'D': State.NONE},
None,
"D -> C -> B & A -> B, when A runs but C isn't QUEUED yet, B shouldn't be QUEUED.",
),
(
{('scheduler', 'schedule_after_task_execution'): 'True'},
{'A': 'C', 'B': 'C'},
{'A': State.QUEUED, 'B': State.FAILED, 'C': State.NONE},
{'A': State.SUCCESS, 'B': State.FAILED, 'C': State.UPSTREAM_FAILED},
None,
"A -> C & B -> C, when A is QUEUED but B has FAILED, C is marked UPSTREAM_FAILED.",
),
]
)
def test_fast_follow(
self, conf, dependencies, init_state, first_run_state, second_run_state, error_message
):
# pylint: disable=too-many-locals
with conf_vars(conf):
session = settings.Session()

dag = DAG('test_dagrun_fast_follow', start_date=DEFAULT_DATE)

dag_model = DagModel(
dag_id=dag.dag_id,
next_dagrun=dag.start_date,
is_active=True,
)
session.add(dag_model)
session.flush()

python_callable = lambda: True
with dag:
task_a = PythonOperator(task_id='A', python_callable=python_callable)
task_b = PythonOperator(task_id='B', python_callable=python_callable)
task_c = PythonOperator(task_id='C', python_callable=python_callable)
if 'D' in init_state:
task_d = PythonOperator(task_id='D', python_callable=python_callable)
for upstream, downstream in dependencies.items():
dag.set_dependency(upstream, downstream)

scheduler_job = SchedulerJob(subdir=os.devnull)
scheduler_job.dagbag.bag_dag(dag, root_dag=dag)

dag_run = dag.create_dagrun(run_id='test_dagrun_fast_follow', state=State.RUNNING)

task_instance_a = TaskInstance(task_a, dag_run.execution_date, init_state['A'])

task_instance_b = TaskInstance(task_b, dag_run.execution_date, init_state['B'])

task_instance_c = TaskInstance(task_c, dag_run.execution_date, init_state['C'])

if 'D' in init_state:
task_instance_d = TaskInstance(task_d, dag_run.execution_date, init_state['D'])
session.merge(task_instance_d)

session.merge(task_instance_a)
session.merge(task_instance_b)
session.merge(task_instance_c)
session.flush()

job1 = LocalTaskJob(
task_instance=task_instance_a, ignore_ti_state=True, executor=SequentialExecutor()
)
job1.task_runner = StandardTaskRunner(job1)

job2 = LocalTaskJob(
task_instance=task_instance_b, ignore_ti_state=True, executor=SequentialExecutor()
)
job2.task_runner = StandardTaskRunner(job2)

settings.engine.dispose()
job1.run()
self.validate_ti_states(dag_run, first_run_state, error_message)
if second_run_state:
job2.run()
self.validate_ti_states(dag_run, second_run_state, error_message)
if scheduler_job.processor_agent:
scheduler_job.processor_agent.end()


@pytest.fixture()
def clean_db_helper():
yield
clear_db_jobs()
clear_db_runs()
db.clear_db_jobs()
db.clear_db_runs()


@pytest.mark.usefixtures("clean_db_helper")
Expand Down

0 comments on commit 408bd26

Please sign in to comment.