diff --git a/airflow/cli/cli_parser.py b/airflow/cli/cli_parser.py index 231bbed490ec5..21031bd1c21a4 100644 --- a/airflow/cli/cli_parser.py +++ b/airflow/cli/cli_parser.py @@ -183,6 +183,11 @@ def string_lower_type(val): ARG_EXECUTION_DATE_OR_RUN_ID = Arg( ('execution_date_or_run_id',), help="The execution_date of the DAG or run_id of the DAGRun" ) +ARG_EXECUTION_DATE_OR_RUN_ID_OPTIONAL = Arg( + ('execution_date_or_run_id',), + nargs='?', + help="The execution_date of the DAG or run_id of the DAGRun (optional)", +) ARG_TASK_REGEX = Arg( ("-t", "--task-regex"), help="The regex to filter specific task_ids to backfill (optional)" ) @@ -1296,7 +1301,7 @@ class GroupCommand(NamedTuple): args=( ARG_DAG_ID, ARG_TASK_ID, - ARG_EXECUTION_DATE_OR_RUN_ID, + ARG_EXECUTION_DATE_OR_RUN_ID_OPTIONAL, ARG_SUBDIR, ARG_DRY_RUN, ARG_TASK_PARAMS, diff --git a/airflow/cli/commands/task_command.py b/airflow/cli/commands/task_command.py index 694594d68dcaf..f8916f0466416 100644 --- a/airflow/cli/commands/task_command.py +++ b/airflow/cli/commands/task_command.py @@ -22,7 +22,7 @@ import logging import os import textwrap -from contextlib import contextmanager, redirect_stderr, redirect_stdout +from contextlib import contextmanager, redirect_stderr, redirect_stdout, suppress from typing import Dict, Generator, List, Optional, Tuple, Union from pendulum.parsing.exceptions import ParserError @@ -75,8 +75,8 @@ def _generate_temporary_run_id() -> str: def _get_dag_run( *, dag: DAG, - exec_date_or_run_id: str, create_if_necessary: CreateIfNecessary, + exec_date_or_run_id: Optional[str] = None, session: Session, ) -> Tuple[DagRun, bool]: """Try to retrieve a DAG run from a string representing either a run ID or logical date. @@ -92,33 +92,35 @@ def _get_dag_run( the logical date; otherwise use it as a run ID and set the logical date to the current time. """ - dag_run = dag.get_dagrun(run_id=exec_date_or_run_id, session=session) - if dag_run: - return dag_run, False - - try: - execution_date: Optional[datetime.datetime] = timezone.parse(exec_date_or_run_id) - except (ParserError, TypeError): - execution_date = None - - try: - dag_run = ( - session.query(DagRun) - .filter(DagRun.dag_id == dag.dag_id, DagRun.execution_date == execution_date) - .one() - ) - except NoResultFound: - if not create_if_necessary: - raise DagRunNotFound( - f"DagRun for {dag.dag_id} with run_id or execution_date of {exec_date_or_run_id!r} not found" - ) from None - else: - return dag_run, False + if not exec_date_or_run_id and not create_if_necessary: + raise ValueError("Must provide `exec_date_or_run_id` if not `create_if_necessary`.") + execution_date: Optional[datetime.datetime] = None + if exec_date_or_run_id: + dag_run = dag.get_dagrun(run_id=exec_date_or_run_id, session=session) + if dag_run: + return dag_run, False + with suppress(ParserError, TypeError): + execution_date = timezone.parse(exec_date_or_run_id) + try: + dag_run = ( + session.query(DagRun) + .filter(DagRun.dag_id == dag.dag_id, DagRun.execution_date == execution_date) + .one() + ) + except NoResultFound: + if not create_if_necessary: + raise DagRunNotFound( + f"DagRun for {dag.dag_id} with run_id or execution_date " + f"of {exec_date_or_run_id!r} not found" + ) from None + else: + return dag_run, False if execution_date is not None: dag_run_execution_date = execution_date else: dag_run_execution_date = timezone.utcnow() + if create_if_necessary == "memory": dag_run = DagRun(dag.dag_id, run_id=exec_date_or_run_id, execution_date=dag_run_execution_date) return dag_run, True @@ -136,14 +138,16 @@ def _get_dag_run( @provide_session def _get_ti( task: BaseOperator, - exec_date_or_run_id: str, map_index: int, *, + exec_date_or_run_id: Optional[str] = None, pool: Optional[str] = None, create_if_necessary: CreateIfNecessary = False, session: Session = NEW_SESSION, ) -> Tuple[TaskInstance, bool]: """Get the task instance through DagRun.run_id, if that fails, get the TI the old way""" + if not exec_date_or_run_id and not create_if_necessary: + raise ValueError("Must provide `exec_date_or_run_id` if not `create_if_necessary`.") if task.is_mapped: if map_index < 0: raise RuntimeError("No map_index passed to mapped task") @@ -370,7 +374,7 @@ def task_run(args, dag=None): # Use DAG from parameter pass task = dag.get_task(task_id=args.task_id) - ti, _ = _get_ti(task, args.execution_date_or_run_id, args.map_index, pool=args.pool) + ti, _ = _get_ti(task, args.map_index, exec_date_or_run_id=args.execution_date_or_run_id, pool=args.pool) ti.init_run_context(raw=args.raw) hostname = get_hostname() @@ -398,7 +402,7 @@ def task_failed_deps(args): """ dag = get_dag(args.subdir, args.dag_id) task = dag.get_task(task_id=args.task_id) - ti, _ = _get_ti(task, args.execution_date_or_run_id, args.map_index) + ti, _ = _get_ti(task, args.map_index, exec_date_or_run_id=args.execution_date_or_run_id) dep_context = DepContext(deps=SCHEDULER_QUEUED_DEPS) failed_deps = list(ti.get_failed_dep_statuses(dep_context=dep_context)) @@ -421,7 +425,7 @@ def task_state(args): """ dag = get_dag(args.subdir, args.dag_id) task = dag.get_task(task_id=args.task_id) - ti, _ = _get_ti(task, args.execution_date_or_run_id, args.map_index) + ti, _ = _get_ti(task, args.map_index, exec_date_or_run_id=args.execution_date_or_run_id) print(ti.current_state()) @@ -544,7 +548,9 @@ def task_test(args, dag=None): if task.params: task.params.validate() - ti, dr_created = _get_ti(task, args.execution_date_or_run_id, args.map_index, create_if_necessary="db") + ti, dr_created = _get_ti( + task, args.map_index, exec_date_or_run_id=args.execution_date_or_run_id, create_if_necessary="db" + ) try: with redirect_stdout(RedactedIO()): @@ -574,7 +580,9 @@ def task_render(args): """Renders and displays templated fields for a given task""" dag = get_dag(args.subdir, args.dag_id) task = dag.get_task(task_id=args.task_id) - ti, _ = _get_ti(task, args.execution_date_or_run_id, args.map_index, create_if_necessary="memory") + ti, _ = _get_ti( + task, args.map_index, exec_date_or_run_id=args.execution_date_or_run_id, create_if_necessary="memory" + ) ti.render_templates() for attr in task.__class__.template_fields: print( diff --git a/docs/apache-airflow/tutorial/fundamentals.rst b/docs/apache-airflow/tutorial/fundamentals.rst index 351c215912a94..d2071a0682273 100644 --- a/docs/apache-airflow/tutorial/fundamentals.rst +++ b/docs/apache-airflow/tutorial/fundamentals.rst @@ -326,7 +326,7 @@ its data interval. .. code-block:: bash - # command layout: command subcommand dag_id task_id date + # command layout: command subcommand [dag_id] [task_id] [(optional) date] # testing print_date airflow tasks test tutorial print_date 2015-06-01 @@ -350,7 +350,7 @@ their log to stdout (on screen), does not bother with dependencies, and does not communicate state (running, success, failed, ...) to the database. It simply allows testing a single task instance. -The same applies to ``airflow dags test [dag_id] [logical_date]``, but on a DAG +The same applies to ``airflow dags test``, but on a DAG level. It performs a single DAG run of the given DAG id. While it does take task dependencies into account, no state is registered in the database. It is convenient for locally testing a full run of your DAG, given that e.g. if one of diff --git a/tests/cli/commands/test_task_command.py b/tests/cli/commands/test_task_command.py index 8476d7f3e9abf..f3defcaeb892b 100644 --- a/tests/cli/commands/test_task_command.py +++ b/tests/cli/commands/test_task_command.py @@ -15,6 +15,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + import io import json import logging @@ -26,6 +27,7 @@ from pathlib import Path from unittest import mock +import pendulum import pytest from parameterized import parameterized @@ -103,6 +105,21 @@ def test_test(self): # Check that prints, and log messages, are shown assert "'example_python_operator__print_the_context__20180101'" in stdout.getvalue() + @pytest.mark.filterwarnings("ignore::airflow.utils.context.AirflowContextDeprecationWarning") + @mock.patch('airflow.utils.timezone.utcnow') + def test_test_no_execution_date(self, mock_utcnow): + """Test the `airflow test` command""" + now = pendulum.now('UTC') + mock_utcnow.return_value = now + ds = now.strftime("%Y%m%d") + args = self.parser.parse_args(["tasks", "test", "example_python_operator", 'print_the_context']) + + with redirect_stdout(io.StringIO()) as stdout: + task_command.task_test(args) + + # Check that prints, and log messages, are shown + assert f"'example_python_operator__print_the_context__{ds}'" in stdout.getvalue() + @pytest.mark.filterwarnings("ignore::airflow.utils.context.AirflowContextDeprecationWarning") def test_test_with_existing_dag_run(self, caplog): """Test the `airflow test` command""" @@ -255,9 +272,9 @@ def test_cli_test_with_params(self): 'test', 'example_passing_params_via_test_command', 'run_this', + DEFAULT_DATE.isoformat(), '--task-params', '{"foo":"bar"}', - DEFAULT_DATE.isoformat(), ] ) ) @@ -268,9 +285,9 @@ def test_cli_test_with_params(self): 'test', 'example_passing_params_via_test_command', 'also_run_this', + DEFAULT_DATE.isoformat(), '--task-params', '{"foo":"bar"}', - DEFAULT_DATE.isoformat(), ] ) ) @@ -284,9 +301,9 @@ def test_cli_test_with_env_vars(self): 'test', 'example_passing_params_via_test_command', 'env_var_test_task', + DEFAULT_DATE.isoformat(), '--env-vars', '{"foo":"bar"}', - DEFAULT_DATE.isoformat(), ] ) )