Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Clean-ups around task-mapping code #26879

Merged
merged 1 commit into from Oct 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
54 changes: 22 additions & 32 deletions airflow/models/dagrun.py
Expand Up @@ -760,7 +760,7 @@ def _get_ready_tis(
expansion_happened = True
if schedulable.state in SCHEDULEABLE_STATES:
task = schedulable.task
if isinstance(schedulable.task, MappedOperator):
if isinstance(task, MappedOperator):
# Ensure the task indexes are complete
created = self._revise_mapped_task_indexes(task, session=session)
ready_tis.extend(created)
Expand Down Expand Up @@ -873,8 +873,6 @@ def verify_integrity(
hook_is_noop: Literal[True, False] = getattr(task_instance_mutation_hook, 'is_noop', False)

dag = self.get_dag()
task_ids: set[str] = set()

task_ids = self._check_for_removed_or_restored_tasks(
dag, task_instance_mutation_hook, session=session
)
Expand Down Expand Up @@ -952,8 +950,8 @@ def _check_for_removed_or_restored_tasks(
ti.state = State.REMOVED
else:
# What if it is _now_ dynamically mapped, but wasn't before?
task.run_time_mapped_ti_count.cache_clear() # type: ignore[attr-defined]
total_length = task.run_time_mapped_ti_count(self.run_id, session=session)
task.get_mapped_ti_count.cache_clear() # type: ignore[attr-defined]
total_length = task.get_mapped_ti_count(self.run_id, session=session)

if total_length is None:
# Not all upstreams finished, so we can't tell what should be here. Remove everything.
Expand Down Expand Up @@ -1046,19 +1044,13 @@ def _create_tasks(
:param session: the session to use
"""

def expand_mapped_literals(
task: Operator, sequence: Sequence[int] | None = None
) -> tuple[Operator, Sequence[int]]:
def expand_mapped_literals(task: Operator) -> tuple[Operator, Sequence[int]]:
if not task.is_mapped:
return (task, (-1,))
task = cast("MappedOperator", task)
count = task.parse_time_mapped_ti_count or task.run_time_mapped_ti_count(
self.run_id, session=session
)
count = task.get_mapped_ti_count(self.run_id, session=session)
if not count:
return (task, (-1,))
if sequence:
return (task, sequence)
return (task, range(count))

tasks_and_map_idxs = map(expand_mapped_literals, filter(task_filter, dag.task_dict.values()))
Expand Down Expand Up @@ -1111,21 +1103,19 @@ def _create_task_instances(
# TODO[HA]: We probably need to savepoint this so we can keep the transaction alive.
session.rollback()

def _revise_mapped_task_indexes(self, task, session: Session):
def _revise_mapped_task_indexes(self, task: MappedOperator, session: Session) -> Iterable[TI]:
"""Check if task increased or reduced in length and handle appropriately"""
from airflow.models.taskinstance import TaskInstance
from airflow.settings import task_instance_mutation_hook

task.run_time_mapped_ti_count.cache_clear()
total_length = (
task.parse_time_mapped_ti_count
or task.run_time_mapped_ti_count(self.run_id, session=session)
or 0
)
query = session.query(TaskInstance.map_index).filter(
TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id == task.task_id,
TaskInstance.run_id == self.run_id,
task.get_mapped_ti_count.cache_clear() # type: ignore[attr-defined]
total_length = task.get_mapped_ti_count(self.run_id, session=session)
if total_length is None: # Upstreams not ready, don't need to revise this yet.
return []

query = session.query(TI.map_index).filter(
TI.dag_id == self.dag_id,
TI.task_id == task.task_id,
TI.run_id == self.run_id,
)
existing_indexes = {i for (i,) in query}
missing_indexes = set(range(total_length)).difference(existing_indexes)
Expand All @@ -1134,20 +1124,20 @@ def _revise_mapped_task_indexes(self, task, session: Session):

if missing_indexes:
for index in missing_indexes:
ti = TaskInstance(task, run_id=self.run_id, map_index=index, state=None)
ti = TI(task, run_id=self.run_id, map_index=index, state=None)
self.log.debug("Expanding TIs upserted %s", ti)
task_instance_mutation_hook(ti)
ti = session.merge(ti)
ti.refresh_from_task(task)
session.flush()
created_tis.append(ti)
elif removed_indexes:
session.query(TaskInstance).filter(
TaskInstance.dag_id == self.dag_id,
TaskInstance.task_id == task.task_id,
TaskInstance.run_id == self.run_id,
TaskInstance.map_index.in_(removed_indexes),
).update({TaskInstance.state: TaskInstanceState.REMOVED})
session.query(TI).filter(
TI.dag_id == self.dag_id,
TI.task_id == task.task_id,
TI.run_id == self.run_id,
TI.map_index.in_(removed_indexes),
).update({TI.state: TaskInstanceState.REMOVED})
session.flush()
return created_tis

Expand Down
10 changes: 9 additions & 1 deletion airflow/models/mappedoperator.py
Expand Up @@ -728,15 +728,23 @@ def iter_mapped_dependencies(self) -> Iterator[Operator]:
def parse_time_mapped_ti_count(self) -> int | None:
"""Number of mapped TaskInstances that can be created at DagRun create time.

This only considers literal mapped arguments, and would return *None*
when any non-literal values are used for mapping.

:return: None if non-literal mapped arg encountered, or the total
number of mapped TIs this task should have.
"""
return self._get_specified_expand_input().get_parse_time_mapped_ti_count()

@cache
def run_time_mapped_ti_count(self, run_id: str, *, session: Session) -> int | None:
def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int | None:
"""Number of mapped TaskInstances that can be created at run time.

This considers both literal and non-literal mapped arguments, and the
result is therefore available when all depended tasks have finished. The
return value should be identical to ``parse_time_mapped_ti_count`` if
all mapped arguments are literal.

:return: None if upstream tasks are not complete yet, or the total
number of mapped TIs this task should have.
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/test_utils/mapping.py
Expand Up @@ -42,4 +42,4 @@ def expand_mapped_task(
session.flush()

mapped.expand_mapped_task(run_id, session=session)
mapped.run_time_mapped_ti_count.cache_clear() # type: ignore[attr-defined]
mapped.get_mapped_ti_count.cache_clear() # type: ignore[attr-defined]