Skip to content

Commit

Permalink
Clean-ups around task-mapping code (#26879)
Browse files Browse the repository at this point in the history
(cherry picked from commit a2d8724)
  • Loading branch information
uranusjr authored and ephraimbuddy committed Oct 18, 2022
1 parent 0007186 commit 05408cb
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 34 deletions.
54 changes: 22 additions & 32 deletions airflow/models/dagrun.py
Expand Up @@ -759,7 +759,7 @@ def _get_ready_tis(
expansion_happened = True
if schedulable.state in SCHEDULEABLE_STATES:
task = schedulable.task
if isinstance(schedulable.task, MappedOperator):
if isinstance(task, MappedOperator):
# Ensure the task indexes are complete
created = self._revise_mapped_task_indexes(task, session=session)
ready_tis.extend(created)
Expand Down Expand Up @@ -872,8 +872,6 @@ def verify_integrity(
hook_is_noop: Literal[True, False] = getattr(task_instance_mutation_hook, 'is_noop', False)

dag = self.get_dag()
task_ids: set[str] = set()

task_ids = self._check_for_removed_or_restored_tasks(
dag, task_instance_mutation_hook, session=session
)
Expand Down Expand Up @@ -951,8 +949,8 @@ def _check_for_removed_or_restored_tasks(
ti.state = State.REMOVED
else:
# What if it is _now_ dynamically mapped, but wasn't before?
task.run_time_mapped_ti_count.cache_clear() # type: ignore[attr-defined]
total_length = task.run_time_mapped_ti_count(self.run_id, session=session)
task.get_mapped_ti_count.cache_clear() # type: ignore[attr-defined]
total_length = task.get_mapped_ti_count(self.run_id, session=session)

if total_length is None:
# Not all upstreams finished, so we can't tell what should be here. Remove everything.
Expand Down Expand Up @@ -1045,19 +1043,13 @@ def _create_tasks(
:param session: the session to use
"""

def expand_mapped_literals(
task: Operator, sequence: Sequence[int] | None = None
) -> tuple[Operator, Sequence[int]]:
def expand_mapped_literals(task: Operator) -> tuple[Operator, Sequence[int]]:
if not task.is_mapped:
return (task, (-1,))
task = cast("MappedOperator", task)
count = task.parse_time_mapped_ti_count or task.run_time_mapped_ti_count(
self.run_id, session=session
)
count = task.get_mapped_ti_count(self.run_id, session=session)
if not count:
return (task, (-1,))
if sequence:
return (task, sequence)
return (task, range(count))

tasks_and_map_idxs = map(expand_mapped_literals, filter(task_filter, dag.task_dict.values()))
Expand Down Expand Up @@ -1110,21 +1102,19 @@ def _create_task_instances(
# TODO[HA]: We probably need to savepoint this so we can keep the transaction alive.
session.rollback()

def _revise_mapped_task_indexes(self, task, session: Session):
def _revise_mapped_task_indexes(self, task: MappedOperator, session: Session) -> Iterable[TI]:
"""Check if task increased or reduced in length and handle appropriately"""
from airflow.models.taskinstance import TaskInstance
from airflow.settings import task_instance_mutation_hook

task.run_time_mapped_ti_count.cache_clear()
total_length = (
task.parse_time_mapped_ti_count
or task.run_time_mapped_ti_count(self.run_id, session=session)
or 0
)
query = session.query(TaskInstance.map_index).filter(
TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id == task.task_id,
TaskInstance.run_id == self.run_id,
task.get_mapped_ti_count.cache_clear() # type: ignore[attr-defined]
total_length = task.get_mapped_ti_count(self.run_id, session=session)
if total_length is None: # Upstreams not ready, don't need to revise this yet.
return []

query = session.query(TI.map_index).filter(
TI.dag_id == self.dag_id,
TI.task_id == task.task_id,
TI.run_id == self.run_id,
)
existing_indexes = {i for (i,) in query}
missing_indexes = set(range(total_length)).difference(existing_indexes)
Expand All @@ -1133,20 +1123,20 @@ def _revise_mapped_task_indexes(self, task, session: Session):

if missing_indexes:
for index in missing_indexes:
ti = TaskInstance(task, run_id=self.run_id, map_index=index, state=None)
ti = TI(task, run_id=self.run_id, map_index=index, state=None)
self.log.debug("Expanding TIs upserted %s", ti)
task_instance_mutation_hook(ti)
ti = session.merge(ti)
ti.refresh_from_task(task)
session.flush()
created_tis.append(ti)
elif removed_indexes:
session.query(TaskInstance).filter(
TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id == task.task_id,
TaskInstance.run_id == self.run_id,
TaskInstance.map_index.in_(removed_indexes),
).update({TaskInstance.state: TaskInstanceState.REMOVED})
session.query(TI).filter(
TI.dag_id == self.dag_id,
TI.task_id == task.task_id,
TI.run_id == self.run_id,
TI.map_index.in_(removed_indexes),
).update({TI.state: TaskInstanceState.REMOVED})
session.flush()
return created_tis

Expand Down
10 changes: 9 additions & 1 deletion airflow/models/mappedoperator.py
Expand Up @@ -727,15 +727,23 @@ def iter_mapped_dependencies(self) -> Iterator[Operator]:
def parse_time_mapped_ti_count(self) -> int | None:
"""Number of mapped TaskInstances that can be created at DagRun create time.
This only considers literal mapped arguments, and would return *None*
when any non-literal values are used for mapping.
:return: None if non-literal mapped arg encountered, or the total
number of mapped TIs this task should have.
"""
return self._get_specified_expand_input().get_parse_time_mapped_ti_count()

@cache
def run_time_mapped_ti_count(self, run_id: str, *, session: Session) -> int | None:
def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int | None:
"""Number of mapped TaskInstances that can be created at run time.
This considers both literal and non-literal mapped arguments, and the
result is therefore available when all depended tasks have finished. The
return value should be identical to ``parse_time_mapped_ti_count`` if
all mapped arguments are literal.
:return: None if upstream tasks are not complete yet, or the total
number of mapped TIs this task should have.
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils/mapping.py
Expand Up @@ -42,4 +42,4 @@ def expand_mapped_task(
session.flush()

mapped.expand_mapped_task(run_id, session=session)
mapped.run_time_mapped_ti_count.cache_clear() # type: ignore[attr-defined]
mapped.get_mapped_ti_count.cache_clear() # type: ignore[attr-defined]

0 comments on commit 05408cb

Please sign in to comment.