Skip to content

Commit

Permalink
review_1
Browse files Browse the repository at this point in the history
  • Loading branch information
raphaelauv committed May 5, 2024
1 parent 21bad35 commit ce3387a
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 8 deletions.
20 changes: 13 additions & 7 deletions airflow/operators/trigger_dagrun.py
Expand Up @@ -28,7 +28,13 @@

from airflow.api.common.trigger_dag import trigger_dag
from airflow.configuration import conf
from airflow.exceptions import AirflowException, AirflowSkipException, DagNotFound, DagRunAlreadyExists, RemovedInAirflow3Warning
from airflow.exceptions import (
AirflowException,
AirflowSkipException,
DagNotFound,
DagRunAlreadyExists,
RemovedInAirflow3Warning,
)
from airflow.models.baseoperator import BaseOperator
from airflow.models.baseoperatorlink import BaseOperatorLink
from airflow.models.dag import DagModel
Expand Down Expand Up @@ -90,7 +96,7 @@ class TriggerDagRunOperator(BaseOperator):
(default: 60)
:param allowed_states: List of allowed states, default is ``['success']``.
:param failed_states: List of failed or dis-allowed states, default is ``None``.
:param soft_fail: Set to true to mark the task as SKIPPED on DagRunAlreadyExists
:param skip_when_already_exists: Set to true to mark the task as SKIPPED if a dag_run already exists
:param deferrable: If waiting for completion, whether or not to defer the task until done,
default is ``False``.
:param execution_date: Deprecated parameter; same as ``logical_date``.
Expand All @@ -102,7 +108,7 @@ class TriggerDagRunOperator(BaseOperator):
"logical_date",
"conf",
"wait_for_completion",
"soft_fail",
"skip_when_already_exists",
)
template_fields_renderers = {"conf": "py"}
ui_color = "#ffefeb"
Expand All @@ -120,7 +126,7 @@ def __init__(
poke_interval: int = 60,
allowed_states: list[str] | None = None,
failed_states: list[str] | None = None,
soft_fail: bool = False,
skip_when_already_exists: bool = False,
deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False),
execution_date: str | datetime.datetime | None = None,
**kwargs,
Expand All @@ -140,7 +146,7 @@ def __init__(
self.failed_states = [DagRunState(s) for s in failed_states]
else:
self.failed_states = [DagRunState.FAILED]
self.soft_fail = soft_fail
self.skip_when_already_exists = skip_when_already_exists
self._defer = deferrable

if execution_date is not None:
Expand Down Expand Up @@ -200,9 +206,9 @@ def execute(self, context: Context):
dag_run = e.dag_run
dag.clear(start_date=dag_run.logical_date, end_date=dag_run.logical_date)
else:
if self.soft_fail:
if self.skip_when_already_exists:
raise AirflowSkipException(
"Skipping due to soft_fail is set to True and DagRunAlreadyExists"
"Skipping due to skip_when_already_exists is set to True and DagRunAlreadyExists"
)
raise e
if dag_run is None:
Expand Down
19 changes: 18 additions & 1 deletion tests/operators/test_trigger_dagrun.py
Expand Up @@ -35,7 +35,7 @@
from airflow.triggers.external_task import DagStateTrigger
from airflow.utils import timezone
from airflow.utils.session import create_session
from airflow.utils.state import State
from airflow.utils.state import State, TaskInstanceState
from airflow.utils.types import DagRunType

pytestmark = pytest.mark.db_test
Expand Down Expand Up @@ -322,6 +322,23 @@ def test_trigger_dagrun_with_reset_dag_run_false_fail(self, trigger_run_id, trig
with pytest.raises(DagRunAlreadyExists):
task.run(start_date=logical_date, end_date=logical_date, ignore_ti_state=True)

def test_trigger_dagrun_with_skip_when_already_exists(self):
"""Test TriggerDagRunOperator with skip_when_already_exists."""
execution_date = DEFAULT_DATE
task = TriggerDagRunOperator(
task_id="test_task",
trigger_dag_id=TRIGGERED_DAG_ID,
trigger_run_id="dummy_run_id",
execution_date=None,
reset_dag_run=False,
skip_when_already_exists=True,
dag=self.dag,
)
task.run(start_date=execution_date, end_date=execution_date, ignore_ti_state=True)
assert task.get_task_instances()[0].state == TaskInstanceState.SUCCESS
task.run(start_date=execution_date, end_date=execution_date, ignore_ti_state=True)
assert task.get_task_instances()[0].state == TaskInstanceState.SKIPPED

@pytest.mark.parametrize(
"trigger_run_id, trigger_logical_date, expected_dagruns_count",
[
Expand Down

0 comments on commit ce3387a

Please sign in to comment.