Skip to content

Commit

Permalink
Remove select_column option in TaskInstance.get_task_instance (apache…
Browse files Browse the repository at this point in the history
…#38571)

Fundamentally what's going on here is we need a TaskInstance object instead of a Row object when sending over the wire in RPC call.  But the full story on this one is actually somewhat complicated.
It was back in 2.2.0 in apache#25312 when we converted to query with the column attrs instead of the TI object (apache#28900 only refactored this logic into a function).  The reason was to avoid locking the dag_run table since TI newly had a dag_run relationship attr.  Now, this causes a problem with AIP-44 because the RPC api does not know how to serialize a Row object.
This PR switches back to querying a TaskInstance object, but avoids locking dag_run by using lazy_load option.  Meanwhile, since try_number is a horrible attribute (which gives you a different answer depending on the state), we have to switch it back to look at the underlying private attr instead of the public accesor.
  • Loading branch information
dstandish authored and idantepper@gmail.com committed Apr 3, 2024
1 parent 7092ea1 commit 4c471dc
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 13 deletions.
24 changes: 11 additions & 13 deletions airflow/models/taskinstance.py
Expand Up @@ -61,7 +61,7 @@
)
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.ext.mutable import MutableDict
from sqlalchemy.orm import reconstructor, relationship
from sqlalchemy.orm import lazyload, reconstructor, relationship
from sqlalchemy.orm.attributes import NO_VALUE, set_committed_value
from sqlalchemy.sql.expression import case, select

Expand Down Expand Up @@ -523,7 +523,6 @@ def _refresh_from_db(
task_id=task_instance.task_id,
run_id=task_instance.run_id,
map_index=task_instance.map_index,
select_columns=True,
lock_for_update=lock_for_update,
session=session,
)
Expand All @@ -534,8 +533,7 @@ def _refresh_from_db(
task_instance.end_date = ti.end_date
task_instance.duration = ti.duration
task_instance.state = ti.state
# Since we selected columns, not the object, this is the raw value
task_instance.try_number = ti.try_number
task_instance.try_number = ti._try_number # private attr to get value unaltered by accessor
task_instance.max_tries = ti.max_tries
task_instance.hostname = ti.hostname
task_instance.unixname = ti.unixname
Expand Down Expand Up @@ -914,7 +912,7 @@ def _get_try_number(*, task_instance: TaskInstance | TaskInstancePydantic):
:meta private:
"""
if task_instance.state == TaskInstanceState.RUNNING.RUNNING:
if task_instance.state == TaskInstanceState.RUNNING:
return task_instance._try_number
return task_instance._try_number + 1

Expand Down Expand Up @@ -1798,18 +1796,18 @@ def get_task_instance(
run_id: str,
task_id: str,
map_index: int,
select_columns: bool = False,
lock_for_update: bool = False,
session: Session = NEW_SESSION,
) -> TaskInstance | TaskInstancePydantic | None:
query = (
session.query(*TaskInstance.__table__.columns) if select_columns else session.query(TaskInstance)
)
query = query.filter_by(
dag_id=dag_id,
run_id=run_id,
task_id=task_id,
map_index=map_index,
session.query(TaskInstance)
.options(lazyload("dag_run")) # lazy load dag run to avoid locking it
.filter_by(
dag_id=dag_id,
run_id=run_id,
task_id=task_id,
map_index=map_index,
)
)

if lock_for_update:
Expand Down
13 changes: 13 additions & 0 deletions tests/models/test_taskinstance.py
Expand Up @@ -4562,3 +4562,16 @@ def test_taskinstance_with_note(create_task_instance, session):

assert session.query(TaskInstance).filter_by(**filter_kwargs).one_or_none() is None
assert session.query(TaskInstanceNote).filter_by(**filter_kwargs).one_or_none() is None


def test__refresh_from_db_should_not_increment_try_number(dag_maker, session):
with dag_maker():
BashOperator(task_id="hello", bash_command="hi")
dag_maker.create_dagrun(state="success")
ti = session.scalar(select(TaskInstance))
assert ti.task_id == "hello" # just to confirm...
assert ti.try_number == 1 # starts out as 1
ti.refresh_from_db()
assert ti.try_number == 1 # stays 1
ti.refresh_from_db()
assert ti.try_number == 1 # stays 1

0 comments on commit 4c471dc

Please sign in to comment.