From 7a7d2ac11e36f2b1e7cdfc38a7c0b0d9c42348ab Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Tue, 26 Jul 2022 16:53:56 +0100 Subject: [PATCH 1/3] Don't mistakenly take a lock on DagRun via ti.refresh_from_fb In 2.2.0 we made TI.dag_run be automatically join-loaded, which is fine for most cases, but for `refresh_from_db` we don't need that (we don't access anything under ti.dag_run) and it's possible that when `lock_for_update=True` is passed we are locking more than we want to and _might_ cause deadlocks. Even if it doesn't, selecting more than we need is wasteful. --- airflow/models/taskinstance.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 725070500c459..2973aa4a8c9aa 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -848,28 +848,32 @@ def refresh_from_db(self, session: Session = NEW_SESSION, lock_for_update: bool """ self.log.debug("Refreshing TaskInstance %s from DB", self) - qry = session.query(TaskInstance).filter( - TaskInstance.dag_id == self.dag_id, - TaskInstance.task_id == self.task_id, - TaskInstance.run_id == self.run_id, - TaskInstance.map_index == self.map_index, + qry = ( + # To avoid joining any relationships by default select the all + # columns, not the object. This also means we get (effectively) a + # namedtuple back, not a TI object + session.query(*TaskInstance.__table__.columns).filter( + TaskInstance.dag_id == self.dag_id, + TaskInstance.task_id == self.task_id, + TaskInstance.run_id == self.run_id, + TaskInstance.map_index == self.map_index, + ) ) if lock_for_update: for attempt in run_with_db_retries(logger=self.log): with attempt: - ti: Optional[TaskInstance] = qry.with_for_update().first() + ti: Optional[TaskInstance] = qry.with_for_update().one_or_none() else: - ti = qry.first() + ti = qry.one_or_none() if ti: # Fields ordered per model definition self.start_date = ti.start_date self.end_date = ti.end_date self.duration = ti.duration self.state = ti.state - # Get the raw value of try_number column, don't read through the - # accessor here otherwise it will be incremented by one already. - self.try_number = ti._try_number + # Since we selected columns, not the object, this is the raw value + self.try_number = ti.try_number self.max_tries = ti.max_tries self.hostname = ti.hostname self.unixname = ti.unixname From 098466038359606ef5e10a0fe318aca4a303ce5c Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Thu, 28 Jul 2022 11:46:34 +0100 Subject: [PATCH 2/3] fixup! Don't mistakenly take a lock on DagRun via ti.refresh_from_fb --- airflow/models/taskinstance.py | 4 ++++ tests/jobs/test_scheduler_job.py | 4 +++- 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 2973aa4a8c9aa..b64d1a5f47887 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -288,6 +288,7 @@ def clear_task_instances( if dag_run_state == DagRunState.QUEUED: dr.last_scheduling_decision = None dr.start_date = None + session.flush() class _LazyXComAccessIterator(collections.abc.Iterator): @@ -848,6 +849,9 @@ def refresh_from_db(self, session: Session = NEW_SESSION, lock_for_update: bool """ self.log.debug("Refreshing TaskInstance %s from DB", self) + if self in session: + session.refresh(self, TaskInstance.__mapper__.column_attrs.keys()) + qry = ( # To avoid joining any relationships by default select the all # columns, not the object. This also means we get (effectively) a diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index df8c0f6c7160f..8d6a779c1d30f 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -458,6 +458,7 @@ def test_execute_task_instances_is_paused_wont_execute(self, session, dag_maker) ti1.state = State.SCHEDULED self.scheduler_job._critical_section_enqueue_task_instances(session) + session.flush() ti1.refresh_from_db(session=session) assert State.SCHEDULED == ti1.state session.rollback() @@ -1316,7 +1317,8 @@ def test_enqueue_task_instances_sets_ti_state_to_None_if_dagrun_in_finish_state( with patch.object(BaseExecutor, 'queue_command') as mock_queue_command: self.scheduler_job._enqueue_task_instances_with_queued_state([ti], session=session) - ti.refresh_from_db() + session.flush() + ti.refresh_from_db(session=session) assert ti.state == State.NONE mock_queue_command.assert_not_called() From 5cc7608274e48e93c495326f9d2b0ee588b59ab8 Mon Sep 17 00:00:00 2001 From: Ash Berlin-Taylor Date: Thu, 28 Jul 2022 17:48:05 +0100 Subject: [PATCH 3/3] Update airflow/models/taskinstance.py Co-authored-by: Jed Cunningham <66968678+jedcunningham@users.noreply.github.com> --- airflow/models/taskinstance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index b64d1a5f47887..d8ed3d8b8d24e 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -853,7 +853,7 @@ def refresh_from_db(self, session: Session = NEW_SESSION, lock_for_update: bool session.refresh(self, TaskInstance.__mapper__.column_attrs.keys()) qry = ( - # To avoid joining any relationships by default select the all + # To avoid joining any relationships, by default select all # columns, not the object. This also means we get (effectively) a # namedtuple back, not a TI object session.query(*TaskInstance.__table__.columns).filter(