Skip to content

Commit

Permalink
Remove DAG parsing from StandardTaskRunner
Browse files Browse the repository at this point in the history
This makes the starting of StandardTaskRunner faster as the parsing of DAG will now be done once at task_run.
Also removed parsing of example dags when running a task
  • Loading branch information
ephraimbuddy committed Sep 29, 2022
1 parent bec80af commit c04efac
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 91 deletions.
10 changes: 1 addition & 9 deletions airflow/cli/commands/task_command.py
Expand Up @@ -47,7 +47,6 @@
from airflow.utils import cli as cli_utils
from airflow.utils.cli import (
get_dag,
get_dag_by_deserialization,
get_dag_by_file_location,
get_dag_by_pickle,
get_dags,
Expand Down Expand Up @@ -364,14 +363,7 @@ def task_run(args, dag=None):
print(f'Loading pickle id: {args.pickle}')
dag = get_dag_by_pickle(args.pickle)
elif not dag:
if args.local:
try:
dag = get_dag_by_deserialization(args.dag_id)
except AirflowException:
print(f'DAG {args.dag_id} does not exist in the database, trying to parse the dag_file')
dag = get_dag(args.subdir, args.dag_id)
else:
dag = get_dag(args.subdir, args.dag_id)
dag = get_dag(args.subdir, args.dag_id, include_examples=False)
else:
# Use DAG from parameter
pass
Expand Down
8 changes: 3 additions & 5 deletions airflow/task/task_runner/standard_task_runner.py
Expand Up @@ -36,6 +36,7 @@ class StandardTaskRunner(BaseTaskRunner):
def __init__(self, local_task_job):
super().__init__(local_task_job)
self._rc = None
self.dag = local_task_job.task_instance.task.dag

def start(self):
if CAN_FORK and not self.run_as_user:
Expand Down Expand Up @@ -64,7 +65,6 @@ def _start_by_fork(self):
from airflow import settings
from airflow.cli.cli_parser import get_parser
from airflow.sentry import Sentry
from airflow.utils.cli import get_dag

# Force a new SQLAlchemy session. We can't share open DB handles
# between process. The cli code will re-create this as part of its
Expand Down Expand Up @@ -92,10 +92,8 @@ def _start_by_fork(self):
dag_id=self._task_instance.dag_id,
task_id=self._task_instance.task_id,
):
# parse dag file since `airflow tasks run --local` does not parse dag file
dag = get_dag(args.subdir, args.dag_id)
args.func(args, dag=dag)
return_code = 0
args.func(args, dag=self.dag)
return_code = 0
except Exception as exc:
return_code = 1

Expand Down
19 changes: 6 additions & 13 deletions airflow/utils/cli.py
Expand Up @@ -33,6 +33,7 @@
from typing import TYPE_CHECKING, Callable, TypeVar, cast

from airflow import settings
from airflow.configuration import conf
from airflow.exceptions import AirflowException
from airflow.utils import cli_action_loggers
from airflow.utils.log.non_caching_file_handler import NonCachingFileHandler
Expand Down Expand Up @@ -205,7 +206,9 @@ def _search_for_dag_file(val: str | None) -> str | None:
return None


def get_dag(subdir: str | None, dag_id: str) -> DAG:
def get_dag(
subdir: str | None, dag_id: str, include_examples=conf.getboolean('core', 'LOAD_EXAMPLES')
) -> DAG:
"""
Returns DAG of a given dag_id
Expand All @@ -216,28 +219,18 @@ def get_dag(subdir: str | None, dag_id: str) -> DAG:
from airflow.models import DagBag

first_path = process_subdir(subdir)
dagbag = DagBag(first_path)
dagbag = DagBag(first_path, include_examples=include_examples)
if dag_id not in dagbag.dags:
fallback_path = _search_for_dag_file(subdir) or settings.DAGS_FOLDER
logger.warning("Dag %r not found in path %s; trying path %s", dag_id, first_path, fallback_path)
dagbag = DagBag(dag_folder=fallback_path)
dagbag = DagBag(dag_folder=fallback_path, include_examples=include_examples)
if dag_id not in dagbag.dags:
raise AirflowException(
f"Dag {dag_id!r} could not be found; either it does not exist or it failed to parse."
)
return dagbag.dags[dag_id]


def get_dag_by_deserialization(dag_id: str) -> DAG:
from airflow.models.serialized_dag import SerializedDagModel

dag_model = SerializedDagModel.get(dag_id)
if dag_model is None:
raise AirflowException(f"Serialized DAG: {dag_id} could not be found")

return dag_model.dag


def get_dags(subdir: str | None, dag_id: str, use_regex: bool = False):
"""Returns DAG(s) matching a given regex or dag_id"""
from airflow.models import DagBag
Expand Down
64 changes: 0 additions & 64 deletions tests/cli/commands/test_task_command.py
Expand Up @@ -159,38 +159,6 @@ def test_test_filters_secrets(self, capsys):
task_command.task_test(args)
assert capsys.readouterr().out.endswith(f"{not_password}\n")

@mock.patch("airflow.cli.commands.task_command.get_dag_by_deserialization")
@mock.patch("airflow.cli.commands.task_command.LocalTaskJob")
def test_run_get_serialized_dag(self, mock_local_job, mock_get_dag_by_deserialization):
"""
Test using serialized dag for local task_run
"""
task_id = self.dag.task_ids[0]
args = [
'tasks',
'run',
'--ignore-all-dependencies',
'--local',
self.dag_id,
task_id,
self.run_id,
]
mock_get_dag_by_deserialization.return_value = SerializedDagModel.get(self.dag_id).dag

task_command.task_run(self.parser.parse_args(args))
mock_local_job.assert_called_once_with(
task_instance=mock.ANY,
mark_success=False,
ignore_all_deps=True,
ignore_depends_on_past=False,
ignore_task_deps=False,
ignore_ti_state=False,
pickle_id=None,
pool=None,
external_executor_id=None,
)
mock_get_dag_by_deserialization.assert_called_once_with(self.dag_id)

def test_cli_test_different_path(self, session):
"""
When thedag processor has a different dags folder
Expand Down Expand Up @@ -265,38 +233,6 @@ def test_cli_test_different_path(self, session):
# verify that the file was in different location when run
assert ti.xcom_pull(ti.task_id) == new_file_path.as_posix()

@mock.patch("airflow.cli.commands.task_command.get_dag_by_deserialization")
@mock.patch("airflow.cli.commands.task_command.LocalTaskJob")
def test_run_get_serialized_dag_fallback(self, mock_local_job, mock_get_dag_by_deserialization):
"""
Fallback to parse dag_file when serialized dag does not exist in the db
"""
task_id = self.dag.task_ids[0]
args = [
'tasks',
'run',
'--ignore-all-dependencies',
'--local',
self.dag_id,
task_id,
self.run_id,
]
mock_get_dag_by_deserialization.side_effect = mock.Mock(side_effect=AirflowException('Not found'))

task_command.task_run(self.parser.parse_args(args))
mock_local_job.assert_called_once_with(
task_instance=mock.ANY,
mark_success=False,
ignore_all_deps=True,
ignore_depends_on_past=False,
ignore_task_deps=False,
ignore_ti_state=False,
pickle_id=None,
pool=None,
external_executor_id=None,
)
mock_get_dag_by_deserialization.assert_called_once_with(self.dag_id)

@mock.patch("airflow.cli.commands.task_command.LocalTaskJob")
def test_run_with_existing_dag_run_id(self, mock_local_job):
"""
Expand Down

0 comments on commit c04efac

Please sign in to comment.