Skip to content

Commit

Permalink
Add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ephraimbuddy committed Aug 10, 2022
1 parent 2439835 commit 6235c94
Show file tree
Hide file tree
Showing 4 changed files with 299 additions and 4 deletions.
3 changes: 3 additions & 0 deletions airflow/ti_deps/deps/ready_to_reschedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def _get_dep_statuses(self, ti, session, dep_context):
"""
is_mapped = ti.task.is_mapped
if not is_mapped and not getattr(ti.task, "reschedule", False):
# Mapped sensors don't currently have the reschedule property,
yield self._passing_status(reason="Task is not in reschedule mode.")
return

Expand All @@ -63,6 +64,8 @@ def _get_dep_statuses(self, ti, session, dep_context):
.first()
)
if not task_reschedule:
# Because mapped sensors don't have the reschedule property, here's the last resort
# and we need a slightly different passing reason
if is_mapped:
yield self._passing_status(reason="The task is mapped and not in reschedule mode")
return
Expand Down
217 changes: 217 additions & 0 deletions tests/models/test_taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -801,6 +801,174 @@ def run_ti_and_assert(
done, fail = True, False
run_ti_and_assert(date4, date3, date4, 60, State.SUCCESS, 3, 0)

def test_mapped_reschedule_handling(self, dag_maker):
"""
Test that mapped task reschedules are handled properly
"""
# Return values of the python sensor callable, modified during tests
done = False
fail = False

def func():
if fail:
raise AirflowException()
return done

with dag_maker(dag_id='test_reschedule_handling') as dag:

task = PythonSensor.partial(
task_id='test_reschedule_handling_sensor',
mode='reschedule',
python_callable=func,
retries=1,
retry_delay=datetime.timedelta(seconds=0),
).expand(poke_interval=[0])

ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0]

ti.task = task
assert ti._try_number == 0
assert ti.try_number == 1

def run_ti_and_assert(
run_date,
expected_start_date,
expected_end_date,
expected_duration,
expected_state,
expected_try_number,
expected_task_reschedule_count,
):
ti.refresh_from_task(task)
with freeze_time(run_date):
try:
ti.run()
except AirflowException:
if not fail:
raise
ti.refresh_from_db()
assert ti.state == expected_state
assert ti._try_number == expected_try_number
assert ti.try_number == expected_try_number + 1
assert ti.start_date == expected_start_date
assert ti.end_date == expected_end_date
assert ti.duration == expected_duration
trs = TaskReschedule.find_for_task_instance(ti)
assert len(trs) == expected_task_reschedule_count

date1 = timezone.utcnow()
date2 = date1 + datetime.timedelta(minutes=1)
date3 = date2 + datetime.timedelta(minutes=1)
date4 = date3 + datetime.timedelta(minutes=1)

# Run with multiple reschedules.
# During reschedule the try number remains the same, but each reschedule is recorded.
# The start date is expected to remain the initial date, hence the duration increases.
# When finished the try number is incremented and there is no reschedule expected
# for this try.

done, fail = False, False
run_ti_and_assert(date1, date1, date1, 0, State.UP_FOR_RESCHEDULE, 0, 1)

done, fail = False, False
run_ti_and_assert(date2, date1, date2, 60, State.UP_FOR_RESCHEDULE, 0, 2)

done, fail = False, False
run_ti_and_assert(date3, date1, date3, 120, State.UP_FOR_RESCHEDULE, 0, 3)

done, fail = True, False
run_ti_and_assert(date4, date1, date4, 180, State.SUCCESS, 1, 0)

# Clear the task instance.
dag.clear()
ti.refresh_from_db()
assert ti.state == State.NONE
assert ti._try_number == 1

# Run again after clearing with reschedules and a retry.
# The retry increments the try number, and for that try no reschedule is expected.
# After the retry the start date is reset, hence the duration is also reset.

done, fail = False, False
run_ti_and_assert(date1, date1, date1, 0, State.UP_FOR_RESCHEDULE, 1, 1)

done, fail = False, True
run_ti_and_assert(date2, date1, date2, 60, State.UP_FOR_RETRY, 2, 0)

done, fail = False, False
run_ti_and_assert(date3, date3, date3, 0, State.UP_FOR_RESCHEDULE, 2, 1)

done, fail = True, False
run_ti_and_assert(date4, date3, date4, 60, State.SUCCESS, 3, 0)

@pytest.mark.usefixtures('test_pool')
def test_mapped_task_reschedule_handling_clear_reschedules(self, dag_maker):
"""
Test that mapped task reschedules clearing are handled properly
"""
# Return values of the python sensor callable, modified during tests
done = False
fail = False

def func():
if fail:
raise AirflowException()
return done

with dag_maker(dag_id='test_reschedule_handling') as dag:
task = PythonSensor.partial(
task_id='test_reschedule_handling_sensor',
mode='reschedule',
python_callable=func,
retries=1,
retry_delay=datetime.timedelta(seconds=0),
pool='test_pool',
).expand(poke_interval=[0])
ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0]
ti.task = task
assert ti._try_number == 0
assert ti.try_number == 1

def run_ti_and_assert(
run_date,
expected_start_date,
expected_end_date,
expected_duration,
expected_state,
expected_try_number,
expected_task_reschedule_count,
):
ti.refresh_from_task(task)
with freeze_time(run_date):
try:
ti.run()
except AirflowException:
if not fail:
raise
ti.refresh_from_db()
assert ti.state == expected_state
assert ti._try_number == expected_try_number
assert ti.try_number == expected_try_number + 1
assert ti.start_date == expected_start_date
assert ti.end_date == expected_end_date
assert ti.duration == expected_duration
trs = TaskReschedule.find_for_task_instance(ti)
assert len(trs) == expected_task_reschedule_count

date1 = timezone.utcnow()

done, fail = False, False
run_ti_and_assert(date1, date1, date1, 0, State.UP_FOR_RESCHEDULE, 0, 1)

# Clear the task instance.
dag.clear()
ti.refresh_from_db()
assert ti.state == State.NONE
assert ti._try_number == 0
# Check that reschedules for ti have also been cleared.
trs = TaskReschedule.find_for_task_instance(ti)
assert not trs

@pytest.mark.usefixtures('test_pool')
def test_reschedule_handling_clear_reschedules(self, dag_maker):
"""
Expand Down Expand Up @@ -2541,6 +2709,55 @@ def timeout():
assert ti.state == State.FAILED


