Skip to content

Commit

Permalink
Promote Operator.output more (#25617)
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr committed Sep 1, 2022
1 parent 9b2a859 commit c4d0581
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 29 deletions.
22 changes: 12 additions & 10 deletions docs/apache-airflow/concepts/dynamic-task-mapping.rst
Original file line number Diff line number Diff line change
Expand Up @@ -180,17 +180,19 @@ It is possible to use ``partial`` and ``expand`` with classic style operators as
Mapping over result of classic operators
----------------------------------------

If you want to map over the result of a classic operator you will need to create an ``XComArg`` object manually.
If you want to map over the result of a classic operator, you should explicitly reference the *output*, instead of the operator itself.

.. code-block:: python
from airflow import XComArg
# Create a list of data inputs.
extract = ExtractOperator(task_id="extract")
task = MyOperator(task_id="source")
# Expand the operator to transform each input.
transform = TransformOperator.partial(task_id="transform").expand(input=extract.output)
downstream = MyOperator2.partial(task_id="consumer").expand(input=XComArg(task))
# Collect the transformed inputs, expand the operator to load each one of them to the target.
load = LoadOperator.partial(task_id="load").expand(input=transform.output)
.. note:: Only a return value can be mapped against. Therefore, it is not allowed to set the ``XComArg``'s ``key`` property manually.
Mixing TaskFlow and classic operators
=====================================
Expand All @@ -201,7 +203,7 @@ In this example you have a regular data delivery to an S3 bucket and want to app
from datetime import datetime
from airflow import DAG, XComArg
from airflow import DAG
from airflow.decorators import task
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.operators.s3 import S3ListOperator
Expand All @@ -225,7 +227,7 @@ In this example you have a regular data delivery to an S3 bucket and want to app
return sum(lines)
counts = count_lines.partial(aws_conn_id="aws_default", bucket=list_filenames.bucket).expand(
filename=XComArg(list_filenames)
filename=list_filenames.output
)
total(lines=counts)
Expand Down Expand Up @@ -266,7 +268,7 @@ Similar to ``expand``, you can also map against a XCom that returns a list of di
}
copy_kwargs = create_copy_kwargs.expand(filename=XComArg(list_filenames))
copy_kwargs = create_copy_kwargs.expand(filename=list_filenames.output)
# Copy files to another bucket, based on the file's extension.
copy_filenames = S3CopyObjectOperator.partial(
Expand Down Expand Up @@ -318,7 +320,7 @@ Since it is common to want to transform the output data format for task mapping,
}
copy_kwargs = XComArg(list_filenames).map(create_copy_kwargs)
copy_kwargs = list_filenames.output.map(create_copy_kwargs)
# Unchanged.
copy_filenames = S3CopyObjectOperator.partial(...).expand_kwargs(copy_kwargs)
Expand All @@ -345,7 +347,7 @@ This is especially useful for conditional logic in task mapping. For example, if
)
list_filenames_b = ["rename_1", "rename_2", "rename_3", ...]
filenames_a_b = XComArg(list_filenames_a).zip(list_filenames_b)
filenames_a_b = list_filenames_a.output.zip(list_filenames_b)
@task
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from airflow.models.baseoperator import BaseOperator
from airflow.models.dagbag import DagBag
from airflow.models.taskmap import TaskMap
from airflow.models.xcom_arg import XComArg
from airflow.security import permissions
from airflow.utils.platform import getuser
from airflow.utils.session import provide_session
Expand Down Expand Up @@ -91,7 +90,7 @@ def create_dag_runs_with_mapped_tasks(self, dag_maker, session, dags={}):
count = dags[dag_id]['success'] + dags[dag_id]['running']
with dag_maker(session=session, dag_id=dag_id, start_date=DEFAULT_DATETIME_1):
task1 = BaseOperator(task_id="op1")
mapped = MockOperator.partial(task_id='task_2').expand(arg2=XComArg(task1))
mapped = MockOperator.partial(task_id='task_2').expand(arg2=task1.output)

dr = dag_maker.create_dagrun(run_id=f"run_{dag_id}")

Expand Down
3 changes: 1 addition & 2 deletions tests/decorators/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,8 +735,7 @@ def fn(arg1, arg2):

with dag_maker(session=session):
task1 = BaseOperator(task_id="op1")
xcom_arg = XComArg(task1)
mapped = fn.partial(arg2='{{ ti.task_id }}').expand(arg1=xcom_arg)
mapped = fn.partial(arg2='{{ ti.task_id }}').expand(arg1=task1.output)

dr = dag_maker.create_dagrun()
ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session)
Expand Down
7 changes: 3 additions & 4 deletions tests/models/test_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from airflow.models import DAG, DagBag, DagModel, DagRun, TaskInstance as TI, clear_task_instances
from airflow.models.baseoperator import BaseOperator
from airflow.models.taskmap import TaskMap
from airflow.models.xcom_arg import XComArg
from airflow.operators.empty import EmptyOperator
from airflow.operators.python import ShortCircuitOperator
from airflow.serialization.serialized_objects import SerializedDAG
Expand Down Expand Up @@ -1036,7 +1035,7 @@ def task_2(arg2):
dag._remove_task('task_2')

with dag:
mapped = task_2.expand(arg2=XComArg(t1)).operator
mapped = task_2.expand(arg2=t1.output).operator

