From 408bd26c22913af93d05aa70abc3c66c52cd4588 Mon Sep 17 00:00:00 2001 From: Ephraim Anierobi Date: Thu, 10 Jun 2021 14:29:30 +0100 Subject: [PATCH] Run mini scheduler in LocalTaskJob during task exit (#16289) Currently, the chances of tasks being killed by the LocalTaskJob heartbeat is high. This is because, after marking a task successful/failed in Taskinstance.py and mini scheduler is enabled, we start running the mini scheduler. Whenever the mini scheduling takes time and meet the next job heartbeat, the heartbeat detects that this task has succeeded with no return code because LocalTaskJob.handle_task_exit was not called after the task succeeded. Hence, the heartbeat thinks that this task was externally marked failed/successful. This change resolves this by moving the mini scheduler to LocalTaskJob at the handle_task_exit method ensuring that the task will no longer be killed by the next heartbeat --- airflow/jobs/local_task_job.py | 60 ++++++++++- airflow/models/taskinstance.py | 60 +---------- tests/cli/commands/test_task_command.py | 4 +- tests/jobs/test_local_task_job.py | 132 ++++++++++++++++++++++-- tests/models/test_taskinstance.py | 103 ------------------ 5 files changed, 185 insertions(+), 174 deletions(-) diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py index 9e68450cf2700..35d0422e7d4b1 100644 --- a/airflow/jobs/local_task_job.py +++ b/airflow/jobs/local_task_job.py @@ -16,19 +16,23 @@ # specific language governing permissions and limitations # under the License. # - import signal from typing import Optional +from sqlalchemy.exc import OperationalError + from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.jobs.base_job import BaseJob +from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstance +from airflow.sentry import Sentry from airflow.stats import Stats from airflow.task.task_runner import get_task_runner from airflow.utils import timezone from airflow.utils.net import get_hostname from airflow.utils.session import provide_session +from airflow.utils.sqlalchemy import with_row_locks from airflow.utils.state import State @@ -160,6 +164,9 @@ def handle_task_exit(self, return_code: int) -> None: if self.task_instance.state != State.SUCCESS: error = self.task_runner.deserialize_run_error() self.task_instance._run_finished_callback(error=error) # pylint: disable=protected-access + if not self.task_instance.test_mode: + if conf.getboolean('scheduler', 'schedule_after_task_execution', fallback=True): + self._run_mini_scheduler_on_child_tasks() def on_kill(self): self.task_runner.terminate() @@ -206,3 +213,54 @@ def heartbeat_callback(self, session=None): error = self.task_runner.deserialize_run_error() or "task marked as failed externally" ti._run_finished_callback(error=error) # pylint: disable=protected-access self.terminating = True + + @provide_session + @Sentry.enrich_errors + def _run_mini_scheduler_on_child_tasks(self, session=None) -> None: + try: + # Re-select the row with a lock + dag_run = with_row_locks( + session.query(DagRun).filter_by( + dag_id=self.dag_id, + execution_date=self.task_instance.execution_date, + ), + session=session, + ).one() + + # Get a partial dag with just the specific tasks we want to + # examine. In order for dep checks to work correctly, we + # include ourself (so TriggerRuleDep can check the state of the + # task we just executed) + task = self.task_instance.task + + partial_dag = task.dag.partial_subset( + task.downstream_task_ids, + include_downstream=False, + include_upstream=False, + include_direct_upstream=True, + ) + + dag_run.dag = partial_dag + info = dag_run.task_instance_scheduling_decisions(session) + + skippable_task_ids = { + task_id for task_id in partial_dag.task_ids if task_id not in task.downstream_task_ids + } + + schedulable_tis = [ti for ti in info.schedulable_tis if ti.task_id not in skippable_task_ids] + for schedulable_ti in schedulable_tis: + if not hasattr(schedulable_ti, "task"): + schedulable_ti.task = task.dag.get_task(schedulable_ti.task_id) + + num = dag_run.schedule_tis(schedulable_tis) + self.log.info("%d downstream tasks scheduled from follow-on schedule check", num) + + session.commit() + except OperationalError as e: + # Any kind of DB error here is _non fatal_ as this block is just an optimisation. + self.log.info( + "Skipping mini scheduling run due to exception: %s", + e.statement, + exc_info=True, + ) + session.rollback() diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 6e0c011d3dbbb..4fd72e7c44e71 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -35,7 +35,6 @@ import pendulum from jinja2 import TemplateAssertionError, UndefinedError from sqlalchemy import Column, Float, Index, Integer, PickleType, String, and_, func, or_ -from sqlalchemy.exc import OperationalError from sqlalchemy.orm import reconstructor, relationship from sqlalchemy.orm.session import Session from sqlalchemy.sql.elements import BooleanClauseList @@ -70,7 +69,7 @@ from airflow.utils.operator_helpers import context_to_airflow_vars from airflow.utils.platform import getuser from airflow.utils.session import provide_session -from airflow.utils.sqlalchemy import UtcDateTime, with_row_locks +from airflow.utils.sqlalchemy import UtcDateTime from airflow.utils.state import State from airflow.utils.timeout import timeout @@ -1200,62 +1199,6 @@ def _run_raw_task( session.commit() - if not test_mode: - self._run_mini_scheduler_on_child_tasks(session) - - @provide_session - @Sentry.enrich_errors - def _run_mini_scheduler_on_child_tasks(self, session=None) -> None: - if conf.getboolean('scheduler', 'schedule_after_task_execution', fallback=True): - from airflow.models.dagrun import DagRun # Avoid circular import - - try: - # Re-select the row with a lock - dag_run = with_row_locks( - session.query(DagRun).filter_by( - dag_id=self.dag_id, - execution_date=self.execution_date, - ), - session=session, - ).one() - - # Get a partial dag with just the specific tasks we want to - # examine. In order for dep checks to work correctly, we - # include ourself (so TriggerRuleDep can check the state of the - # task we just executed) - partial_dag = self.task.dag.partial_subset( - self.task.downstream_task_ids, - include_downstream=False, - include_upstream=False, - include_direct_upstream=True, - ) - - dag_run.dag = partial_dag - info = dag_run.task_instance_scheduling_decisions(session) - - skippable_task_ids = { - task_id - for task_id in partial_dag.task_ids - if task_id not in self.task.downstream_task_ids - } - - schedulable_tis = [ti for ti in info.schedulable_tis if ti.task_id not in skippable_task_ids] - for schedulable_ti in schedulable_tis: - if not hasattr(schedulable_ti, "task"): - schedulable_ti.task = self.task.dag.get_task(schedulable_ti.task_id) - - num = dag_run.schedule_tis(schedulable_tis) - self.log.info("%d downstream tasks scheduled from follow-on schedule check", num) - - session.commit() - except OperationalError as e: - # Any kind of DB error here is _non fatal_ as this block is just an optimisation. - self.log.info( - f"Skipping mini scheduling run due to exception: {e.statement}", - exc_info=True, - ) - session.rollback() - def _prepare_and_execute_task_with_callbacks(self, context, task): """Prepare Task for Execution""" from airflow.models.renderedtifields import RenderedTaskInstanceFields @@ -1408,6 +1351,7 @@ def run( # pylint: disable=too-many-arguments session=session, ) if not res: + self.log.info("CHECK AND CHANGE") return try: diff --git a/tests/cli/commands/test_task_command.py b/tests/cli/commands/test_task_command.py index f50ddbc5e4f17..2b93e6d5a96d6 100644 --- a/tests/cli/commands/test_task_command.py +++ b/tests/cli/commands/test_task_command.py @@ -71,8 +71,7 @@ def test_cli_list_tasks(self): args = self.parser.parse_args(['tasks', 'list', 'example_bash_operator', '--tree']) task_command.task_list(args) - @mock.patch("airflow.models.taskinstance.TaskInstance._run_mini_scheduler_on_child_tasks") - def test_test(self, mock_run_mini_scheduler): + def test_test(self): """Test the `airflow test` command""" args = self.parser.parse_args( ["tasks", "test", "example_python_operator", 'print_the_context', '2018-01-01'] @@ -81,7 +80,6 @@ def test_test(self, mock_run_mini_scheduler): with redirect_stdout(io.StringIO()) as stdout: task_command.task_test(args) - mock_run_mini_scheduler.assert_not_called() # Check that prints, and log messages, are shown assert "'example_python_operator__print_the_context__20180101'" in stdout.getvalue() diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py index 9047f8aa18938..6639775156853 100644 --- a/tests/jobs/test_local_task_job.py +++ b/tests/jobs/test_local_task_job.py @@ -33,7 +33,8 @@ from airflow.exceptions import AirflowException, AirflowFailException from airflow.executors.sequential_executor import SequentialExecutor from airflow.jobs.local_task_job import LocalTaskJob -from airflow.models.dag import DAG +from airflow.jobs.scheduler_job import SchedulerJob +from airflow.models.dag import DAG, DagModel from airflow.models.dagbag import DagBag from airflow.models.taskinstance import TaskInstance from airflow.operators.dummy import DummyOperator @@ -44,8 +45,9 @@ from airflow.utils.session import create_session from airflow.utils.state import State from airflow.utils.timeout import timeout +from tests.test_utils import db from tests.test_utils.asserts import assert_queries_count -from tests.test_utils.db import clear_db_jobs, clear_db_runs +from tests.test_utils.config import conf_vars from tests.test_utils.mock_executor import MockExecutor DEFAULT_DATE = timezone.datetime(2016, 1, 1) @@ -54,15 +56,25 @@ class TestLocalTaskJob(unittest.TestCase): def setUp(self): - clear_db_jobs() - clear_db_runs() + db.clear_db_dags() + db.clear_db_jobs() + db.clear_db_runs() + db.clear_db_task_fail() patcher = patch('airflow.jobs.base_job.sleep') self.addCleanup(patcher.stop) self.mock_base_job_sleep = patcher.start() def tearDown(self) -> None: - clear_db_jobs() - clear_db_runs() + db.clear_db_dags() + db.clear_db_jobs() + db.clear_db_runs() + db.clear_db_task_fail() + + def validate_ti_states(self, dag_run, ti_state_mapping, error_message): + for task_id, expected_state in ti_state_mapping.items(): + task_instance = dag_run.get_task_instance(task_id=task_id) + task_instance.refresh_from_db() + assert task_instance.state == expected_state, error_message def test_localtaskjob_essential_attr(self): """ @@ -563,20 +575,122 @@ def task_function(ti): if ti.state == State.RUNNING and ti.pid is not None: break time.sleep(0.2) - assert ti.state == State.RUNNING assert ti.pid is not None + assert ti.state == State.RUNNING os.kill(ti.pid, signal_type) process.join(timeout=10) assert failure_callback_called.value == 1 assert task_terminated_externally.value == 1 assert not process.is_alive() + @parameterized.expand( + [ + ( + {('scheduler', 'schedule_after_task_execution'): 'True'}, + {'A': 'B', 'B': 'C'}, + {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE}, + {'A': State.SUCCESS, 'B': State.SCHEDULED, 'C': State.NONE}, + {'A': State.SUCCESS, 'B': State.SUCCESS, 'C': State.SCHEDULED}, + "A -> B -> C, with fast-follow ON when A runs, B should be QUEUED. Same for B and C.", + ), + ( + {('scheduler', 'schedule_after_task_execution'): 'False'}, + {'A': 'B', 'B': 'C'}, + {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE}, + {'A': State.SUCCESS, 'B': State.NONE, 'C': State.NONE}, + None, + "A -> B -> C, with fast-follow OFF, when A runs, B shouldn't be QUEUED.", + ), + ( + {('scheduler', 'schedule_after_task_execution'): 'True'}, + {'A': 'B', 'C': 'B', 'D': 'C'}, + {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE, 'D': State.NONE}, + {'A': State.SUCCESS, 'B': State.NONE, 'C': State.NONE, 'D': State.NONE}, + None, + "D -> C -> B & A -> B, when A runs but C isn't QUEUED yet, B shouldn't be QUEUED.", + ), + ( + {('scheduler', 'schedule_after_task_execution'): 'True'}, + {'A': 'C', 'B': 'C'}, + {'A': State.QUEUED, 'B': State.FAILED, 'C': State.NONE}, + {'A': State.SUCCESS, 'B': State.FAILED, 'C': State.UPSTREAM_FAILED}, + None, + "A -> C & B -> C, when A is QUEUED but B has FAILED, C is marked UPSTREAM_FAILED.", + ), + ] + ) + def test_fast_follow( + self, conf, dependencies, init_state, first_run_state, second_run_state, error_message + ): + # pylint: disable=too-many-locals + with conf_vars(conf): + session = settings.Session() + + dag = DAG('test_dagrun_fast_follow', start_date=DEFAULT_DATE) + + dag_model = DagModel( + dag_id=dag.dag_id, + next_dagrun=dag.start_date, + is_active=True, + ) + session.add(dag_model) + session.flush() + + python_callable = lambda: True + with dag: + task_a = PythonOperator(task_id='A', python_callable=python_callable) + task_b = PythonOperator(task_id='B', python_callable=python_callable) + task_c = PythonOperator(task_id='C', python_callable=python_callable) + if 'D' in init_state: + task_d = PythonOperator(task_id='D', python_callable=python_callable) + for upstream, downstream in dependencies.items(): + dag.set_dependency(upstream, downstream) + + scheduler_job = SchedulerJob(subdir=os.devnull) + scheduler_job.dagbag.bag_dag(dag, root_dag=dag) + + dag_run = dag.create_dagrun(run_id='test_dagrun_fast_follow', state=State.RUNNING) + + task_instance_a = TaskInstance(task_a, dag_run.execution_date, init_state['A']) + + task_instance_b = TaskInstance(task_b, dag_run.execution_date, init_state['B']) + + task_instance_c = TaskInstance(task_c, dag_run.execution_date, init_state['C']) + + if 'D' in init_state: + task_instance_d = TaskInstance(task_d, dag_run.execution_date, init_state['D']) + session.merge(task_instance_d) + + session.merge(task_instance_a) + session.merge(task_instance_b) + session.merge(task_instance_c) + session.flush() + + job1 = LocalTaskJob( + task_instance=task_instance_a, ignore_ti_state=True, executor=SequentialExecutor() + ) + job1.task_runner = StandardTaskRunner(job1) + + job2 = LocalTaskJob( + task_instance=task_instance_b, ignore_ti_state=True, executor=SequentialExecutor() + ) + job2.task_runner = StandardTaskRunner(job2) + + settings.engine.dispose() + job1.run() + self.validate_ti_states(dag_run, first_run_state, error_message) + if second_run_state: + job2.run() + self.validate_ti_states(dag_run, second_run_state, error_message) + if scheduler_job.processor_agent: + scheduler_job.processor_agent.end() + @pytest.fixture() def clean_db_helper(): yield - clear_db_jobs() - clear_db_runs() + db.clear_db_jobs() + db.clear_db_runs() @pytest.mark.usefixtures("clean_db_helper") diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index cd8d301cf3f79..9448ae291d9bc 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -38,10 +38,8 @@ AirflowSensorTimeout, AirflowSkipException, ) -from airflow.jobs.scheduler_job import SchedulerJob from airflow.models import ( DAG, - DagModel, DagRun, Pool, RenderedTaskInstanceFields, @@ -1916,107 +1914,6 @@ def test_get_rendered_k8s_spec(self): with create_session() as session: session.query(RenderedTaskInstanceFields).delete() - def validate_ti_states(self, dag_run, ti_state_mapping, error_message): - for task_id, expected_state in ti_state_mapping.items(): - task_instance = dag_run.get_task_instance(task_id=task_id) - assert task_instance.state == expected_state, error_message - - @parameterized.expand( - [ - ( - {('scheduler', 'schedule_after_task_execution'): 'True'}, - {'A': 'B', 'B': 'C'}, - {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE}, - {'A': State.SUCCESS, 'B': State.SCHEDULED, 'C': State.NONE}, - {'A': State.SUCCESS, 'B': State.SUCCESS, 'C': State.SCHEDULED}, - "A -> B -> C, with fast-follow ON when A runs, B should be QUEUED. Same for B and C.", - ), - ( - {('scheduler', 'schedule_after_task_execution'): 'False'}, - {'A': 'B', 'B': 'C'}, - {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE}, - {'A': State.SUCCESS, 'B': State.NONE, 'C': State.NONE}, - None, - "A -> B -> C, with fast-follow OFF, when A runs, B shouldn't be QUEUED.", - ), - ( - {('scheduler', 'schedule_after_task_execution'): 'True'}, - {'A': 'B', 'C': 'B', 'D': 'C'}, - {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE, 'D': State.NONE}, - {'A': State.SUCCESS, 'B': State.NONE, 'C': State.NONE, 'D': State.NONE}, - None, - "D -> C -> B & A -> B, when A runs but C isn't QUEUED yet, B shouldn't be QUEUED.", - ), - ( - {('scheduler', 'schedule_after_task_execution'): 'True'}, - {'A': 'C', 'B': 'C'}, - {'A': State.QUEUED, 'B': State.FAILED, 'C': State.NONE}, - {'A': State.SUCCESS, 'B': State.FAILED, 'C': State.UPSTREAM_FAILED}, - None, - "A -> C & B -> C, when A is QUEUED but B has FAILED, C is marked UPSTREAM_FAILED.", - ), - ] - ) - def test_fast_follow( - self, conf, dependencies, init_state, first_run_state, second_run_state, error_message - ): - with conf_vars(conf): - session = settings.Session() - - dag = DAG('test_dagrun_fast_follow', start_date=DEFAULT_DATE) - - dag_model = DagModel( - dag_id=dag.dag_id, - next_dagrun=dag.start_date, - is_active=True, - ) - session.add(dag_model) - session.flush() - - python_callable = lambda: True - with dag: - task_a = PythonOperator(task_id='A', python_callable=python_callable) - task_b = PythonOperator(task_id='B', python_callable=python_callable) - task_c = PythonOperator(task_id='C', python_callable=python_callable) - if 'D' in init_state: - task_d = PythonOperator(task_id='D', python_callable=python_callable) - for upstream, downstream in dependencies.items(): - dag.set_dependency(upstream, downstream) - - scheduler_job = SchedulerJob(subdir=os.devnull) - scheduler_job.dagbag.bag_dag(dag, root_dag=dag) - - dag_run = dag.create_dagrun(run_id='test_dagrun_fast_follow', state=State.RUNNING) - - task_instance_a = dag_run.get_task_instance(task_id=task_a.task_id) - task_instance_a.task = task_a - task_instance_a.set_state(init_state['A']) - - task_instance_b = dag_run.get_task_instance(task_id=task_b.task_id) - task_instance_b.task = task_b - task_instance_b.set_state(init_state['B']) - - task_instance_c = dag_run.get_task_instance(task_id=task_c.task_id) - task_instance_c.task = task_c - task_instance_c.set_state(init_state['C']) - - if 'D' in init_state: - task_instance_d = dag_run.get_task_instance(task_id=task_d.task_id) - task_instance_d.task = task_d - task_instance_d.state = init_state['D'] - - session.commit() - task_instance_a.run() - - self.validate_ti_states(dag_run, first_run_state, error_message) - - if second_run_state: - scheduler_job._critical_section_execute_task_instances(session=session) - task_instance_b.run() - self.validate_ti_states(dag_run, second_run_state, error_message) - if scheduler_job.processor_agent: - scheduler_job.processor_agent.end() - def test_set_state_up_for_retry(self): dag = DAG('dag', start_date=DEFAULT_DATE) op1 = DummyOperator(task_id='op_1', owner='test', dag=dag)