diff --git a/airflow/operators/trigger_dagrun.py b/airflow/operators/trigger_dagrun.py index 421c7963d0d68..a346db10f0fc9 100644 --- a/airflow/operators/trigger_dagrun.py +++ b/airflow/operators/trigger_dagrun.py @@ -24,11 +24,15 @@ from airflow.api.common.trigger_dag import trigger_dag from airflow.exceptions import AirflowException, DagNotFound, DagRunAlreadyExists from airflow.models import BaseOperator, BaseOperatorLink, DagBag, DagModel, DagRun +from airflow.models.xcom import XCom from airflow.utils import timezone from airflow.utils.helpers import build_airflow_url_with_query from airflow.utils.state import State from airflow.utils.types import DagRunType +XCOM_EXECUTION_DATE_ISO = "trigger_execution_date_iso" +XCOM_RUN_ID = "trigger_run_id" + class TriggerDagRunLink(BaseOperatorLink): """ @@ -39,7 +43,13 @@ class TriggerDagRunLink(BaseOperatorLink): name = 'Triggered DAG' def get_link(self, operator, dttm): - query = {"dag_id": operator.trigger_dag_id, "execution_date": dttm.isoformat()} + # Fetch the correct execution date for the triggerED dag which is + # stored in xcom during execution of the triggerING task. + trigger_execution_date_iso = XCom.get_one( + execution_date=dttm, key=XCOM_EXECUTION_DATE_ISO, task_id=operator.task_id, dag_id=operator.dag_id + ) + + query = {"dag_id": operator.trigger_dag_id, "base_date": trigger_execution_date_iso} return build_airflow_url_with_query(query) @@ -140,6 +150,7 @@ def execute(self, context: Dict): execution_date=self.execution_date, replace_microseconds=False, ) + except DagRunAlreadyExists as e: if self.reset_dag_run: self.log.info("Clearing %s on %s", self.trigger_dag_id, self.execution_date) @@ -157,6 +168,12 @@ def execute(self, context: Dict): else: raise e + # Store the execution date from the dag run (either created or found above) to + # be used when creating the extra link on the webserver. + ti = context['task_instance'] + ti.xcom_push(key=XCOM_EXECUTION_DATE_ISO, value=dag_run.execution_date.isoformat()) + ti.xcom_push(key=XCOM_RUN_ID, value=dag_run.run_id) + if self.wait_for_completion: # wait for dag to complete while True: diff --git a/tests/operators/test_trigger_dagrun.py b/tests/operators/test_trigger_dagrun.py index 9ff87358d0db2..1934c4d4174b0 100644 --- a/tests/operators/test_trigger_dagrun.py +++ b/tests/operators/test_trigger_dagrun.py @@ -19,7 +19,7 @@ import pathlib import tempfile from datetime import datetime -from unittest import TestCase +from unittest import TestCase, mock import pytest @@ -76,6 +76,25 @@ def tearDown(self): pathlib.Path(self._tmpfile).unlink() + @mock.patch('airflow.operators.trigger_dagrun.build_airflow_url_with_query') + def assert_extra_link(self, triggering_exec_date, triggered_dag_run, triggering_task, mock_build_url): + """ + Asserts whether the correct extra links url will be created. + + Specifically it tests whether the correct dag id and date are passed to + the method which constructs the final url. + Note: We can't run that method to generate the url itself because the Flask app context + isn't available within the test logic, so it is mocked here. + """ + triggering_task.get_extra_links(triggering_exec_date, 'Triggered DAG') + assert mock_build_url.called + args, _ = mock_build_url.call_args + expected_args = { + 'dag_id': triggered_dag_run.dag_id, + 'base_date': triggered_dag_run.execution_date.isoformat(), + } + assert expected_args in args + def test_trigger_dagrun(self): """Test TriggerDagRunOperator.""" task = TriggerDagRunOperator(task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, dag=self.dag) @@ -84,7 +103,9 @@ def test_trigger_dagrun(self): with create_session() as session: dagruns = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all() assert len(dagruns) == 1 - assert dagruns[0].external_trigger + triggered_dag_run = dagruns[0] + assert triggered_dag_run.external_trigger + self.assert_extra_link(DEFAULT_DATE, triggered_dag_run, task) def test_trigger_dagrun_custom_run_id(self): task = TriggerDagRunOperator( @@ -114,8 +135,10 @@ def test_trigger_dagrun_with_execution_date(self): with create_session() as session: dagruns = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all() assert len(dagruns) == 1 - assert dagruns[0].external_trigger - assert dagruns[0].execution_date == utc_now + triggered_dag_run = dagruns[0] + assert triggered_dag_run.external_trigger + assert triggered_dag_run.execution_date == utc_now + self.assert_extra_link(DEFAULT_DATE, triggered_dag_run, task) def test_trigger_dagrun_twice(self): """Test TriggerDagRunOperator with custom execution_date.""" @@ -140,12 +163,14 @@ def test_trigger_dagrun_twice(self): ) session.add(dag_run) session.commit() - task.execute(None) + task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) dagruns = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all() assert len(dagruns) == 1 - assert dagruns[0].external_trigger - assert dagruns[0].execution_date == utc_now + triggered_dag_run = dagruns[0] + assert triggered_dag_run.external_trigger + assert triggered_dag_run.execution_date == utc_now + self.assert_extra_link(DEFAULT_DATE, triggered_dag_run, task) def test_trigger_dagrun_with_templated_execution_date(self): """Test TriggerDagRunOperator with templated execution_date.""" @@ -160,8 +185,10 @@ def test_trigger_dagrun_with_templated_execution_date(self): with create_session() as session: dagruns = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all() assert len(dagruns) == 1 - assert dagruns[0].external_trigger - assert dagruns[0].execution_date == DEFAULT_DATE + triggered_dag_run = dagruns[0] + assert triggered_dag_run.external_trigger + assert triggered_dag_run.execution_date == DEFAULT_DATE + self.assert_extra_link(DEFAULT_DATE, triggered_dag_run, task) def test_trigger_dagrun_operator_conf(self): """Test passing conf to the triggered DagRun.""" @@ -288,7 +315,9 @@ def test_trigger_dagrun_triggering_itself(self): .all() ) assert len(dagruns) == 2 - assert dagruns[1].state == State.QUEUED + triggered_dag_run = dagruns[1] + assert triggered_dag_run.state == State.QUEUED + self.assert_extra_link(execution_date, triggered_dag_run, task) def test_trigger_dagrun_triggering_itself_with_execution_date(self): """Test TriggerDagRunOperator that triggers itself with execution date,