diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 656d7754561a2..7930d91fb8a64 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -287,6 +287,7 @@ def clear_task_instances( if dag_run_state == DagRunState.QUEUED: dr.last_scheduling_decision = None dr.start_date = None + session.flush() class _LazyXComAccessIterator(collections.abc.Iterator): @@ -848,28 +849,35 @@ def refresh_from_db(self, session: Session = NEW_SESSION, lock_for_update: bool """ self.log.debug("Refreshing TaskInstance %s from DB", self) - qry = session.query(TaskInstance).filter( - TaskInstance.dag_id == self.dag_id, - TaskInstance.task_id == self.task_id, - TaskInstance.run_id == self.run_id, - TaskInstance.map_index == self.map_index, + if self in session: + session.refresh(self, TaskInstance.__mapper__.column_attrs.keys()) + + qry = ( + # To avoid joining any relationships, by default select all + # columns, not the object. This also means we get (effectively) a + # namedtuple back, not a TI object + session.query(*TaskInstance.__table__.columns).filter( + TaskInstance.dag_id == self.dag_id, + TaskInstance.task_id == self.task_id, + TaskInstance.run_id == self.run_id, + TaskInstance.map_index == self.map_index, + ) ) if lock_for_update: for attempt in run_with_db_retries(logger=self.log): with attempt: - ti: Optional[TaskInstance] = qry.with_for_update().first() + ti: Optional[TaskInstance] = qry.with_for_update().one_or_none() else: - ti = qry.first() + ti = qry.one_or_none() if ti: # Fields ordered per model definition self.start_date = ti.start_date self.end_date = ti.end_date self.duration = ti.duration self.state = ti.state - # Get the raw value of try_number column, don't read through the - # accessor here otherwise it will be incremented by one already. - self.try_number = ti._try_number + # Since we selected columns, not the object, this is the raw value + self.try_number = ti.try_number self.max_tries = ti.max_tries self.hostname = ti.hostname self.unixname = ti.unixname diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index f43351751103d..ba055b295cfe5 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -457,6 +457,7 @@ def test_execute_task_instances_is_paused_wont_execute(self, session, dag_maker) ti1.state = State.SCHEDULED self.scheduler_job._critical_section_enqueue_task_instances(session) + session.flush() ti1.refresh_from_db(session=session) assert State.SCHEDULED == ti1.state session.rollback() @@ -1315,7 +1316,8 @@ def test_enqueue_task_instances_sets_ti_state_to_None_if_dagrun_in_finish_state( with patch.object(BaseExecutor, 'queue_command') as mock_queue_command: self.scheduler_job._enqueue_task_instances_with_queued_state([ti], session=session) - ti.refresh_from_db() + session.flush() + ti.refresh_from_db(session=session) assert ti.state == State.NONE mock_queue_command.assert_not_called()