diff --git a/airflow/operators/trigger_dagrun.py b/airflow/operators/trigger_dagrun.py index 3306dde8a6bae..526b04c485bf7 100644 --- a/airflow/operators/trigger_dagrun.py +++ b/airflow/operators/trigger_dagrun.py @@ -107,12 +107,12 @@ def __init__( self.allowed_states = allowed_states or [State.SUCCESS] self.failed_states = failed_states or [State.FAILED] - if not isinstance(execution_date, (str, datetime.datetime, type(None))): + if execution_date is not None and not isinstance(execution_date, (str, datetime.datetime)): raise TypeError( f"Expected str or datetime.datetime type for execution_date.Got {type(execution_date)}" ) - self.execution_date: Optional[datetime.datetime] = execution_date # type: ignore + self.execution_date = execution_date try: json.dumps(self.conf) @@ -121,30 +121,28 @@ def __init__( def execute(self, context: Context): if isinstance(self.execution_date, datetime.datetime): - execution_date = self.execution_date + parsed_execution_date = self.execution_date elif isinstance(self.execution_date, str): - execution_date = timezone.parse(self.execution_date) - self.execution_date = execution_date + parsed_execution_date = timezone.parse(self.execution_date) else: - execution_date = timezone.utcnow() + parsed_execution_date = timezone.utcnow() if self.trigger_run_id: run_id = self.trigger_run_id else: - run_id = DagRun.generate_run_id(DagRunType.MANUAL, execution_date) - + run_id = DagRun.generate_run_id(DagRunType.MANUAL, parsed_execution_date) try: dag_run = trigger_dag( dag_id=self.trigger_dag_id, run_id=run_id, conf=self.conf, - execution_date=self.execution_date, + execution_date=parsed_execution_date, replace_microseconds=False, ) except DagRunAlreadyExists as e: if self.reset_dag_run: - self.log.info("Clearing %s on %s", self.trigger_dag_id, self.execution_date) + self.log.info("Clearing %s on %s", self.trigger_dag_id, parsed_execution_date) # Get target dag object and call clear() @@ -154,7 +152,7 @@ def execute(self, context: Context): dag_bag = DagBag(dag_folder=dag_model.fileloc, read_dags_from_db=True) dag = dag_bag.get_dag(self.trigger_dag_id) - dag.clear(start_date=self.execution_date, end_date=self.execution_date) + dag.clear(start_date=parsed_execution_date, end_date=parsed_execution_date) dag_run = DagRun.find(dag_id=dag.dag_id, run_id=run_id)[0] else: raise e diff --git a/tests/operators/test_trigger_dagrun.py b/tests/operators/test_trigger_dagrun.py index 1934c4d4174b0..180781eed6109 100644 --- a/tests/operators/test_trigger_dagrun.py +++ b/tests/operators/test_trigger_dagrun.py @@ -30,6 +30,7 @@ from airflow.utils import timezone from airflow.utils.session import create_session from airflow.utils.state import State +from airflow.utils.types import DagRunType DEFAULT_DATE = datetime(2019, 1, 1, tzinfo=timezone.utc) TEST_DAG_ID = "testdag" @@ -101,11 +102,10 @@ def test_trigger_dagrun(self): task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) with create_session() as session: - dagruns = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all() - assert len(dagruns) == 1 - triggered_dag_run = dagruns[0] - assert triggered_dag_run.external_trigger - self.assert_extra_link(DEFAULT_DATE, triggered_dag_run, task) + dagrun = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).one() + assert dagrun.external_trigger + assert dagrun.run_id == DagRun.generate_run_id(DagRunType.MANUAL, dagrun.execution_date) + self.assert_extra_link(DEFAULT_DATE, dagrun, task) def test_trigger_dagrun_custom_run_id(self): task = TriggerDagRunOperator( @@ -123,22 +123,21 @@ def test_trigger_dagrun_custom_run_id(self): def test_trigger_dagrun_with_execution_date(self): """Test TriggerDagRunOperator with custom execution_date.""" - utc_now = timezone.utcnow() + custom_execution_date = timezone.datetime(2021, 1, 2, 3, 4, 5) task = TriggerDagRunOperator( task_id="test_trigger_dagrun_with_execution_date", trigger_dag_id=TRIGGERED_DAG_ID, - execution_date=utc_now, + execution_date=custom_execution_date, dag=self.dag, ) task.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) with create_session() as session: - dagruns = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).all() - assert len(dagruns) == 1 - triggered_dag_run = dagruns[0] - assert triggered_dag_run.external_trigger - assert triggered_dag_run.execution_date == utc_now - self.assert_extra_link(DEFAULT_DATE, triggered_dag_run, task) + dagrun = session.query(DagRun).filter(DagRun.dag_id == TRIGGERED_DAG_ID).one() + assert dagrun.external_trigger + assert dagrun.execution_date == custom_execution_date + assert dagrun.run_id == DagRun.generate_run_id(DagRunType.MANUAL, custom_execution_date) + self.assert_extra_link(DEFAULT_DATE, dagrun, task) def test_trigger_dagrun_twice(self): """Test TriggerDagRunOperator with custom execution_date."""