From b77fb10312b9af4bd9c319ab08ed8a3e015b1670 Mon Sep 17 00:00:00 2001 From: Niko Date: Thu, 9 Dec 2021 05:46:59 -0800 Subject: [PATCH] Fix TriggerDagRunOperator extra link (#19410) The extra link provided by the operator was previously using the execution date of the triggering dag, not the triggered dag. Store the execution date of the triggered dag in xcom so that it can be read back later within the webserver when the link is being created. (cherry picked from commit 820e836c4a2e45239279d4d71e1db9434022fec5) --- airflow/operators/trigger_dagrun.py | 19 +++++++++- tests/operators/test_trigger_dagrun.py | 49 ++++++++++++++++++++------ 2 files changed, 57 insertions(+), 11 deletions(-) diff --git a/airflow/operators/trigger_dagrun.py b/airflow/operators/trigger_dagrun.py index 421c7963d0d68..a346db10f0fc9 100644 --- a/airflow/operators/trigger_dagrun.py +++ b/airflow/operators/trigger_dagrun.py @@ -24,11 +24,15 @@ from airflow.api.common.trigger_dag import trigger_dag from airflow.exceptions import AirflowException, DagNotFound, DagRunAlreadyExists from airflow.models import BaseOperator, BaseOperatorLink, DagBag, DagModel, DagRun +from airflow.models.xcom import XCom from airflow.utils import timezone from airflow.utils.helpers import build_airflow_url_with_query from airflow.utils.state import State from airflow.utils.types import DagRunType +XCOM_EXECUTION_DATE_ISO = "trigger_execution_date_iso" +XCOM_RUN_ID = "trigger_run_id" + class TriggerDagRunLink(BaseOperatorLink): """ @@ -39,7 +43,13 @@ class TriggerDagRunLink(BaseOperatorLink): name = 'Triggered DAG' def get_link(self, operator, dttm): - query = {"dag_id": operator.trigger_dag_id, "execution_date": dttm.isoformat()} + # Fetch the correct execution date for the triggerED dag which is + # stored in xcom during execution of the triggerING task. + trigger_execution_date_iso = XCom.get_one( + execution_date=dttm, key=XCOM_EXECUTION_DATE_ISO, task_id=operator.task_id, dag_id=operator.dag_id + ) + + query = {"dag_id": operator.trigger_dag_id, "base_date": trigger_execution_date_iso} return build_airflow_url_with_query(query) @@ -140,6 +150,7 @@ def execute(self, context: Dict): execution_date=self.execution_date, replace_microseconds=False, ) + except DagRunAlreadyExists as e: if self.reset_dag_run: self.log.info("Clearing %s on %s", self.trigger_dag_id, self.execution_date) @@ -157,6 +168,12 @@ def execute(self, context: Dict): else: raise e + # Store the execution date from the dag run (either created or found above) to + # be used when creating the extra link on the webserver. + ti = context['task_instance'] + ti.xcom_push(key=XCOM_EXECUTION_DATE_ISO, value=dag_run.execution_date.isoformat()) + ti.xcom_push(key=XCOM_RUN_ID, value=dag_run.run_id) + if self.wait_for_completion: # wait for dag to complete while True: diff --git a/tests/operators/test_trigger_dagrun.py b/tests/operators/test_trigger_dagrun.py index 9ff87358d0db2..1934c4d4174b0 100644 --- a/tests/operators/test_trigger_dagrun.py +++ b/tests/operators/test_trigger_dagrun.py @@ -19,7 +19,7 @@ import pathlib import tempfile from datetime import datetime -from unittest import TestCase +from unittest import TestCase, mock import pytest @@ -76,6 +76,25 @@ def tearDown(self): pathlib.Path(self._tmpfile).unlink() + @mock.patch('airflow.operators.trigger_dagrun.build_airflow_url_with_query') + def assert_extra_link(self, triggering_exec_date, triggered_dag_run, triggering_task, mock_build_url): + """ + Asserts whether the correct extra links url will be created. + + Specifically it tests whether the correct dag id and date are passed to + the method which constructs the final url. + Note: We can't run that method to generate the url itself because the Flask app context + isn't available within the test logic, so it is mocked here. + """ + triggering_task.get_extra_links(triggering_exec_date, 'Triggered DAG') + assert mock_build_url.called + args, _ = mock_build_url.call_args + expected_args = { + 'dag_id': triggered_dag_run.dag_id, + 'base_date': triggered_dag_run.execution_date.isoformat(), + } + assert expected_args in args + def test_trigger_dagrun(self): """Test TriggerDagRunOperator.""" task = TriggerDagRunOperator(task_id="test_task", trigger_dag_id=TRIGGERED_DAG_ID, dag=self.dag) @@ -84,7 +103,9 @@ def test_trigger_dagrun(self): with create_session() as session: dagruns = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all() assert len(dagruns) == 1 - assert dagruns[0].external_trigger + triggered_dag_run = dagruns[0] + assert triggered_dag_run.external_trigger + self.assert_extra_link(DEFAULT_DATE, triggered_dag_run, task) def test_trigger_dagrun_custom_run_id(self): task = TriggerDagRunOperator( @@ -114,8 +135,10 @@ def test_trigger_dagrun_with_execution_date(self): with create_session() as session: dagruns = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all() assert len(dagruns) == 1 - assert dagruns[0].external_trigger - assert dagruns[0].execution_date == utc_now + triggered_dag_run = dagruns[0] + assert triggered_dag_run.external_trigger + assert triggered_dag_run.execution_date == utc_now + self.assert_extra_link(DEFAULT_DATE, triggered_dag_run, task) def test_trigger_dagrun_twice(self): """Test TriggerDagRunOperator with custom execution_date.""" @@ -140,12 +163,14 @@ def test_trigger_dagrun_twice(self): ) session.add(dag_run) session.commit() - task.execute(None) + task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) dagruns = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all() assert len(dagruns) == 1 - assert dagruns[0].external_trigger - assert dagruns[0].execution_date == utc_now + triggered_dag_run = dagruns[0] + assert triggered_dag_run.external_trigger + assert triggered_dag_run.execution_date == utc_now + self.assert_extra_link(DEFAULT_DATE, triggered_dag_run, task) def test_trigger_dagrun_with_templated_execution_date(self): """Test TriggerDagRunOperator with templated execution_date.""" @@ -160,8 +185,10 @@ def test_trigger_dagrun_with_templated_execution_date(self): with create_session() as session: dagruns = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all() assert len(dagruns) == 1 - assert dagruns[0].external_trigger - assert dagruns[0].execution_date == DEFAULT_DATE + triggered_dag_run = dagruns[0] + assert triggered_dag_run.external_trigger + assert triggered_dag_run.execution_date == DEFAULT_DATE + self.assert_extra_link(DEFAULT_DATE, triggered_dag_run, task) def test_trigger_dagrun_operator_conf(self): """Test passing conf to the triggered DagRun.""" @@ -288,7 +315,9 @@ def test_trigger_dagrun_triggering_itself(self): .all() ) assert len(dagruns) == 2 - assert dagruns[1].state == State.QUEUED + triggered_dag_run = dagruns[1] + assert triggered_dag_run.state == State.QUEUED + self.assert_extra_link(execution_date, triggered_dag_run, task) def test_trigger_dagrun_triggering_itself_with_execution_date(self): """Test TriggerDagRunOperator that triggers itself with execution date,