diff --git a/airflow/jobs/local_task_job.py b/airflow/jobs/local_task_job.py index 9e68450cf2700..35d0422e7d4b1 100644 --- a/airflow/jobs/local_task_job.py +++ b/airflow/jobs/local_task_job.py @@ -16,19 +16,23 @@ # specific language governing permissions and limitations # under the License. # - import signal from typing import Optional +from sqlalchemy.exc import OperationalError + from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.jobs.base_job import BaseJob +from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstance +from airflow.sentry import Sentry from airflow.stats import Stats from airflow.task.task_runner import get_task_runner from airflow.utils import timezone from airflow.utils.net import get_hostname from airflow.utils.session import provide_session +from airflow.utils.sqlalchemy import with_row_locks from airflow.utils.state import State @@ -160,6 +164,9 @@ def handle_task_exit(self, return_code: int) -> None: if self.task_instance.state != State.SUCCESS: error = self.task_runner.deserialize_run_error() self.task_instance._run_finished_callback(error=error) # pylint: disable=protected-access + if not self.task_instance.test_mode: + if conf.getboolean('scheduler', 'schedule_after_task_execution', fallback=True): + self._run_mini_scheduler_on_child_tasks() def on_kill(self): self.task_runner.terminate() @@ -206,3 +213,54 @@ def heartbeat_callback(self, session=None): error = self.task_runner.deserialize_run_error() or "task marked as failed externally" ti._run_finished_callback(error=error) # pylint: disable=protected-access self.terminating = True + + @provide_session + @Sentry.enrich_errors + def _run_mini_scheduler_on_child_tasks(self, session=None) -> None: + try: + # Re-select the row with a lock + dag_run = with_row_locks( + session.query(DagRun).filter_by( + dag_id=self.dag_id, + execution_date=self.task_instance.execution_date, + ), + session=session, + ).one() + + # Get a partial dag with just the specific tasks we want to + # examine. In order for dep checks to work correctly, we + # include ourself (so TriggerRuleDep can check the state of the + # task we just executed) + task = self.task_instance.task + + partial_dag = task.dag.partial_subset( + task.downstream_task_ids, + include_downstream=False, + include_upstream=False, + include_direct_upstream=True, + ) + + dag_run.dag = partial_dag + info = dag_run.task_instance_scheduling_decisions(session) + + skippable_task_ids = { + task_id for task_id in partial_dag.task_ids if task_id not in task.downstream_task_ids + } + + schedulable_tis = [ti for ti in info.schedulable_tis if ti.task_id not in skippable_task_ids] + for schedulable_ti in schedulable_tis: + if not hasattr(schedulable_ti, "task"): + schedulable_ti.task = task.dag.get_task(schedulable_ti.task_id) + + num = dag_run.schedule_tis(schedulable_tis) + self.log.info("%d downstream tasks scheduled from follow-on schedule check", num) + + session.commit() + except OperationalError as e: + # Any kind of DB error here is _non fatal_ as this block is just an optimisation. + self.log.info( + "Skipping mini scheduling run due to exception: %s", + e.statement, + exc_info=True, + ) + session.rollback() diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 6e0c011d3dbbb..4fd72e7c44e71 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -35,7 +35,6 @@ import pendulum from jinja2 import TemplateAssertionError, UndefinedError from sqlalchemy import Column, Float, Index, Integer, PickleType, String, and_, func, or_ -from sqlalchemy.exc import OperationalError from sqlalchemy.orm import reconstructor, relationship from sqlalchemy.orm.session import Session from sqlalchemy.sql.elements import BooleanClauseList @@ -70,7 +69,7 @@ from airflow.utils.operator_helpers import context_to_airflow_vars from airflow.utils.platform import getuser from airflow.utils.session import provide_session -from airflow.utils.sqlalchemy import UtcDateTime, with_row_locks +from airflow.utils.sqlalchemy import UtcDateTime from airflow.utils.state import State from airflow.utils.timeout import timeout @@ -1200,62 +1199,6 @@ def _run_raw_task( session.commit() - if not test_mode: - self._run_mini_scheduler_on_child_tasks(session) - - @provide_session - @Sentry.enrich_errors - def _run_mini_scheduler_on_child_tasks(self, session=None) -> None: - if conf.getboolean('scheduler', 'schedule_after_task_execution', fallback=True): - from airflow.models.dagrun import DagRun # Avoid circular import - - try: - # Re-select the row with a lock - dag_run = with_row_locks( - session.query(DagRun).filter_by( - dag_id=self.dag_id, - execution_date=self.execution_date, - ), - session=session, - ).one() - - # Get a partial dag with just the specific tasks we want to - # examine. In order for dep checks to work correctly, we - # include ourself (so TriggerRuleDep can check the state of the - # task we just executed) - partial_dag = self.task.dag.partial_subset( - self.task.downstream_task_ids, - include_downstream=False, - include_upstream=False, - include_direct_upstream=True, - ) - - dag_run.dag = partial_dag - info = dag_run.task_instance_scheduling_decisions(session) - - skippable_task_ids = { - task_id - for task_id in partial_dag.task_ids - if task_id not in self.task.downstream_task_ids - } - - schedulable_tis = [ti for ti in info.schedulable_tis if ti.task_id not in skippable_task_ids] - for schedulable_ti in schedulable_tis: - if not hasattr(schedulable_ti, "task"): - schedulable_ti.task = self.task.dag.get_task(schedulable_ti.task_id) - - num = dag_run.schedule_tis(schedulable_tis) - self.log.info("%d downstream tasks scheduled from follow-on schedule check", num) - - session.commit() - except OperationalError as e: - # Any kind of DB error here is _non fatal_ as this block is just an optimisation. - self.log.info( - f"Skipping mini scheduling run due to exception: {e.statement}", - exc_info=True, - ) - session.rollback() - def _prepare_and_execute_task_with_callbacks(self, context, task): """Prepare Task for Execution""" from airflow.models.renderedtifields import RenderedTaskInstanceFields @@ -1408,6 +1351,7 @@ def run( # pylint: disable=too-many-arguments session=session, ) if not res: + self.log.info("CHECK AND CHANGE") return try: diff --git a/tests/cli/commands/test_task_command.py b/tests/cli/commands/test_task_command.py index f50ddbc5e4f17..2b93e6d5a96d6 100644 --- a/tests/cli/commands/test_task_command.py +++ b/tests/cli/commands/test_task_command.py @@ -71,8 +71,7 @@ def test_cli_list_tasks(self): args = self.parser.parse_args(['tasks', 'list', 'example_bash_operator', '--tree']) task_command.task_list(args) - @mock.patch("airflow.models.taskinstance.TaskInstance._run_mini_scheduler_on_child_tasks") - def test_test(self, mock_run_mini_scheduler): + def test_test(self): """Test the `airflow test` command""" args = self.parser.parse_args( ["tasks", "test", "example_python_operator", 'print_the_context', '2018-01-01'] @@ -81,7 +80,6 @@ def test_test(self, mock_run_mini_scheduler): with redirect_stdout(io.StringIO()) as stdout: task_command.task_test(args) - mock_run_mini_scheduler.assert_not_called() # Check that prints, and log messages, are shown assert "'example_python_operator__print_the_context__20180101'" in stdout.getvalue() diff --git a/tests/jobs/test_local_task_job.py b/tests/jobs/test_local_task_job.py index 9047f8aa18938..6639775156853 100644 --- a/tests/jobs/test_local_task_job.py +++ b/tests/jobs/test_local_task_job.py @@ -33,7 +33,8 @@ from airflow.exceptions import AirflowException, AirflowFailException from airflow.executors.sequential_executor import SequentialExecutor from airflow.jobs.local_task_job import LocalTaskJob -from airflow.models.dag import DAG +from airflow.jobs.scheduler_job import SchedulerJob +from airflow.models.dag import DAG, DagModel from airflow.models.dagbag import DagBag from airflow.models.taskinstance import TaskInstance from airflow.operators.dummy import DummyOperator @@ -44,8 +45,9 @@ from airflow.utils.session import create_session from airflow.utils.state import State from airflow.utils.timeout import timeout +from tests.test_utils import db from tests.test_utils.asserts import assert_queries_count -from tests.test_utils.db import clear_db_jobs, clear_db_runs +from tests.test_utils.config import conf_vars from tests.test_utils.mock_executor import MockExecutor DEFAULT_DATE = timezone.datetime(2016, 1, 1) @@ -54,15 +56,25 @@ class TestLocalTaskJob(unittest.TestCase): def setUp(self): - clear_db_jobs() - clear_db_runs() + db.clear_db_dags() + db.clear_db_jobs() + db.clear_db_runs() + db.clear_db_task_fail() patcher = patch('airflow.jobs.base_job.sleep') self.addCleanup(patcher.stop) self.mock_base_job_sleep = patcher.start() def tearDown(self) -> None: - clear_db_jobs() - clear_db_runs() + db.clear_db_dags() + db.clear_db_jobs() + db.clear_db_runs() + db.clear_db_task_fail() + + def validate_ti_states(self, dag_run, ti_state_mapping, error_message): + for task_id, expected_state in ti_state_mapping.items(): + task_instance = dag_run.get_task_instance(task_id=task_id) + task_instance.refresh_from_db() + assert task_instance.state == expected_state, error_message def test_localtaskjob_essential_attr(self): """ @@ -563,20 +575,122 @@ def task_function(ti): if ti.state == State.RUNNING and ti.pid is not None: break time.sleep(0.2) - assert ti.state == State.RUNNING assert ti.pid is not None + assert ti.state == State.RUNNING os.kill(ti.pid, signal_type) process.join(timeout=10) assert failure_callback_called.value == 1 assert task_terminated_externally.value == 1 assert not process.is_alive() + @parameterized.expand( + [ + ( + {('scheduler', 'schedule_after_task_execution'): 'True'}, + {'A': 'B', 'B': 'C'}, + {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE}, + {'A': State.SUCCESS, 'B': State.SCHEDULED, 'C': State.NONE}, + {'A': State.SUCCESS, 'B': State.SUCCESS, 'C': State.SCHEDULED}, + "A -> B -> C, with fast-follow ON when A runs, B should be QUEUED. Same for B and C.", + ), + ( + {('scheduler', 'schedule_after_task_execution'): 'False'}, + {'A': 'B', 'B': 'C'}, + {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE}, + {'A': State.SUCCESS, 'B': State.NONE, 'C': State.NONE}, + None, + "A -> B -> C, with fast-follow OFF, when A runs, B shouldn't be QUEUED.", + ), + ( + {('scheduler', 'schedule_after_task_execution'): 'True'}, + {'A': 'B', 'C': 'B', 'D': 'C'}, + {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE, 'D': State.NONE}, + {'A': State.SUCCESS, 'B': State.NONE, 'C': State.NONE, 'D': State.NONE}, + None, + "D -> C -> B & A -> B, when A runs but C isn't QUEUED yet, B shouldn't be QUEUED.", + ), + ( + {('scheduler', 'schedule_after_task_execution'): 'True'}, + {'A': 'C', 'B': 'C'}, + {'A': State.QUEUED, 'B': State.FAILED, 'C': State.NONE}, + {'A': State.SUCCESS, 'B': State.FAILED, 'C': State.UPSTREAM_FAILED}, + None, + "A -> C & B -> C, when A is QUEUED but B has FAILED, C is marked UPSTREAM_FAILED.", + ), + ] + ) + def test_fast_follow( + self, conf, dependencies, init_state, first_run_state, second_run_state, error_message + ): + # pylint: disable=too-many-locals + with conf_vars(conf): + session = settings.Session() + + dag = DAG('test_dagrun_fast_follow', start_date=DEFAULT_DATE) + + dag_model = DagModel( + dag_id=dag.dag_id, + next_dagrun=dag.start_date, + is_active=True, + ) + session.add(dag_model) + session.flush() + + python_callable = lambda: True + with dag: + task_a = PythonOperator(task_id='A', python_callable=python_callable) + task_b = PythonOperator(task_id='B', python_callable=python_callable) + task_c = PythonOperator(task_id='C', python_callable=python_callable) + if 'D' in init_state: + task_d = PythonOperator(task_id='D', python_callable=python_callable) + for upstream, downstream in dependencies.items(): + dag.set_dependency(upstream, downstream) + + scheduler_job = SchedulerJob(subdir=os.devnull) + scheduler_job.dagbag.bag_dag(dag, root_dag=dag) + + dag_run = dag.create_dagrun(run_id='test_dagrun_fast_follow', state=State.RUNNING) + + task_instance_a = TaskInstance(task_a, dag_run.execution_date, init_state['A']) + + task_instance_b = TaskInstance(task_b, dag_run.execution_date, init_state['B']) + + task_instance_c = TaskInstance(task_c, dag_run.execution_date, init_state['C']) + + if 'D' in init_state: + task_instance_d = TaskInstance(task_d, dag_run.execution_date, init_state['D']) + session.merge(task_instance_d) + + session.merge(task_instance_a) + session.merge(task_instance_b) + session.merge(task_instance_c) + session.flush() + + job1 = LocalTaskJob( + task_instance=task_instance_a, ignore_ti_state=True, executor=SequentialExecutor() + ) + job1.task_runner = StandardTaskRunner(job1) + + job2 = LocalTaskJob( + task_instance=task_instance_b, ignore_ti_state=True, executor=SequentialExecutor() + ) + job2.task_runner = StandardTaskRunner(job2) + + settings.engine.dispose() + job1.run() + self.validate_ti_states(dag_run, first_run_state, error_message) + if second_run_state: + job2.run() + self.validate_ti_states(dag_run, second_run_state, error_message) + if scheduler_job.processor_agent: + scheduler_job.processor_agent.end() + @pytest.fixture() def clean_db_helper(): yield - clear_db_jobs() - clear_db_runs() + db.clear_db_jobs() + db.clear_db_runs() @pytest.mark.usefixtures("clean_db_helper") diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index cd8d301cf3f79..9448ae291d9bc 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -38,10 +38,8 @@ AirflowSensorTimeout, AirflowSkipException, ) -from airflow.jobs.scheduler_job import SchedulerJob from airflow.models import ( DAG, - DagModel, DagRun, Pool, RenderedTaskInstanceFields, @@ -1916,107 +1914,6 @@ def test_get_rendered_k8s_spec(self): with create_session() as session: session.query(RenderedTaskInstanceFields).delete() - def validate_ti_states(self, dag_run, ti_state_mapping, error_message): - for task_id, expected_state in ti_state_mapping.items(): - task_instance = dag_run.get_task_instance(task_id=task_id) - assert task_instance.state == expected_state, error_message - - @parameterized.expand( - [ - ( - {('scheduler', 'schedule_after_task_execution'): 'True'}, - {'A': 'B', 'B': 'C'}, - {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE}, - {'A': State.SUCCESS, 'B': State.SCHEDULED, 'C': State.NONE}, - {'A': State.SUCCESS, 'B': State.SUCCESS, 'C': State.SCHEDULED}, - "A -> B -> C, with fast-follow ON when A runs, B should be QUEUED. Same for B and C.", - ), - ( - {('scheduler', 'schedule_after_task_execution'): 'False'}, - {'A': 'B', 'B': 'C'}, - {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE}, - {'A': State.SUCCESS, 'B': State.NONE, 'C': State.NONE}, - None, - "A -> B -> C, with fast-follow OFF, when A runs, B shouldn't be QUEUED.", - ), - ( - {('scheduler', 'schedule_after_task_execution'): 'True'}, - {'A': 'B', 'C': 'B', 'D': 'C'}, - {'A': State.QUEUED, 'B': State.NONE, 'C': State.NONE, 'D': State.NONE}, - {'A': State.SUCCESS, 'B': State.NONE, 'C': State.NONE, 'D': State.NONE}, - None, - "D -> C -> B & A -> B, when A runs but C isn't QUEUED yet, B shouldn't be QUEUED.", - ), - ( - {('scheduler', 'schedule_after_task_execution'): 'True'}, - {'A': 'C', 'B': 'C'}, - {'A': State.QUEUED, 'B': State.FAILED, 'C': State.NONE}, - {'A': State.SUCCESS, 'B': State.FAILED, 'C': State.UPSTREAM_FAILED}, - None, - "A -> C & B -> C, when A is QUEUED but B has FAILED, C is marked UPSTREAM_FAILED.", - ), - ] - ) - def test_fast_follow( - self, conf, dependencies, init_state, first_run_state, second_run_state, error_message - ): - with conf_vars(conf): - session = settings.Session() - - dag = DAG('test_dagrun_fast_follow', start_date=DEFAULT_DATE) - - dag_model = DagModel( - dag_id=dag.dag_id, - next_dagrun=dag.start_date, - is_active=True, - ) - session.add(dag_model) - session.flush() - - python_callable = lambda: True - with dag: - task_a = PythonOperator(task_id='A', python_callable=python_callable) - task_b = PythonOperator(task_id='B', python_callable=python_callable) - task_c = PythonOperator(task_id='C', python_callable=python_callable) - if 'D' in init_state: - task_d = PythonOperator(task_id='D', python_callable=python_callable) - for upstream, downstream in dependencies.items(): - dag.set_dependency(upstream, downstream) - - scheduler_job = SchedulerJob(subdir=os.devnull) - scheduler_job.dagbag.bag_dag(dag, root_dag=dag) - - dag_run = dag.create_dagrun(run_id='test_dagrun_fast_follow', state=State.RUNNING) - - task_instance_a = dag_run.get_task_instance(task_id=task_a.task_id) - task_instance_a.task = task_a - task_instance_a.set_state(init_state['A']) - - task_instance_b = dag_run.get_task_instance(task_id=task_b.task_id) - task_instance_b.task = task_b - task_instance_b.set_state(init_state['B']) - - task_instance_c = dag_run.get_task_instance(task_id=task_c.task_id) - task_instance_c.task = task_c - task_instance_c.set_state(init_state['C']) - - if 'D' in init_state: - task_instance_d = dag_run.get_task_instance(task_id=task_d.task_id) - task_instance_d.task = task_d - task_instance_d.state = init_state['D'] - - session.commit() - task_instance_a.run() - - self.validate_ti_states(dag_run, first_run_state, error_message) - - if second_run_state: - scheduler_job._critical_section_execute_task_instances(session=session) - task_instance_b.run() - self.validate_ti_states(dag_run, second_run_state, error_message) - if scheduler_job.processor_agent: - scheduler_job.processor_agent.end() - def test_set_state_up_for_retry(self): dag = DAG('dag', start_date=DEFAULT_DATE) op1 = DummyOperator(task_id='op_1', owner='test', dag=dag)