@pytest.mark.parametrize("mode", ["poke", "reschedule"])
@pytest.mark.parametrize("retries", [0, 1])
def test_mapped_sensor_timeout(mode, retries, dag_maker):
"""
Test that AirflowSensorTimeout does not cause mapped sensor to retry.
"""

def timeout():
raise AirflowSensorTimeout

mock_on_failure = mock.MagicMock()
with dag_maker(dag_id=f'test_sensor_timeout_{mode}_{retries}'):
PythonSensor.partial(
task_id='test_raise_sensor_timeout',
python_callable=timeout,
on_failure_callback=mock_on_failure,
retries=retries,
).expand(mode=[mode])
ti = dag_maker.create_dagrun(execution_date=timezone.utcnow()).task_instances[0]

with pytest.raises(AirflowSensorTimeout):
ti.run()

assert mock_on_failure.called
assert ti.state == State.FAILED


@pytest.mark.parametrize("mode", ["poke", "reschedule"])
@pytest.mark.parametrize("retries", [0, 1])
def test_mapped_sensor_works(mode, retries, dag_maker):
"""
Test that mapped sensors reaches success state.
"""

def timeout(ti):
return 1

with dag_maker(dag_id=f'test_sensor_timeout_{mode}_{retries}'):
PythonSensor.partial(
task_id='test_raise_sensor_timeout',
python_callable=timeout,
retries=retries,
).expand(mode=[mode])
ti = dag_maker.create_dagrun().task_instances[0]

ti.run()
assert ti.state == State.SUCCESS


class TestTaskInstanceRecordTaskMapXComPush:
"""Test TI.xcom_push() correctly records return values for task-mapping."""

Expand Down
6 changes: 2 additions & 4 deletions tests/serialization/test_dag_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -1614,7 +1614,7 @@ def poke(self, context: Context):
assert op.deps == serialized_op.deps

@pytest.mark.parametrize("mode", ["poke", "reschedule"])
def test_serialize_mapped_sensor(self, mode):
def test_serialize_mapped_sensor_has_reschedule_dep(self, mode):
from airflow.sensors.base import BaseSensorOperator

class DummySensor(BaseSensorOperator):
Expand All @@ -1626,9 +1626,7 @@ def poke(self, context: Context):
blob = SerializedBaseOperator.serialize_mapped_operator(op)
assert "deps" in blob

serialized_op = SerializedBaseOperator.deserialize_operator(blob)
assert serialized_op.reschedule == (mode == "reschedule")
assert op.deps == serialized_op.deps
assert 'airflow.ti_deps.deps.ready_to_reschedule.ReadyToRescheduleDep' in blob['deps']

@pytest.mark.parametrize(
"passed_success_callback, expected_value",
Expand Down
77 changes: 77 additions & 0 deletions tests/ti_deps/deps/test_ready_to_reschedule_dep.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,24 @@ def _get_task_reschedule(self, reschedule_date):
)
return reschedule