# At this point, we need to test that the change works on the serialized
# DAG (which is what the scheduler operates on)
Expand Down Expand Up @@ -1667,7 +1666,7 @@ def test_mapped_mixed__literal_not_expanded_at_create(dag_maker, session):
literal = [1, 2, 3, 4]
with dag_maker(session=session):
task = BaseOperator(task_id='task_1')
mapped = MockOperator.partial(task_id='task_2').expand(arg1=literal, arg2=XComArg(task))
mapped = MockOperator.partial(task_id='task_2').expand(arg1=literal, arg2=task.output)

dr = dag_maker.create_dagrun()
query = (
Expand All @@ -1686,7 +1685,7 @@ def test_mapped_mixed__literal_not_expanded_at_create(dag_maker, session):
def test_ti_scheduling_mapped_zero_length(dag_maker, session):
with dag_maker(session=session):
task = BaseOperator(task_id='task_1')
mapped = MockOperator.partial(task_id='task_2').expand(arg2=XComArg(task))
mapped = MockOperator.partial(task_id='task_2').expand(arg2=task.output)

dr: DagRun = dag_maker.create_dagrun()
ti1, ti2 = sorted(dr.task_instances, key=lambda ti: ti.task_id)
Expand Down
20 changes: 9 additions & 11 deletions tests/models/test_mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_map_xcom_arg():
"""Test that dependencies are correct when mapping with an XComArg"""
with DAG("test-dag", start_date=DEFAULT_DATE):
task1 = BaseOperator(task_id="op1")
mapped = MockOperator.partial(task_id='task_2').expand(arg2=XComArg(task1))
mapped = MockOperator.partial(task_id='task_2').expand(arg2=task1.output)
finish = MockOperator(task_id="finish")

mapped >> finish
Expand All @@ -116,8 +116,8 @@ def execute(self, context):
with dag_maker("test-dag", session=session, start_date=DEFAULT_DATE) as dag:
upstream_return = [1, 2, 3]
task1 = PushExtraXComOperator(return_value=upstream_return, task_id="task_1")
task2 = PushExtraXComOperator.partial(task_id='task_2').expand(return_value=XComArg(task1))
task3 = PushExtraXComOperator.partial(task_id='task_3').expand(return_value=XComArg(task2))
task2 = PushExtraXComOperator.partial(task_id='task_2').expand(return_value=task1.output)
task3 = PushExtraXComOperator.partial(task_id='task_3').expand(return_value=task2.output)

dr = dag_maker.create_dagrun()
ti_1 = dr.get_task_instance("task_1", session)
Expand Down Expand Up @@ -184,7 +184,7 @@ def test_expand_mapped_task_instance(dag_maker, session, num_existing_tis, expec
literal = [1, 2, {'a': 'b'}]
with dag_maker(session=session):
task1 = BaseOperator(task_id="op1")
mapped = MockOperator.partial(task_id='task_2').expand(arg2=XComArg(task1))
mapped = MockOperator.partial(task_id='task_2').expand(arg2=task1.output)

dr = dag_maker.create_dagrun()

Expand Down Expand Up @@ -228,7 +228,7 @@ def test_expand_mapped_task_instance(dag_maker, session, num_existing_tis, expec
def test_expand_mapped_task_instance_skipped_on_zero(dag_maker, session):
with dag_maker(session=session):
task1 = BaseOperator(task_id="op1")
mapped = MockOperator.partial(task_id='task_2').expand(arg2=XComArg(task1))
mapped = MockOperator.partial(task_id='task_2').expand(arg2=task1.output)

dr = dag_maker.create_dagrun()

Expand Down Expand Up @@ -281,10 +281,8 @@ def __init__(self, value, arg1, **kwargs):

with dag_maker(session=session):
task1 = BaseOperator(task_id="op1")
xcom_arg = XComArg(task1)
mapped = MyOperator.partial(task_id='a', arg2='{{ ti.task_id }}').expand(
value=xcom_arg, arg1=xcom_arg
)
output1 = task1.output
mapped = MyOperator.partial(task_id='a', arg2='{{ ti.task_id }}').expand(value=output1, arg1=output1)

dr = dag_maker.create_dagrun()
ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session)
Expand Down Expand Up @@ -357,7 +355,7 @@ def test_expand_kwargs_mapped_task_instance(dag_maker, session, num_existing_tis
literal = [{"arg1": "a"}, {"arg1": "b"}, {"arg1": "c"}]
with dag_maker(session=session):
task1 = BaseOperator(task_id="op1")
mapped = MockOperator.partial(task_id='task_2').expand_kwargs(XComArg(task1))
mapped = MockOperator.partial(task_id='task_2').expand_kwargs(task1.output)

dr = dag_maker.create_dagrun()

Expand Down Expand Up @@ -408,7 +406,7 @@ def test_expand_kwargs_mapped_task_instance(dag_maker, session, num_existing_tis
def test_expand_kwargs_render_template_fields_validating_operator(dag_maker, session, map_index, expected):
with dag_maker(session=session):
task1 = BaseOperator(task_id="op1")
mapped = MockOperator.partial(task_id='a', arg2='{{ ti.task_id }}').expand_kwargs(XComArg(task1))
mapped = MockOperator.partial(task_id='a', arg2='{{ ti.task_id }}').expand_kwargs(task1.output)

dr = dag_maker.create_dagrun()
ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session)
Expand Down

0 comments on commit c4d0581

Please sign in to comment.