From cc4cde987cbc073a223b531a7674856cb2847f9a Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Fri, 11 Nov 2022 02:22:51 +0100 Subject: [PATCH] Add max_wait for exponential_backoff in BaseSensor (#27597) Co-authored-by: Daniel Standish <15932138+dstandish@users.noreply.github.com> --- airflow/decorators/__init__.pyi | 4 +++- airflow/sensors/base.py | 15 +++++++++++++++ tests/sensors/test_base.py | 22 ++++++++++++++++++++++ 3 files changed, 40 insertions(+), 1 deletion(-) diff --git a/airflow/decorators/__init__.pyi b/airflow/decorators/__init__.pyi index fd17efa174e80..0a6d534b247fa 100644 --- a/airflow/decorators/__init__.pyi +++ b/airflow/decorators/__init__.pyi @@ -14,12 +14,12 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - # This file provides better type hinting and editor autocompletion support for # dynamically generated task decorators. Functions declared in this stub do not # necessarily exist at run time. See "Creating Custom @task Decorators" # documentation for more details. +from datetime import timedelta from typing import Any, Callable, Iterable, Mapping, Union, overload from kubernetes.client import models as k8s @@ -421,6 +421,7 @@ class TaskDecoratorCollection: soft_fail: bool = False, mode: str = ..., exponential_backoff: bool = False, + max_wait: timedelta | float | None = None, **kwargs, ) -> TaskDecorator: """ @@ -444,6 +445,7 @@ class TaskDecoratorCollection: prevent too much load on the scheduler. :param exponential_backoff: allow progressive longer waits between pokes by using exponential backoff algorithm + :param max_wait: maximum wait interval between pokes, can be ``timedelta`` or ``float`` seconds """ @overload def sensor(self, python_callable: Optional[FParams, FReturn] = None) -> Task[FParams, FReturn]: ... diff --git a/airflow/sensors/base.py b/airflow/sensors/base.py index 3cedd31ed4adf..0df1a5f4b7052 100644 --- a/airflow/sensors/base.py +++ b/airflow/sensors/base.py @@ -97,6 +97,7 @@ class BaseSensorOperator(BaseOperator, SkipMixin): prevent too much load on the scheduler. :param exponential_backoff: allow progressive longer waits between pokes by using exponential backoff algorithm + :param max_wait: maximum wait interval between pokes, can be ``timedelta`` or ``float`` seconds """ ui_color: str = "#e6f1f2" @@ -114,6 +115,7 @@ def __init__( soft_fail: bool = False, mode: str = "poke", exponential_backoff: bool = False, + max_wait: timedelta | float | None = None, **kwargs, ) -> None: super().__init__(**kwargs) @@ -122,8 +124,17 @@ def __init__( self.timeout = timeout self.mode = mode self.exponential_backoff = exponential_backoff + self.max_wait = self._coerce_max_wait(max_wait) self._validate_input_values() + @staticmethod + def _coerce_max_wait(max_wait: float | timedelta | None) -> timedelta | None: + if max_wait is None or isinstance(max_wait, timedelta): + return max_wait + if isinstance(max_wait, (int, float)) and max_wait >= 0: + return timedelta(seconds=max_wait) + raise AirflowException("Operator arg `max_wait` must be timedelta object or a non-negative number") + def _validate_input_values(self) -> None: if not isinstance(self.poke_interval, (int, float)) or self.poke_interval < 0: raise AirflowException("The poke_interval must be a non-negative number") @@ -233,6 +244,10 @@ def _get_next_poke_interval( delay_backoff_in_seconds = min(modded_hash, timedelta.max.total_seconds() - 1) new_interval = min(self.timeout - int(run_duration()), delay_backoff_in_seconds) + + if self.max_wait: + new_interval = min(self.max_wait.total_seconds(), new_interval) + self.log.info("new %s interval is %s", self.mode, new_interval) return new_interval diff --git a/tests/sensors/test_base.py b/tests/sensors/test_base.py index ec0907db5d4d1..8545b99f591c6 100644 --- a/tests/sensors/test_base.py +++ b/tests/sensors/test_base.py @@ -509,6 +509,28 @@ def run_duration(): assert interval2 >= sensor.poke_interval assert interval2 > interval1 + def test_sensor_with_exponential_backoff_on_and_max_wait(self): + + sensor = DummySensor( + task_id=SENSOR_OP, + return_value=None, + poke_interval=10, + timeout=60, + exponential_backoff=True, + max_wait=timedelta(seconds=30), + ) + + with patch("airflow.utils.timezone.utcnow") as mock_utctime: + mock_utctime.return_value = DEFAULT_DATE + + started_at = timezone.utcnow() - timedelta(seconds=10) + + def run_duration(): + return (timezone.utcnow - started_at).total_seconds() + + for idx, expected in enumerate([2, 6, 13, 30, 30, 30, 30, 30]): + assert sensor._get_next_poke_interval(started_at, run_duration, idx) == expected + @pytest.mark.backend("mysql") def test_reschedule_poke_interval_too_long_on_mysql(self, make_sensor): with pytest.raises(AirflowException) as ctx: