Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Promote Operator.output more #25617

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 12 additions & 10 deletions docs/apache-airflow/concepts/dynamic-task-mapping.rst
Original file line number Diff line number Diff line change
Expand Up @@ -180,17 +180,19 @@ It is possible to use ``partial`` and ``expand`` with classic style operators as
Mapping over result of classic operators
----------------------------------------

If you want to map over the result of a classic operator you will need to create an ``XComArg`` object manually.
If you want to map over the result of a classic operator, you should explicitly reference the *output*, instead of the operator itself.

.. code-block:: python

from airflow import XComArg
# Create a list of data inputs.
extract = ExtractOperator(task_id="extract")

task = MyOperator(task_id="source")
# Expand the operator to transform each input.
transform = TransformOperator.partial(task_id="transform").expand(input=extract.output)

downstream = MyOperator2.partial(task_id="consumer").expand(input=XComArg(task))
# Collect the transformed inputs, expand the operator to load each one of them to the target.
load = LoadOperator.partial(task_id="load").expand(input=transform.output)

.. note:: Only a return value can be mapped against. Therefore, it is not allowed to set the ``XComArg``'s ``key`` property manually.

Mixing TaskFlow and classic operators
=====================================
Expand All @@ -201,7 +203,7 @@ In this example you have a regular data delivery to an S3 bucket and want to app

from datetime import datetime

from airflow import DAG, XComArg
from airflow import DAG
from airflow.decorators import task
from airflow.providers.amazon.aws.hooks.s3 import S3Hook
from airflow.providers.amazon.aws.operators.s3 import S3ListOperator
Expand All @@ -225,7 +227,7 @@ In this example you have a regular data delivery to an S3 bucket and want to app
return sum(lines)

counts = count_lines.partial(aws_conn_id="aws_default", bucket=list_filenames.bucket).expand(
filename=XComArg(list_filenames)
filename=list_filenames.output
)

total(lines=counts)
Expand Down Expand Up @@ -266,7 +268,7 @@ Similar to ``expand``, you can also map against a XCom that returns a list of di
}


copy_kwargs = create_copy_kwargs.expand(filename=XComArg(list_filenames))
copy_kwargs = create_copy_kwargs.expand(filename=list_filenames.output)

# Copy files to another bucket, based on the file's extension.
copy_filenames = S3CopyObjectOperator.partial(
Expand Down Expand Up @@ -318,7 +320,7 @@ Since it is common to want to transform the output data format for task mapping,
}


copy_kwargs = XComArg(list_filenames).map(create_copy_kwargs)
copy_kwargs = list_filenames.output.map(create_copy_kwargs)

# Unchanged.
copy_filenames = S3CopyObjectOperator.partial(...).expand_kwargs(copy_kwargs)
Expand All @@ -345,7 +347,7 @@ This is especially useful for conditional logic in task mapping. For example, if
)
list_filenames_b = ["rename_1", "rename_2", "rename_3", ...]

filenames_a_b = XComArg(list_filenames_a).zip(list_filenames_b)
filenames_a_b = list_filenames_a.output.zip(list_filenames_b)


@task
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from airflow.models.baseoperator import BaseOperator
from airflow.models.dagbag import DagBag
from airflow.models.taskmap import TaskMap
from airflow.models.xcom_arg import XComArg
from airflow.security import permissions
from airflow.utils.platform import getuser
from airflow.utils.session import provide_session
Expand Down Expand Up @@ -91,7 +90,7 @@ def create_dag_runs_with_mapped_tasks(self, dag_maker, session, dags={}):
count = dags[dag_id]['success'] + dags[dag_id]['running']
with dag_maker(session=session, dag_id=dag_id, start_date=DEFAULT_DATETIME_1):
task1 = BaseOperator(task_id="op1")
mapped = MockOperator.partial(task_id='task_2').expand(arg2=XComArg(task1))
mapped = MockOperator.partial(task_id='task_2').expand(arg2=task1.output)

dr = dag_maker.create_dagrun(run_id=f"run_{dag_id}")

Expand Down
3 changes: 1 addition & 2 deletions tests/decorators/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,8 +735,7 @@ def fn(arg1, arg2):

with dag_maker(session=session):
task1 = BaseOperator(task_id="op1")
xcom_arg = XComArg(task1)
mapped = fn.partial(arg2='{{ ti.task_id }}').expand(arg1=xcom_arg)
mapped = fn.partial(arg2='{{ ti.task_id }}').expand(arg1=task1.output)

dr = dag_maker.create_dagrun()
ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session)
Expand Down
7 changes: 3 additions & 4 deletions tests/models/test_dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from airflow.models import DAG, DagBag, DagModel, DagRun, TaskInstance as TI, clear_task_instances
from airflow.models.baseoperator import BaseOperator
from airflow.models.taskmap import TaskMap
from airflow.models.xcom_arg import XComArg
from airflow.operators.empty import EmptyOperator
from airflow.operators.python import ShortCircuitOperator
from airflow.serialization.serialized_objects import SerializedDAG
Expand Down Expand Up @@ -1036,7 +1035,7 @@ def task_2(arg2):
dag._remove_task('task_2')

with dag:
mapped = task_2.expand(arg2=XComArg(t1)).operator
mapped = task_2.expand(arg2=t1.output).operator

