From 38707b407d1ea0b511e325bfca6c1484bc9db93b Mon Sep 17 00:00:00 2001 From: Tzu-ping Chung Date: Wed, 5 Oct 2022 16:18:04 +0800 Subject: [PATCH] Clean-ups around task-mapping code There should be no change in functionality. A few things are involved: 1. The expand_mapped_literals closure is declaring a 'sequence' argument that is always None. Remove it. 2. The run_time_mapped_ti_count method is never used in isolation, but combined with parse_time_mapped_ti_count. We should just combine the calls -- actually, the run-time variant already encompasses the parse-time variant, so the latter call can simply be removed. 3. The TaskInstance import in _revise_mapped_task_indexes is redundant (the class is already imported globally) and is removed. 4. Various typing fixups. --- airflow/models/dagrun.py | 54 +++++++++++++------------------- airflow/models/mappedoperator.py | 10 +++++- tests/test_utils/mapping.py | 2 +- 3 files changed, 32 insertions(+), 34 deletions(-) diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index ac98b1a32f49c..d6adde5f7b0aa 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -760,7 +760,7 @@ def _get_ready_tis( expansion_happened = True if schedulable.state in SCHEDULEABLE_STATES: task = schedulable.task - if isinstance(schedulable.task, MappedOperator): + if isinstance(task, MappedOperator): # Ensure the task indexes are complete created = self._revise_mapped_task_indexes(task, session=session) ready_tis.extend(created) @@ -873,8 +873,6 @@ def verify_integrity( hook_is_noop: Literal[True, False] = getattr(task_instance_mutation_hook, 'is_noop', False) dag = self.get_dag() - task_ids: set[str] = set() - task_ids = self._check_for_removed_or_restored_tasks( dag, task_instance_mutation_hook, session=session ) @@ -952,8 +950,8 @@ def _check_for_removed_or_restored_tasks( ti.state = State.REMOVED else: # What if it is _now_ dynamically mapped, but wasn't before? - task.run_time_mapped_ti_count.cache_clear() # type: ignore[attr-defined] - total_length = task.run_time_mapped_ti_count(self.run_id, session=session) + task.get_mapped_ti_count.cache_clear() # type: ignore[attr-defined] + total_length = task.get_mapped_ti_count(self.run_id, session=session) if total_length is None: # Not all upstreams finished, so we can't tell what should be here. Remove everything. @@ -1046,19 +1044,13 @@ def _create_tasks( :param session: the session to use """ - def expand_mapped_literals( - task: Operator, sequence: Sequence[int] | None = None - ) -> tuple[Operator, Sequence[int]]: + def expand_mapped_literals(task: Operator) -> tuple[Operator, Sequence[int]]: if not task.is_mapped: return (task, (-1,)) task = cast("MappedOperator", task) - count = task.parse_time_mapped_ti_count or task.run_time_mapped_ti_count( - self.run_id, session=session - ) + count = task.get_mapped_ti_count(self.run_id, session=session) if not count: return (task, (-1,)) - if sequence: - return (task, sequence) return (task, range(count)) tasks_and_map_idxs = map(expand_mapped_literals, filter(task_filter, dag.task_dict.values())) @@ -1111,21 +1103,19 @@ def _create_task_instances( # TODO[HA]: We probably need to savepoint this so we can keep the transaction alive. session.rollback() - def _revise_mapped_task_indexes(self, task, session: Session): + def _revise_mapped_task_indexes(self, task: MappedOperator, session: Session) -> Iterable[TI]: """Check if task increased or reduced in length and handle appropriately""" - from airflow.models.taskinstance import TaskInstance from airflow.settings import task_instance_mutation_hook - task.run_time_mapped_ti_count.cache_clear() - total_length = ( - task.parse_time_mapped_ti_count - or task.run_time_mapped_ti_count(self.run_id, session=session) - or 0 - ) - query = session.query(TaskInstance.map_index).filter( - TaskInstance.dag_id == self.dag_id, - TaskInstance.task_id == task.task_id, - TaskInstance.run_id == self.run_id, + task.get_mapped_ti_count.cache_clear() # type: ignore[attr-defined] + total_length = task.get_mapped_ti_count(self.run_id, session=session) + if total_length is None: # Upstreams not ready, don't need to revise this yet. + return [] + + query = session.query(TI.map_index).filter( + TI.dag_id == self.dag_id, + TI.task_id == task.task_id, + TI.run_id == self.run_id, ) existing_indexes = {i for (i,) in query} missing_indexes = set(range(total_length)).difference(existing_indexes) @@ -1134,7 +1124,7 @@ def _revise_mapped_task_indexes(self, task, session: Session): if missing_indexes: for index in missing_indexes: - ti = TaskInstance(task, run_id=self.run_id, map_index=index, state=None) + ti = TI(task, run_id=self.run_id, map_index=index, state=None) self.log.debug("Expanding TIs upserted %s", ti) task_instance_mutation_hook(ti) ti = session.merge(ti) @@ -1142,12 +1132,12 @@ def _revise_mapped_task_indexes(self, task, session: Session): session.flush() created_tis.append(ti) elif removed_indexes: - session.query(TaskInstance).filter( - TaskInstance.dag_id == self.dag_id, - TaskInstance.task_id == task.task_id, - TaskInstance.run_id == self.run_id, - TaskInstance.map_index.in_(removed_indexes), - ).update({TaskInstance.state: TaskInstanceState.REMOVED}) + session.query(TI).filter( + TI.dag_id == self.dag_id, + TI.task_id == task.task_id, + TI.run_id == self.run_id, + TI.map_index.in_(removed_indexes), + ).update({TI.state: TaskInstanceState.REMOVED}) session.flush() return created_tis diff --git a/airflow/models/mappedoperator.py b/airflow/models/mappedoperator.py index deba1e9d8c44d..2d5d00cf3dfe9 100644 --- a/airflow/models/mappedoperator.py +++ b/airflow/models/mappedoperator.py @@ -728,15 +728,23 @@ def iter_mapped_dependencies(self) -> Iterator[Operator]: def parse_time_mapped_ti_count(self) -> int | None: """Number of mapped TaskInstances that can be created at DagRun create time. + This only considers literal mapped arguments, and would return *None* + when any non-literal values are used for mapping. + :return: None if non-literal mapped arg encountered, or the total number of mapped TIs this task should have. """ return self._get_specified_expand_input().get_parse_time_mapped_ti_count() @cache - def run_time_mapped_ti_count(self, run_id: str, *, session: Session) -> int | None: + def get_mapped_ti_count(self, run_id: str, *, session: Session) -> int | None: """Number of mapped TaskInstances that can be created at run time. + This considers both literal and non-literal mapped arguments, and the + result is therefore available when all depended tasks have finished. The + return value should be identical to ``parse_time_mapped_ti_count`` if + all mapped arguments are literal. + :return: None if upstream tasks are not complete yet, or the total number of mapped TIs this task should have. """ diff --git a/tests/test_utils/mapping.py b/tests/test_utils/mapping.py index 5cfa230369167..984446343cb89 100644 --- a/tests/test_utils/mapping.py +++ b/tests/test_utils/mapping.py @@ -42,4 +42,4 @@ def expand_mapped_task( session.flush() mapped.expand_mapped_task(run_id, session=session) - mapped.run_time_mapped_ti_count.cache_clear() # type: ignore[attr-defined] + mapped.get_mapped_ti_count.cache_clear() # type: ignore[attr-defined]