Skip to content

Commit

Permalink
Don't show stale Serialized DAGs if they are deleted in DB (apache#16368
Browse files Browse the repository at this point in the history
)

If `DagBag.get_dag()` is called currently, it will return the DAG
even if the DAG does not exist in `serialized_dag` table.

This PR changes that logic to remove the dag from local cache too
when `DagBag.get_dag()` is called. This happens after
`min_serialized_dag_fetch_secs`.

(cherry picked from commit e3b3c1f)
(cherry picked from commit 9d14b1d)
(cherry picked from commit c95f6d9)
  • Loading branch information
kaxil committed Jun 23, 2021
1 parent 9c76e8f commit df68d5c
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 4 deletions.
11 changes: 10 additions & 1 deletion airflow/models/dagbag.py
Expand Up @@ -175,6 +175,8 @@ def get_dag(self, dag_id, session: Session = None):
# 1. if time has come to check if DAG is updated (controlled by min_serialized_dag_fetch_secs)
# 2. check the last_updated column in SerializedDag table to see if Serialized DAG is updated
# 3. if (2) is yes, fetch the Serialized DAG.
# 4. if (2) returns None (i.e. Serialized DAG is deleted), remove dag from dagbag
# if it exists and return None.
min_serialized_dag_fetch_secs = timedelta(seconds=settings.MIN_SERIALIZED_DAG_FETCH_INTERVAL)
if (
dag_id in self.dags_last_fetched
Expand All @@ -184,7 +186,14 @@ def get_dag(self, dag_id, session: Session = None):
dag_id=dag_id,
session=session,
)
if sd_last_updated_datetime and sd_last_updated_datetime > self.dags_last_fetched[dag_id]:
if not sd_last_updated_datetime:
self.log.warning("Serialized DAG %s no longer exists", dag_id)
del self.dags[dag_id]
del self.dags_last_fetched[dag_id]
del self.dags_hash[dag_id]
return None

if sd_last_updated_datetime > self.dags_last_fetched[dag_id]:
self._add_dag_from_db(dag_id=dag_id, session=session)

return self.dags.get(dag_id)
Expand Down
4 changes: 2 additions & 2 deletions airflow/models/serialized_dag.py
Expand Up @@ -253,7 +253,7 @@ def bulk_sync_to_db(dags: List[DAG], session: Session = None):

@classmethod
@provide_session
def get_last_updated_datetime(cls, dag_id: str, session: Session = None) -> datetime:
def get_last_updated_datetime(cls, dag_id: str, session: Session = None) -> Optional[datetime]:
"""
Get the date when the Serialized DAG associated to DAG was last updated
in serialized_dag table
Expand All @@ -276,6 +276,6 @@ def get_latest_version_hash(cls, dag_id: str, session: Session = None) -> str:
:param session: ORM Session
:type session: Session
:return: DAG Hash
:rtype: str
:rtype: str | None
"""
return session.query(cls.dag_hash).filter(cls.dag_id == dag_id).scalar()
31 changes: 30 additions & 1 deletion tests/models/test_dagbag.py
Expand Up @@ -19,7 +19,7 @@
import shutil
import textwrap
import unittest
from datetime import datetime, timezone
from datetime import datetime, timedelta, timezone
from tempfile import NamedTemporaryFile, mkdtemp
from unittest import mock
from unittest.mock import patch
Expand All @@ -33,6 +33,7 @@
from airflow.exceptions import SerializationError
from airflow.models import DagBag, DagModel
from airflow.models.serialized_dag import SerializedDagModel
from airflow.serialization.serialized_objects import SerializedDAG
from airflow.utils.dates import timezone as tz
from airflow.utils.session import create_session
from tests import cluster_policies
Expand Down Expand Up @@ -278,6 +279,34 @@ def process_file(self, filepath, only_if_updated=True, safe_mode=True):
self.assertEqual(dag_id, dag.dag_id)
self.assertEqual(2, dagbag.process_file_calls)

def test_dag_removed_if_serialized_dag_is_removed(self):
"""
Test that if a DAG does not exist in serialized_dag table (as the DAG file was removed),
remove dags from the DagBag
"""
from airflow.operators.dummy import DummyOperator

dag = models.DAG(
dag_id="test_dag_removed_if_serialized_dag_is_removed",
schedule_interval=None,
start_date=tz.datetime(2021, 10, 12),
)

with dag:
DummyOperator(task_id="task_1")

dagbag = DagBag(dag_folder=self.empty_dir, include_examples=False, read_dags_from_db=True)
dagbag.dags = {dag.dag_id: SerializedDAG.from_dict(SerializedDAG.to_dict(dag))}
dagbag.dags_last_fetched = {dag.dag_id: (tz.utcnow() - timedelta(minutes=2))}
dagbag.dags_hash = {dag.dag_id: mock.ANY}

assert SerializedDagModel.has_dag(dag.dag_id) is False

assert dagbag.get_dag(dag.dag_id) is None
assert dag.dag_id not in dagbag.dags
assert dag.dag_id not in dagbag.dags_last_fetched
assert dag.dag_id not in dagbag.dags_hash

def process_dag(self, create_dag):
"""
Helper method to process a file generated from the input create_dag function.
Expand Down

0 comments on commit df68d5c

Please sign in to comment.