From df68d5c06f336aac875d54354dc3658bbfe4b5fb Mon Sep 17 00:00:00 2001 From: Kaxil Naik Date: Fri, 11 Jun 2021 21:29:56 +0100 Subject: [PATCH] Don't show stale Serialized DAGs if they are deleted in DB (#16368) If `DagBag.get_dag()` is called currently, it will return the DAG even if the DAG does not exist in `serialized_dag` table. This PR changes that logic to remove the dag from local cache too when `DagBag.get_dag()` is called. This happens after `min_serialized_dag_fetch_secs`. (cherry picked from commit e3b3c1fd1cf61b5d1bbe7aef11ddc85b9a7aa171) (cherry picked from commit 9d14b1d6d213b736a552a86cc346e7d38ab9e287) (cherry picked from commit c95f6d96f5bc04852ffb874df0f98a8a7a3834ed) --- airflow/models/dagbag.py | 11 ++++++++++- airflow/models/serialized_dag.py | 4 ++-- tests/models/test_dagbag.py | 31 ++++++++++++++++++++++++++++++- 3 files changed, 42 insertions(+), 4 deletions(-) diff --git a/airflow/models/dagbag.py b/airflow/models/dagbag.py index f1ae55cdbe97a..efe03d7fb8033 100644 --- a/airflow/models/dagbag.py +++ b/airflow/models/dagbag.py @@ -175,6 +175,8 @@ def get_dag(self, dag_id, session: Session = None): # 1. if time has come to check if DAG is updated (controlled by min_serialized_dag_fetch_secs) # 2. check the last_updated column in SerializedDag table to see if Serialized DAG is updated # 3. if (2) is yes, fetch the Serialized DAG. + # 4. if (2) returns None (i.e. Serialized DAG is deleted), remove dag from dagbag + # if it exists and return None. min_serialized_dag_fetch_secs = timedelta(seconds=settings.MIN_SERIALIZED_DAG_FETCH_INTERVAL) if ( dag_id in self.dags_last_fetched @@ -184,7 +186,14 @@ def get_dag(self, dag_id, session: Session = None): dag_id=dag_id, session=session, ) - if sd_last_updated_datetime and sd_last_updated_datetime > self.dags_last_fetched[dag_id]: + if not sd_last_updated_datetime: + self.log.warning("Serialized DAG %s no longer exists", dag_id) + del self.dags[dag_id] + del self.dags_last_fetched[dag_id] + del self.dags_hash[dag_id] + return None + + if sd_last_updated_datetime > self.dags_last_fetched[dag_id]: self._add_dag_from_db(dag_id=dag_id, session=session) return self.dags.get(dag_id) diff --git a/airflow/models/serialized_dag.py b/airflow/models/serialized_dag.py index e12b29ffbbc60..eea42eba8108b 100644 --- a/airflow/models/serialized_dag.py +++ b/airflow/models/serialized_dag.py @@ -253,7 +253,7 @@ def bulk_sync_to_db(dags: List[DAG], session: Session = None): @classmethod @provide_session - def get_last_updated_datetime(cls, dag_id: str, session: Session = None) -> datetime: + def get_last_updated_datetime(cls, dag_id: str, session: Session = None) -> Optional[datetime]: """ Get the date when the Serialized DAG associated to DAG was last updated in serialized_dag table @@ -276,6 +276,6 @@ def get_latest_version_hash(cls, dag_id: str, session: Session = None) -> str: :param session: ORM Session :type session: Session :return: DAG Hash - :rtype: str + :rtype: str | None """ return session.query(cls.dag_hash).filter(cls.dag_id == dag_id).scalar() diff --git a/tests/models/test_dagbag.py b/tests/models/test_dagbag.py index 73b0bf0e41334..c7a73fe0afe38 100644 --- a/tests/models/test_dagbag.py +++ b/tests/models/test_dagbag.py @@ -19,7 +19,7 @@ import shutil import textwrap import unittest -from datetime import datetime, timezone +from datetime import datetime, timedelta, timezone from tempfile import NamedTemporaryFile, mkdtemp from unittest import mock from unittest.mock import patch @@ -33,6 +33,7 @@ from airflow.exceptions import SerializationError from airflow.models import DagBag, DagModel from airflow.models.serialized_dag import SerializedDagModel +from airflow.serialization.serialized_objects import SerializedDAG from airflow.utils.dates import timezone as tz from airflow.utils.session import create_session from tests import cluster_policies @@ -278,6 +279,34 @@ def process_file(self, filepath, only_if_updated=True, safe_mode=True): self.assertEqual(dag_id, dag.dag_id) self.assertEqual(2, dagbag.process_file_calls) + def test_dag_removed_if_serialized_dag_is_removed(self): + """ + Test that if a DAG does not exist in serialized_dag table (as the DAG file was removed), + remove dags from the DagBag + """ + from airflow.operators.dummy import DummyOperator + + dag = models.DAG( + dag_id="test_dag_removed_if_serialized_dag_is_removed", + schedule_interval=None, + start_date=tz.datetime(2021, 10, 12), + ) + + with dag: + DummyOperator(task_id="task_1") + + dagbag = DagBag(dag_folder=self.empty_dir, include_examples=False, read_dags_from_db=True) + dagbag.dags = {dag.dag_id: SerializedDAG.from_dict(SerializedDAG.to_dict(dag))} + dagbag.dags_last_fetched = {dag.dag_id: (tz.utcnow() - timedelta(minutes=2))} + dagbag.dags_hash = {dag.dag_id: mock.ANY} + + assert SerializedDagModel.has_dag(dag.dag_id) is False + + assert dagbag.get_dag(dag.dag_id) is None + assert dag.dag_id not in dagbag.dags + assert dag.dag_id not in dagbag.dags_last_fetched + assert dag.dag_id not in dagbag.dags_hash + def process_dag(self, create_dag): """ Helper method to process a file generated from the input create_dag function.