Skip to content

Commit

Permalink
Add max_wait for exponential_backoff in BaseSensor (#27597)
Browse files Browse the repository at this point in the history
Co-authored-by: Daniel Standish <15932138+dstandish@users.noreply.github.com>
  • Loading branch information
hussein-awala and dstandish committed Nov 11, 2022
1 parent 1059de6 commit cc4cde9
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 1 deletion.
4 changes: 3 additions & 1 deletion airflow/decorators/__init__.pyi
Expand Up @@ -14,12 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

# This file provides better type hinting and editor autocompletion support for
# dynamically generated task decorators. Functions declared in this stub do not
# necessarily exist at run time. See "Creating Custom @task Decorators"
# documentation for more details.

from datetime import timedelta
from typing import Any, Callable, Iterable, Mapping, Union, overload

from kubernetes.client import models as k8s
Expand Down Expand Up @@ -421,6 +421,7 @@ class TaskDecoratorCollection:
soft_fail: bool = False,
mode: str = ...,
exponential_backoff: bool = False,
max_wait: timedelta | float | None = None,
**kwargs,
) -> TaskDecorator:
"""
Expand All @@ -444,6 +445,7 @@ class TaskDecoratorCollection:
prevent too much load on the scheduler.
:param exponential_backoff: allow progressive longer waits between
pokes by using exponential backoff algorithm
:param max_wait: maximum wait interval between pokes, can be ``timedelta`` or ``float`` seconds
"""
@overload
def sensor(self, python_callable: Optional[FParams, FReturn] = None) -> Task[FParams, FReturn]: ...
Expand Down
15 changes: 15 additions & 0 deletions airflow/sensors/base.py
Expand Up @@ -97,6 +97,7 @@ class BaseSensorOperator(BaseOperator, SkipMixin):
prevent too much load on the scheduler.
:param exponential_backoff: allow progressive longer waits between
pokes by using exponential backoff algorithm
:param max_wait: maximum wait interval between pokes, can be ``timedelta`` or ``float`` seconds
"""

ui_color: str = "#e6f1f2"
Expand All @@ -114,6 +115,7 @@ def __init__(
soft_fail: bool = False,
mode: str = "poke",
exponential_backoff: bool = False,
max_wait: timedelta | float | None = None,
**kwargs,
) -> None:
super().__init__(**kwargs)
Expand All @@ -122,8 +124,17 @@ def __init__(
self.timeout = timeout
self.mode = mode
self.exponential_backoff = exponential_backoff
self.max_wait = self._coerce_max_wait(max_wait)
self._validate_input_values()

@staticmethod
def _coerce_max_wait(max_wait: float | timedelta | None) -> timedelta | None:
if max_wait is None or isinstance(max_wait, timedelta):
return max_wait
if isinstance(max_wait, (int, float)) and max_wait >= 0:
return timedelta(seconds=max_wait)
raise AirflowException("Operator arg `max_wait` must be timedelta object or a non-negative number")

def _validate_input_values(self) -> None:
if not isinstance(self.poke_interval, (int, float)) or self.poke_interval < 0:
raise AirflowException("The poke_interval must be a non-negative number")
Expand Down Expand Up @@ -233,6 +244,10 @@ def _get_next_poke_interval(

delay_backoff_in_seconds = min(modded_hash, timedelta.max.total_seconds() - 1)
new_interval = min(self.timeout - int(run_duration()), delay_backoff_in_seconds)

if self.max_wait:
new_interval = min(self.max_wait.total_seconds(), new_interval)

self.log.info("new %s interval is %s", self.mode, new_interval)
return new_interval

Expand Down
22 changes: 22 additions & 0 deletions tests/sensors/test_base.py
Expand Up @@ -509,6 +509,28 @@ def run_duration():
assert interval2 >= sensor.poke_interval
assert interval2 > interval1

def test_sensor_with_exponential_backoff_on_and_max_wait(self):

sensor = DummySensor(
task_id=SENSOR_OP,
return_value=None,
poke_interval=10,
timeout=60,
exponential_backoff=True,
max_wait=timedelta(seconds=30),
)

with patch("airflow.utils.timezone.utcnow") as mock_utctime:
mock_utctime.return_value = DEFAULT_DATE

started_at = timezone.utcnow() - timedelta(seconds=10)

def run_duration():
return (timezone.utcnow - started_at).total_seconds()

for idx, expected in enumerate([2, 6, 13, 30, 30, 30, 30, 30]):
assert sensor._get_next_poke_interval(started_at, run_duration, idx) == expected

@pytest.mark.backend("mysql")
def test_reschedule_poke_interval_too_long_on_mysql(self, make_sensor):
with pytest.raises(AirflowException) as ctx:
Expand Down

0 comments on commit cc4cde9

Please sign in to comment.