Skip to content

Commit

Permalink
Make execution_date_or_run_id optional in tasks test command (#26114
Browse files Browse the repository at this point in the history
)
  • Loading branch information
dstandish committed Sep 7, 2022
1 parent 5e24323 commit 243b3d7
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 36 deletions.
7 changes: 6 additions & 1 deletion airflow/cli/cli_parser.py
Expand Up @@ -183,6 +183,11 @@ def string_lower_type(val):
ARG_EXECUTION_DATE_OR_RUN_ID = Arg(
('execution_date_or_run_id',), help="The execution_date of the DAG or run_id of the DAGRun"
)
ARG_EXECUTION_DATE_OR_RUN_ID_OPTIONAL = Arg(
('execution_date_or_run_id',),
nargs='?',
help="The execution_date of the DAG or run_id of the DAGRun (optional)",
)
ARG_TASK_REGEX = Arg(
("-t", "--task-regex"), help="The regex to filter specific task_ids to backfill (optional)"
)
Expand Down Expand Up @@ -1296,7 +1301,7 @@ class GroupCommand(NamedTuple):
args=(
ARG_DAG_ID,
ARG_TASK_ID,
ARG_EXECUTION_DATE_OR_RUN_ID,
ARG_EXECUTION_DATE_OR_RUN_ID_OPTIONAL,
ARG_SUBDIR,
ARG_DRY_RUN,
ARG_TASK_PARAMS,
Expand Down
68 changes: 38 additions & 30 deletions airflow/cli/commands/task_command.py
Expand Up @@ -22,7 +22,7 @@
import logging
import os
import textwrap
from contextlib import contextmanager, redirect_stderr, redirect_stdout
from contextlib import contextmanager, redirect_stderr, redirect_stdout, suppress
from typing import Dict, Generator, List, Optional, Tuple, Union

from pendulum.parsing.exceptions import ParserError
Expand Down Expand Up @@ -75,8 +75,8 @@ def _generate_temporary_run_id() -> str:
def _get_dag_run(
*,
dag: DAG,
exec_date_or_run_id: str,
create_if_necessary: CreateIfNecessary,
exec_date_or_run_id: Optional[str] = None,
session: Session,
) -> Tuple[DagRun, bool]:
"""Try to retrieve a DAG run from a string representing either a run ID or logical date.
Expand All @@ -92,33 +92,35 @@ def _get_dag_run(
the logical date; otherwise use it as a run ID and set the logical date
to the current time.
"""
dag_run = dag.get_dagrun(run_id=exec_date_or_run_id, session=session)
if dag_run:
return dag_run, False

try:
execution_date: Optional[datetime.datetime] = timezone.parse(exec_date_or_run_id)
except (ParserError, TypeError):
execution_date = None

try:
dag_run = (
session.query(DagRun)
.filter(DagRun.dag_id == dag.dag_id, DagRun.execution_date == execution_date)
.one()
)
except NoResultFound:
if not create_if_necessary:
raise DagRunNotFound(
f"DagRun for {dag.dag_id} with run_id or execution_date of {exec_date_or_run_id!r} not found"
) from None
else:
return dag_run, False
if not exec_date_or_run_id and not create_if_necessary:
raise ValueError("Must provide `exec_date_or_run_id` if not `create_if_necessary`.")
execution_date: Optional[datetime.datetime] = None
if exec_date_or_run_id:
dag_run = dag.get_dagrun(run_id=exec_date_or_run_id, session=session)
if dag_run:
return dag_run, False
with suppress(ParserError, TypeError):
execution_date = timezone.parse(exec_date_or_run_id)
try:
dag_run = (
session.query(DagRun)
.filter(DagRun.dag_id == dag.dag_id, DagRun.execution_date == execution_date)
.one()
)
except NoResultFound:
if not create_if_necessary:
raise DagRunNotFound(
f"DagRun for {dag.dag_id} with run_id or execution_date "
f"of {exec_date_or_run_id!r} not found"
) from None
else:
return dag_run, False

if execution_date is not None:
dag_run_execution_date = execution_date
else:
dag_run_execution_date = timezone.utcnow()

if create_if_necessary == "memory":
dag_run = DagRun(dag.dag_id, run_id=exec_date_or_run_id, execution_date=dag_run_execution_date)
return dag_run, True
Expand All @@ -136,14 +138,16 @@ def _get_dag_run(
@provide_session
def _get_ti(
task: BaseOperator,
exec_date_or_run_id: str,
map_index: int,
*,
exec_date_or_run_id: Optional[str] = None,
pool: Optional[str] = None,
create_if_necessary: CreateIfNecessary = False,
session: Session = NEW_SESSION,
) -> Tuple[TaskInstance, bool]:
"""Get the task instance through DagRun.run_id, if that fails, get the TI the old way"""
if not exec_date_or_run_id and not create_if_necessary:
raise ValueError("Must provide `exec_date_or_run_id` if not `create_if_necessary`.")
if task.is_mapped:
if map_index < 0:
raise RuntimeError("No map_index passed to mapped task")
Expand Down Expand Up @@ -370,7 +374,7 @@ def task_run(args, dag=None):
# Use DAG from parameter
pass
task = dag.get_task(task_id=args.task_id)
ti, _ = _get_ti(task, args.execution_date_or_run_id, args.map_index, pool=args.pool)
ti, _ = _get_ti(task, args.map_index, exec_date_or_run_id=args.execution_date_or_run_id, pool=args.pool)
ti.init_run_context(raw=args.raw)

hostname = get_hostname()
Expand Down Expand Up @@ -398,7 +402,7 @@ def task_failed_deps(args):
"""
dag = get_dag(args.subdir, args.dag_id)
task = dag.get_task(task_id=args.task_id)
ti, _ = _get_ti(task, args.execution_date_or_run_id, args.map_index)
ti, _ = _get_ti(task, args.map_index, exec_date_or_run_id=args.execution_date_or_run_id)

dep_context = DepContext(deps=SCHEDULER_QUEUED_DEPS)
failed_deps = list(ti.get_failed_dep_statuses(dep_context=dep_context))
Expand All @@ -421,7 +425,7 @@ def task_state(args):
"""
dag = get_dag(args.subdir, args.dag_id)
task = dag.get_task(task_id=args.task_id)
ti, _ = _get_ti(task, args.execution_date_or_run_id, args.map_index)
ti, _ = _get_ti(task, args.map_index, exec_date_or_run_id=args.execution_date_or_run_id)
print(ti.current_state())


Expand Down Expand Up @@ -544,7 +548,9 @@ def task_test(args, dag=None):
if task.params:
task.params.validate()

ti, dr_created = _get_ti(task, args.execution_date_or_run_id, args.map_index, create_if_necessary="db")
ti, dr_created = _get_ti(
task, args.map_index, exec_date_or_run_id=args.execution_date_or_run_id, create_if_necessary="db"
)

try:
with redirect_stdout(RedactedIO()):
Expand Down Expand Up @@ -574,7 +580,9 @@ def task_render(args):
"""Renders and displays templated fields for a given task"""
dag = get_dag(args.subdir, args.dag_id)
task = dag.get_task(task_id=args.task_id)
ti, _ = _get_ti(task, args.execution_date_or_run_id, args.map_index, create_if_necessary="memory")
ti, _ = _get_ti(
task, args.map_index, exec_date_or_run_id=args.execution_date_or_run_id, create_if_necessary="memory"
)
ti.render_templates()
for attr in task.__class__.template_fields:
print(
Expand Down
4 changes: 2 additions & 2 deletions docs/apache-airflow/tutorial/fundamentals.rst
Expand Up @@ -326,7 +326,7 @@ its data interval.

.. code-block:: bash
# command layout: command subcommand dag_id task_id date
# command layout: command subcommand [dag_id] [task_id] [(optional) date]
# testing print_date
airflow tasks test tutorial print_date 2015-06-01
Expand All @@ -350,7 +350,7 @@ their log to stdout (on screen), does not bother with dependencies, and
does not communicate state (running, success, failed, ...) to the database.
It simply allows testing a single task instance.

The same applies to ``airflow dags test [dag_id] [logical_date]``, but on a DAG
The same applies to ``airflow dags test``, but on a DAG
level. It performs a single DAG run of the given DAG id. While it does take task
dependencies into account, no state is registered in the database. It is
convenient for locally testing a full run of your DAG, given that e.g. if one of
Expand Down
23 changes: 20 additions & 3 deletions tests/cli/commands/test_task_command.py
Expand Up @@ -15,6 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import io
import json
import logging
Expand All @@ -26,6 +27,7 @@
from pathlib import Path
from unittest import mock

import pendulum
import pytest
from parameterized import parameterized

Expand Down Expand Up @@ -103,6 +105,21 @@ def test_test(self):
# Check that prints, and log messages, are shown
assert "'example_python_operator__print_the_context__20180101'" in stdout.getvalue()

@pytest.mark.filterwarnings("ignore::airflow.utils.context.AirflowContextDeprecationWarning")
@mock.patch('airflow.utils.timezone.utcnow')
def test_test_no_execution_date(self, mock_utcnow):
"""Test the `airflow test` command"""
now = pendulum.now('UTC')
mock_utcnow.return_value = now
ds = now.strftime("%Y%m%d")
args = self.parser.parse_args(["tasks", "test", "example_python_operator", 'print_the_context'])

with redirect_stdout(io.StringIO()) as stdout:
task_command.task_test(args)

# Check that prints, and log messages, are shown
assert f"'example_python_operator__print_the_context__{ds}'" in stdout.getvalue()

@pytest.mark.filterwarnings("ignore::airflow.utils.context.AirflowContextDeprecationWarning")
def test_test_with_existing_dag_run(self, caplog):
"""Test the `airflow test` command"""
Expand Down Expand Up @@ -255,9 +272,9 @@ def test_cli_test_with_params(self):
'test',
'example_passing_params_via_test_command',
'run_this',
DEFAULT_DATE.isoformat(),
'--task-params',
'{"foo":"bar"}',
DEFAULT_DATE.isoformat(),
]
)
)
Expand All @@ -268,9 +285,9 @@ def test_cli_test_with_params(self):
'test',
'example_passing_params_via_test_command',
'also_run_this',
DEFAULT_DATE.isoformat(),
'--task-params',
'{"foo":"bar"}',
DEFAULT_DATE.isoformat(),
]
)
)
Expand All @@ -284,9 +301,9 @@ def test_cli_test_with_env_vars(self):
'test',
'example_passing_params_via_test_command',
'env_var_test_task',
DEFAULT_DATE.isoformat(),
'--env-vars',
'{"foo":"bar"}',
DEFAULT_DATE.isoformat(),
]
)
)
Expand Down

0 comments on commit 243b3d7

Please sign in to comment.