# At this point, we need to test that the change works on the serialized
# DAG (which is what the scheduler operates on)
Expand Down Expand Up @@ -1667,7 +1666,7 @@ def test_mapped_mixed__literal_not_expanded_at_create(dag_maker, session):
literal = [1, 2, 3, 4]
with dag_maker(session=session):
task = BaseOperator(task_id='task_1')
mapped = MockOperator.partial(task_id='task_2').expand(arg1=literal, arg2=XComArg(task))
mapped = MockOperator.partial(task_id='task_2').expand(arg1=literal, arg2=task.output)

dr = dag_maker.create_dagrun()
query = (
Expand All @@ -1686,7 +1685,7 @@ def test_mapped_mixed__literal_not_expanded_at_create(dag_maker, session):
def test_ti_scheduling_mapped_zero_length(dag_maker, session):
with dag_maker(session=session):
task = BaseOperator(task_id='task_1')
mapped = MockOperator.partial(task_id='task_2').expand(arg2=XComArg(task))
mapped = MockOperator.partial(task_id='task_2').expand(arg2=task.output)

dr: DagRun = dag_maker.create_dagrun()
ti1, ti2 = sorted(dr.task_instances, key=lambda ti: ti.task_id)
Expand Down
20 changes: 9 additions & 11 deletions tests/models/test_mappedoperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def test_map_xcom_arg():
"""Test that dependencies are correct when mapping with an XComArg"""
with DAG("test-dag", start_date=DEFAULT_DATE):
task1 = BaseOperator(task_id="op1")
mapped = MockOperator.partial(task_id='task_2').expand(arg2=XComArg(task1))
mapped = MockOperator.partial(task_id='task_2').expand(arg2=task1.output)
finish = MockOperator(task_id="finish")

mapped >> finish
Expand All @@ -116,8 +116,8 @@ def execute(self, context):
with dag_maker("test-dag", session=session, start_date=DEFAULT_DATE) as dag:
upstream_return = [1, 2, 3]
task1 = PushExtraXComOperator(return_value=upstream_return, task_id="task_1")
task2 = PushExtraXComOperator.partial(task_id='task_2').expand(return_value=XComArg(task1))
task3 = PushExtraXComOperator.partial(task_id='task_3').expand(return_value=XComArg(task2))
task2 = PushExtraXComOperator.partial(task_id='task_2').expand(return_value=task1.output)
task3 = PushExtraXComOperator.partial(task_id='task_3').expand(return_value=task2.output)

dr = dag_maker.create_dagrun()
ti_1 = dr.get_task_instance("task_1", session)
Expand Down Expand Up @@ -184,7 +184,7 @@ def test_expand_mapped_task_instance(dag_maker, session, num_existing_tis, expec
literal = [1, 2, {'a': 'b'}]
with dag_maker(session=session):
task1 = BaseOperator(task_id="op1")
mapped = MockOperator.partial(task_id='task_2').expand(arg2=XComArg(task1))
mapped = MockOperator.partial(task_id='task_2').expand(arg2=task1.output)

dr = dag_maker.create_dagrun()

Expand Down Expand Up @@ -228,7 +228,7 @@ def test_expand_mapped_task_instance(dag_maker, session, num_existing_tis, expec
def test_expand_mapped_task_instance_skipped_on_zero(dag_maker, session):
with dag_maker(session=session):
task1 = BaseOperator(task_id="op1")
mapped = MockOperator.partial(task_id='task_2').expand(arg2=XComArg(task1))
mapped = MockOperator.partial(task_id='task_2').expand(arg2=task1.output)

dr = dag_maker.create_dagrun()

Expand Down Expand Up @@ -281,10 +281,8 @@ def __init__(self, value, arg1, **kwargs):

with dag_maker(session=session):
task1 = BaseOperator(task_id="op1")
xcom_arg = XComArg(task1)
mapped = MyOperator.partial(task_id='a', arg2='{{ ti.task_id }}').expand(
value=xcom_arg, arg1=xcom_arg
)
output1 = task1.output
mapped = MyOperator.partial(task_id='a', arg2='{{ ti.task_id }}').expand(value=output1, arg1=output1)

dr = dag_maker.create_dagrun()
ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session)
Expand Down Expand Up @@ -357,7 +355,7 @@ def test_expand_kwargs_mapped_task_instance(dag_maker, session, num_existing_tis
literal = [{"arg1": "a"}, {"arg1": "b"}, {"arg1": "c"}]
with dag_maker(session=session):
task1 = BaseOperator(task_id="op1")
mapped = MockOperator.partial(task_id='task_2').expand_kwargs(XComArg(task1))
mapped = MockOperator.partial(task_id='task_2').expand_kwargs(task1.output)

dr = dag_maker.create_dagrun()

Expand Down Expand Up @@ -408,7 +406,7 @@ def test_expand_kwargs_mapped_task_instance(dag_maker, session, num_existing_tis
def test_expand_kwargs_render_template_fields_validating_operator(dag_maker, session, map_index, expected):
with dag_maker(session=session):
task1 = BaseOperator(task_id="op1")
mapped = MockOperator.partial(task_id='a', arg2='{{ ti.task_id }}').expand_kwargs(XComArg(task1))
mapped = MockOperator.partial(task_id='a', arg2='{{ ti.task_id }}').expand_kwargs(task1.output)

dr = dag_maker.create_dagrun()
ti: TaskInstance = dr.get_task_instance(task1.task_id, session=session)
Expand Down