diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index b78463b6cf5e6..be2701b118c34 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -190,6 +190,8 @@ def get_dag(self, dag_id, session: Session = None): # 1. if time has come to check if DAG is updated (controlled by min_serialized_dag_fetch_secs) # 2. check the last_updated column in SerializedDag table to see if Serialized DAG is updated # 3. if (2) is yes, fetch the Serialized DAG. + # 4. if (2) returns None (i.e. Serialized DAG is deleted), remove dag from dagbag + # if it exists and return None. min_serialized_dag_fetch_secs = timedelta(seconds=settings.MIN_SERIALIZED_DAG_FETCH_INTERVAL) if ( dag_id in self.dags_last_fetched @@ -199,7 +201,14 @@ def get_dag(self, dag_id, session: Session = None): dag_id=dag_id, session=session, ) - if sd_last_updated_datetime and sd_last_updated_datetime > self.dags_last_fetched[dag_id]: + if not sd_last_updated_datetime: + self.log.warning("Serialized DAG %s no longer exists", dag_id) + del self.dags[dag_id] + del self.dags_last_fetched[dag_id] + del self.dags_hash[dag_id] + return None + + if sd_last_updated_datetime > self.dags_last_fetched[dag_id]: self._add_dag_from_db(dag_id=dag_id, session=session) return self.dags.get(dag_id) diff --git a/airflow/models/serialized_dag.py b/airflow/models/serialized_dag.py index 81448b6e64b8d..4e8ebc4f7e3ae 100644 --- a/airflow/models/serialized_dag.py +++ b/airflow/models/serialized_dag.py @@ -261,7 +261,7 @@ def bulk_sync_to_db(dags: List[DAG], session: Session = None): @classmethod @provide_session - def get_last_updated_datetime(cls, dag_id: str, session: Session = None) -> datetime: + def get_last_updated_datetime(cls, dag_id: str, session: Session = None) -> Optional[datetime]: """ Get the date when the Serialized DAG associated to DAG was last updated in serialized_dag table @@ -295,7 +295,7 @@ def get_latest_version_hash(cls, dag_id: str, session: Session = None) -> str: :param session: ORM Session :type session: Session :return: DAG Hash - :rtype: str + :rtype: str | None """ return session.query(cls.dag_hash).filter(cls.dag_id == dag_id).scalar() diff --git a/tests/models/test_dagbag.py b/tests/models/test_dagbag.py index 359cd5c11b18f..0c52c49c99247 100644 --- a/tests/models/test_dagbag.py +++ b/tests/models/test_dagbag.py @@ -19,7 +19,7 @@ import shutil import textwrap import unittest -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from tempfile import NamedTemporaryFile, mkdtemp from unittest import mock from unittest.mock import patch @@ -33,6 +33,7 @@ from airflow.exceptions import SerializationError from airflow.models import DagBag, DagModel from airflow.models.serialized_dag import SerializedDagModel +from airflow.serialization.serialized_objects import SerializedDAG from airflow.utils.dates import timezone as tz from airflow.utils.session import create_session from airflow.www.security import ApplessAirflowSecurityManager @@ -311,6 +312,34 @@ def process_file(self, filepath, only_if_updated=True, safe_mode=True): assert dag_id == dag.dag_id assert 2 == dagbag.process_file_calls + def test_dag_removed_if_serialized_dag_is_removed(self): + """ + Test that if a DAG does not exist in serialized_dag table (as the DAG file was removed), + remove dags from the DagBag + """ + from airflow.operators.dummy import DummyOperator + + dag = models.DAG( + dag_id="test_dag_removed_if_serialized_dag_is_removed", + schedule_interval=None, + start_date=tz.datetime(2021, 10, 12), + ) + + with dag: + DummyOperator(task_id="task_1") + + dagbag = DagBag(dag_folder=self.empty_dir, include_examples=False, read_dags_from_db=True) + dagbag.dags = {dag.dag_id: SerializedDAG.from_dict(SerializedDAG.to_dict(dag))} + dagbag.dags_last_fetched = {dag.dag_id: (tz.utcnow() - timedelta(minutes=2))} + dagbag.dags_hash = {dag.dag_id: mock.ANY} + + assert SerializedDagModel.has_dag(dag.dag_id) is False + + assert dagbag.get_dag(dag.dag_id) is None + assert dag.dag_id not in dagbag.dags + assert dag.dag_id not in dagbag.dags_last_fetched + assert dag.dag_id not in dagbag.dags_hash + def process_dag(self, create_dag): """ Helper method to process a file generated from the input create_dag function.