diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py index 9caa8bb4bda7f..982aa31fd50b7 100644 --- a/airflow/cli/commands/task_command.py +++ b/airflow/cli/commands/task_command.py @@ -47,7 +47,6 @@ from airflow.utils import cli as cli_utils from airflow.utils.cli import ( get_dag, - get_dag_by_deserialization, get_dag_by_file_location, get_dag_by_pickle, get_dags, @@ -364,14 +363,7 @@ def task_run(args, dag=None): print(f'Loading pickle id: {args.pickle}') dag = get_dag_by_pickle(args.pickle) elif not dag: - if args.local: - try: - dag = get_dag_by_deserialization(args.dag_id) - except AirflowException: - print(f'DAG {args.dag_id} does not exist in the database, trying to parse the dag_file') - dag = get_dag(args.subdir, args.dag_id) - else: - dag = get_dag(args.subdir, args.dag_id) + dag = get_dag(args.subdir, args.dag_id, include_examples=False) else: # Use DAG from parameter pass diff --git a/airflow/task/task_runner/standard_task_runner.py b/airflow/task/task_runner/standard_task_runner.py index 3c13a28df4aaf..27fd11b1b1ff1 100644 --- a/airflow/task/task_runner/standard_task_runner.py +++ b/airflow/task/task_runner/standard_task_runner.py @@ -36,6 +36,7 @@ class StandardTaskRunner(BaseTaskRunner): def __init__(self, local_task_job): super().__init__(local_task_job) self._rc = None + self.dag = local_task_job.task_instance.task.dag def start(self): if CAN_FORK and not self.run_as_user: @@ -64,7 +65,6 @@ def _start_by_fork(self): from airflow import settings from airflow.cli.cli_parser import get_parser from airflow.sentry import Sentry - from airflow.utils.cli import get_dag # Force a new SQLAlchemy session. We can't share open DB handles # between process. The cli code will re-create this as part of its @@ -92,10 +92,8 @@ def _start_by_fork(self): dag_id=self._task_instance.dag_id, task_id=self._task_instance.task_id, ): - # parse dag file since `airflow tasks run --local` does not parse dag file - dag = get_dag(args.subdir, args.dag_id) - args.func(args, dag=dag) - return_code = 0 + args.func(args, dag=self.dag) + return_code = 0 except Exception as exc: return_code = 1 diff --git a/airflow/utils/cli.py b/airflow/utils/cli.py index 87313f46f5561..522bf963e29f0 100644 --- a/airflow/utils/cli.py +++ b/airflow/utils/cli.py @@ -33,6 +33,7 @@ from typing import TYPE_CHECKING, Callable, TypeVar, cast from airflow import settings +from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.utils import cli_action_loggers from airflow.utils.log.non_caching_file_handler import NonCachingFileHandler @@ -205,7 +206,9 @@ def _search_for_dag_file(val: str | None) -> str | None: return None -def get_dag(subdir: str | None, dag_id: str) -> DAG: +def get_dag( + subdir: str | None, dag_id: str, include_examples=conf.getboolean('core', 'LOAD_EXAMPLES') +) -> DAG: """ Returns DAG of a given dag_id @@ -216,11 +219,11 @@ def get_dag(subdir: str | None, dag_id: str) -> DAG: from airflow.models import DagBag first_path = process_subdir(subdir) - dagbag = DagBag(first_path) + dagbag = DagBag(first_path, include_examples=include_examples) if dag_id not in dagbag.dags: fallback_path = _search_for_dag_file(subdir) or settings.DAGS_FOLDER logger.warning("Dag %r not found in path %s; trying path %s", dag_id, first_path, fallback_path) - dagbag = DagBag(dag_folder=fallback_path) + dagbag = DagBag(dag_folder=fallback_path, include_examples=include_examples) if dag_id not in dagbag.dags: raise AirflowException( f"Dag {dag_id!r} could not be found; either it does not exist or it failed to parse." @@ -228,16 +231,6 @@ def get_dag(subdir: str | None, dag_id: str) -> DAG: return dagbag.dags[dag_id] -def get_dag_by_deserialization(dag_id: str) -> DAG: - from airflow.models.serialized_dag import SerializedDagModel - - dag_model = SerializedDagModel.get(dag_id) - if dag_model is None: - raise AirflowException(f"Serialized DAG: {dag_id} could not be found") - - return dag_model.dag - - def get_dags(subdir: str | None, dag_id: str, use_regex: bool = False): """Returns DAG(s) matching a given regex or dag_id""" from airflow.models import DagBag diff --git a/tests/cli/commands/test_task_command.py b/tests/cli/commands/test_task_command.py index 03b9259f8db28..802140755f1fe 100644 --- a/tests/cli/commands/test_task_command.py +++ b/tests/cli/commands/test_task_command.py @@ -159,38 +159,6 @@ def test_test_filters_secrets(self, capsys): task_command.task_test(args) assert capsys.readouterr().out.endswith(f"{not_password}\n") - @mock.patch("airflow.cli.commands.task_command.get_dag_by_deserialization") - @mock.patch("airflow.cli.commands.task_command.LocalTaskJob") - def test_run_get_serialized_dag(self, mock_local_job, mock_get_dag_by_deserialization): - """ - Test using serialized dag for local task_run - """ - task_id = self.dag.task_ids[0] - args = [ - 'tasks', - 'run', - '--ignore-all-dependencies', - '--local', - self.dag_id, - task_id, - self.run_id, - ] - mock_get_dag_by_deserialization.return_value = SerializedDagModel.get(self.dag_id).dag - - task_command.task_run(self.parser.parse_args(args)) - mock_local_job.assert_called_once_with( - task_instance=mock.ANY, - mark_success=False, - ignore_all_deps=True, - ignore_depends_on_past=False, - ignore_task_deps=False, - ignore_ti_state=False, - pickle_id=None, - pool=None, - external_executor_id=None, - ) - mock_get_dag_by_deserialization.assert_called_once_with(self.dag_id) - def test_cli_test_different_path(self, session): """ When thedag processor has a different dags folder @@ -265,38 +233,6 @@ def test_cli_test_different_path(self, session): # verify that the file was in different location when run assert ti.xcom_pull(ti.task_id) == new_file_path.as_posix() - @mock.patch("airflow.cli.commands.task_command.get_dag_by_deserialization") - @mock.patch("airflow.cli.commands.task_command.LocalTaskJob") - def test_run_get_serialized_dag_fallback(self, mock_local_job, mock_get_dag_by_deserialization): - """ - Fallback to parse dag_file when serialized dag does not exist in the db - """ - task_id = self.dag.task_ids[0] - args = [ - 'tasks', - 'run', - '--ignore-all-dependencies', - '--local', - self.dag_id, - task_id, - self.run_id, - ] - mock_get_dag_by_deserialization.side_effect = mock.Mock(side_effect=AirflowException('Not found')) - - task_command.task_run(self.parser.parse_args(args)) - mock_local_job.assert_called_once_with( - task_instance=mock.ANY, - mark_success=False, - ignore_all_deps=True, - ignore_depends_on_past=False, - ignore_task_deps=False, - ignore_ti_state=False, - pickle_id=None, - pool=None, - external_executor_id=None, - ) - mock_get_dag_by_deserialization.assert_called_once_with(self.dag_id) - @mock.patch("airflow.cli.commands.task_command.LocalTaskJob") def test_run_with_existing_dag_run_id(self, mock_local_job): """