Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: skip_when_already_exists TriggerDagRunOperator #39173

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
16 changes: 15 additions & 1 deletion airflow/operators/trigger_dagrun.py
Expand Up @@ -28,7 +28,13 @@

from airflow.api.common.trigger_dag import trigger_dag
from airflow.configuration import conf
from airflow.exceptions import AirflowException, DagNotFound, DagRunAlreadyExists, RemovedInAirflow3Warning
from airflow.exceptions import (
AirflowException,
AirflowSkipException,
DagNotFound,
DagRunAlreadyExists,
RemovedInAirflow3Warning,
)
from airflow.models.baseoperator import BaseOperator
from airflow.models.baseoperatorlink import BaseOperatorLink
from airflow.models.dag import DagModel
Expand Down Expand Up @@ -90,6 +96,7 @@ class TriggerDagRunOperator(BaseOperator):
(default: 60)
:param allowed_states: List of allowed states, default is ``['success']``.
:param failed_states: List of failed or dis-allowed states, default is ``None``.
:param skip_when_already_exists: Set to true to mark the task as SKIPPED if a dag_run already exists
:param deferrable: If waiting for completion, whether or not to defer the task until done,
default is ``False``.
:param execution_date: Deprecated parameter; same as ``logical_date``.
Expand All @@ -101,6 +108,7 @@ class TriggerDagRunOperator(BaseOperator):
"logical_date",
"conf",
"wait_for_completion",
"skip_when_already_exists",
)
template_fields_renderers = {"conf": "py"}
ui_color = "#ffefeb"
Expand All @@ -118,6 +126,7 @@ def __init__(
poke_interval: int = 60,
allowed_states: list[str] | None = None,
failed_states: list[str] | None = None,
skip_when_already_exists: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
execution_date: str | datetime.datetime | None = None,
**kwargs,
Expand All @@ -137,6 +146,7 @@ def __init__(
self.failed_states = [DagRunState(s) for s in failed_states]
else:
self.failed_states = [DagRunState.FAILED]
self.skip_when_already_exists = skip_when_already_exists
self._defer = deferrable

if execution_date is not None:
Expand Down Expand Up @@ -196,6 +206,10 @@ def execute(self, context: Context):
dag_run = e.dag_run
dag.clear(start_date=dag_run.logical_date, end_date=dag_run.logical_date)
else:
if self.skip_when_already_exists:
raise AirflowSkipException(
"Skipping due to skip_when_already_exists is set to True and DagRunAlreadyExists"
)
raise e
if dag_run is None:
raise RuntimeError("The dag_run should be set here!")
Expand Down
19 changes: 18 additions & 1 deletion tests/operators/test_trigger_dagrun.py
Expand Up @@ -35,7 +35,7 @@
from airflow.triggers.external_task import DagStateTrigger
from airflow.utils import timezone
from airflow.utils.session import create_session
from airflow.utils.state import State
from airflow.utils.state import State, TaskInstanceState
from airflow.utils.types import DagRunType

pytestmark = pytest.mark.db_test
Expand Down Expand Up @@ -322,6 +322,23 @@ def test_trigger_dagrun_with_reset_dag_run_false_fail(self, trigger_run_id, trig
with pytest.raises(DagRunAlreadyExists):
task.run(start_date=logical_date, end_date=logical_date, ignore_ti_state=True)

def test_trigger_dagrun_with_skip_when_already_exists(self):
"""Test TriggerDagRunOperator with skip_when_already_exists."""
execution_date = DEFAULT_DATE
task = TriggerDagRunOperator(
task_id="test_task",
trigger_dag_id=TRIGGERED_DAG_ID,
trigger_run_id="dummy_run_id",
execution_date=None,
reset_dag_run=False,
skip_when_already_exists=True,
dag=self.dag,
)
task.run(start_date=execution_date, end_date=execution_date, ignore_ti_state=True)
assert task.get_task_instances()[0].state == TaskInstanceState.SUCCESS
task.run(start_date=execution_date, end_date=execution_date, ignore_ti_state=True)
assert task.get_task_instances()[0].state == TaskInstanceState.SKIPPED

@pytest.mark.parametrize(
"trigger_run_id, trigger_logical_date, expected_dagruns_count",
[
Expand Down