From 4c471dcd853d4a97a92e9b602f70a2d149ea9aca Mon Sep 17 00:00:00 2001 From: Daniel Standish <15932138+dstandish@users.noreply.github.com> Date: Tue, 2 Apr 2024 09:16:58 -0700 Subject: [PATCH] Remove select_column option in TaskInstance.get_task_instance (#38571) Fundamentally what's going on here is we need a TaskInstance object instead of a Row object when sending over the wire in RPC call. But the full story on this one is actually somewhat complicated. It was back in 2.2.0 in #25312 when we converted to query with the column attrs instead of the TI object (#28900 only refactored this logic into a function). The reason was to avoid locking the dag_run table since TI newly had a dag_run relationship attr. Now, this causes a problem with AIP-44 because the RPC api does not know how to serialize a Row object. This PR switches back to querying a TaskInstance object, but avoids locking dag_run by using lazy_load option. Meanwhile, since try_number is a horrible attribute (which gives you a different answer depending on the state), we have to switch it back to look at the underlying private attr instead of the public accesor. --- airflow/models/taskinstance.py | 24 +++++++++++------------- tests/models/test_taskinstance.py | 13 +++++++++++++ 2 files changed, 24 insertions(+), 13 deletions(-) diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py index 2107781041018..14fc0fc8f7ffb 100644 --- a/airflow/models/taskinstance.py +++ b/airflow/models/taskinstance.py @@ -61,7 +61,7 @@ ) from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.ext.mutable import MutableDict -from sqlalchemy.orm import reconstructor, relationship +from sqlalchemy.orm import lazyload, reconstructor, relationship from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value from sqlalchemy.sql.expression import case, select @@ -523,7 +523,6 @@ def _refresh_from_db( task_id=task_instance.task_id, run_id=task_instance.run_id, map_index=task_instance.map_index, - select_columns=True, lock_for_update=lock_for_update, session=session, ) @@ -534,8 +533,7 @@ def _refresh_from_db( task_instance.end_date = ti.end_date task_instance.duration = ti.duration task_instance.state = ti.state - # Since we selected columns, not the object, this is the raw value - task_instance.try_number = ti.try_number + task_instance.try_number = ti._try_number # private attr to get value unaltered by accessor task_instance.max_tries = ti.max_tries task_instance.hostname = ti.hostname task_instance.unixname = ti.unixname @@ -914,7 +912,7 @@ def _get_try_number(*, task_instance: TaskInstance | TaskInstancePydantic): :meta private: """ - if task_instance.state == TaskInstanceState.RUNNING.RUNNING: + if task_instance.state == TaskInstanceState.RUNNING: return task_instance._try_number return task_instance._try_number + 1 @@ -1798,18 +1796,18 @@ def get_task_instance( run_id: str, task_id: str, map_index: int, - select_columns: bool = False, lock_for_update: bool = False, session: Session = NEW_SESSION, ) -> TaskInstance | TaskInstancePydantic | None: query = ( - session.query(*TaskInstance.__table__.columns) if select_columns else session.query(TaskInstance) - ) - query = query.filter_by( - dag_id=dag_id, - run_id=run_id, - task_id=task_id, - map_index=map_index, + session.query(TaskInstance) + .options(lazyload("dag_run")) # lazy load dag run to avoid locking it + .filter_by( + dag_id=dag_id, + run_id=run_id, + task_id=task_id, + map_index=map_index, + ) ) if lock_for_update: diff --git a/tests/models/test_taskinstance.py b/tests/models/test_taskinstance.py index 8dacc839cb8a3..46654d564d422 100644 --- a/tests/models/test_taskinstance.py +++ b/tests/models/test_taskinstance.py @@ -4562,3 +4562,16 @@ def test_taskinstance_with_note(create_task_instance, session): assert session.query(TaskInstance).filter_by(**filter_kwargs).one_or_none() is None assert session.query(TaskInstanceNote).filter_by(**filter_kwargs).one_or_none() is None + + +def test__refresh_from_db_should_not_increment_try_number(dag_maker, session): + with dag_maker(): + BashOperator(task_id="hello", bash_command="hi") + dag_maker.create_dagrun(state="success") + ti = session.scalar(select(TaskInstance)) + assert ti.task_id == "hello" # just to confirm... + assert ti.try_number == 1 # starts out as 1 + ti.refresh_from_db() + assert ti.try_number == 1 # stays 1 + ti.refresh_from_db() + assert ti.try_number == 1 # stays 1