def _get_mapped_task_instance(self, state):
dag = DAG('test_dag')
task = Mock(dag=dag, reschedule=True, is_mapped=True)
ti = TaskInstance(task=task, state=state, run_id=None)
return ti

def _get_mapped_task_reschedule(self, reschedule_date):
task = Mock(dag_id='test_dag', task_id='test_task', is_mapped=True)
reschedule = TaskReschedule(
task=task,
run_id=None,
try_number=None,
start_date=reschedule_date,
end_date=reschedule_date,
reschedule_date=reschedule_date,
)
return reschedule

def test_should_pass_if_ignore_in_reschedule_period_is_set(self):
ti = self._get_task_instance(State.UP_FOR_RESCHEDULE)
dep_context = DepContext(ignore_in_reschedule_period=True)
Expand Down Expand Up @@ -103,3 +121,62 @@ def test_should_fail_before_reschedule_date_multiple(self, mock_query_for_task_i
][-1]
ti = self._get_task_instance(State.UP_FOR_RESCHEDULE)
assert not ReadyToRescheduleDep().is_met(ti=ti)

def test_mapped_task_should_pass_if_ignore_in_reschedule_period_is_set(self):
ti = self._get_mapped_task_instance(State.UP_FOR_RESCHEDULE)
dep_context = DepContext(ignore_in_reschedule_period=True)
assert ReadyToRescheduleDep().is_met(ti=ti, dep_context=dep_context)

@patch('airflow.models.taskreschedule.TaskReschedule.query_for_task_instance')
def test_mapped_task_should_pass_if_not_reschedule_mode(self, mock_query_for_task_instance):
mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value = []
ti = self._get_mapped_task_instance(State.UP_FOR_RESCHEDULE)
del ti.task.reschedule
assert ReadyToRescheduleDep().is_met(ti=ti)

def test_mapped_task_should_pass_if_not_in_none_state(self):
ti = self._get_mapped_task_instance(State.UP_FOR_RETRY)
assert ReadyToRescheduleDep().is_met(ti=ti)

@patch('airflow.models.taskreschedule.TaskReschedule.query_for_task_instance')
def test_mapped_should_pass_if_no_reschedule_record_exists(self, mock_query_for_task_instance):
mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value = []
ti = self._get_mapped_task_instance(State.NONE)
assert ReadyToRescheduleDep().is_met(ti=ti)

@patch('airflow.models.taskreschedule.TaskReschedule.query_for_task_instance')
def test_mapped_should_pass_after_reschedule_date_one(self, mock_query_for_task_instance):
mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value = (
self._get_mapped_task_reschedule(utcnow() - timedelta(minutes=1))
)
ti = self._get_mapped_task_instance(State.UP_FOR_RESCHEDULE)
assert ReadyToRescheduleDep().is_met(ti=ti)

@patch('airflow.models.taskreschedule.TaskReschedule.query_for_task_instance')
def test_mapped_task_should_pass_after_reschedule_date_multiple(self, mock_query_for_task_instance):
mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value = [
self._get_mapped_task_reschedule(utcnow() - timedelta(minutes=21)),
self._get_mapped_task_reschedule(utcnow() - timedelta(minutes=11)),
self._get_mapped_task_reschedule(utcnow() - timedelta(minutes=1)),
][-1]
ti = self._get_mapped_task_instance(State.UP_FOR_RESCHEDULE)
assert ReadyToRescheduleDep().is_met(ti=ti)

@patch('airflow.models.taskreschedule.TaskReschedule.query_for_task_instance')
def test_mapped_task_should_fail_before_reschedule_date_one(self, mock_query_for_task_instance):
mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value = (
self._get_mapped_task_reschedule(utcnow() + timedelta(minutes=1))
)

ti = self._get_mapped_task_instance(State.UP_FOR_RESCHEDULE)
assert not ReadyToRescheduleDep().is_met(ti=ti)

@patch('airflow.models.taskreschedule.TaskReschedule.query_for_task_instance')
def test_mapped_task_should_fail_before_reschedule_date_multiple(self, mock_query_for_task_instance):
mock_query_for_task_instance.return_value.with_entities.return_value.first.return_value = [
self._get_mapped_task_reschedule(utcnow() - timedelta(minutes=19)),
self._get_mapped_task_reschedule(utcnow() - timedelta(minutes=9)),
self._get_mapped_task_reschedule(utcnow() + timedelta(minutes=1)),
][-1]
ti = self._get_mapped_task_instance(State.UP_FOR_RESCHEDULE)
assert not ReadyToRescheduleDep().is_met(ti=ti)

0 comments on commit 6235c94

Please sign in to comment.