diff --git a/airflow/utils/log/file_task_handler.py b/airflow/utils/log/file_task_handler.py index 692c279d0ce84..e20da8cebc7dc 100644 --- a/airflow/utils/log/file_task_handler.py +++ b/airflow/utils/log/file_task_handler.py @@ -29,13 +29,13 @@ from airflow.exceptions import RemovedInAirflow3Warning from airflow.utils.context import Context from airflow.utils.helpers import parse_template_string, render_template_to_string -from airflow.utils.log.logging_mixin import DISABLE_PROPOGATE from airflow.utils.log.non_caching_file_handler import NonCachingFileHandler from airflow.utils.session import create_session from airflow.utils.state import State if TYPE_CHECKING: from airflow.models import TaskInstance + from airflow.utils.log.logging_mixin import SetContextPropagate class FileTaskHandler(logging.Handler): @@ -62,7 +62,7 @@ def __init__(self, base_log_folder: str, filename_template: str | None = None): stacklevel=(2 if type(self) == FileTaskHandler else 3), ) - def set_context(self, ti: TaskInstance): + def set_context(self, ti: TaskInstance) -> None | SetContextPropagate: """ Provide task_instance context to airflow task handler. @@ -73,8 +73,7 @@ def set_context(self, ti: TaskInstance): if self.formatter: self.handler.setFormatter(self.formatter) self.handler.setLevel(self.level) - - return DISABLE_PROPOGATE + return None def emit(self, record): if self.handler: diff --git a/airflow/utils/log/logging_mixin.py b/airflow/utils/log/logging_mixin.py index 8127bde9f7ed8..b8f5a0871c8a9 100644 --- a/airflow/utils/log/logging_mixin.py +++ b/airflow/utils/log/logging_mixin.py @@ -18,18 +18,35 @@ from __future__ import annotations import abc +import enum import logging import re import sys from io import IOBase from logging import Handler, Logger, StreamHandler -from typing import IO +from typing import IO, cast # 7-bit C1 ANSI escape sequences ANSI_ESCAPE = re.compile(r"\x1B[@-_][0-?]*[ -/]*[@-~]") -# Private: A sentinel object -DISABLE_PROPOGATE = object() + +# Private: A sentinel objects +class SetContextPropagate(enum.Enum): + """:meta private:""" + + # If a `set_context` function wants to _keep_ propagation set on it's logger it needs to return this + # special value. + MAINTAIN_PROPAGATE = object() + # Don't use this one anymore! + DISABLE_PROPAGATE = object() + + +def __getattr__(name): + if name in ("DISABLE_PROPOGATE", "DISABLE_PROPAGATE"): + # Compat for spelling on off chance someone is using this directly + # And old object that isn't needed anymore + return SetContextPropagate.DISABLE_PROPAGATE + raise AttributeError(f"module {__name__} has no attribute {name}") def remove_escape_codes(text: str) -> str: @@ -183,13 +200,23 @@ def set_context(logger, value): :param value: value to set """ while logger: + orig_propagate = logger.propagate for handler in logger.handlers: # Not all handlers need to have context passed in so we ignore # the error when handlers do not have set_context defined. - set_context = getattr(handler, "set_context", None) - if set_context and set_context(value) is DISABLE_PROPOGATE: - logger.propagate = False - if logger.propagate is True: + + # Don't use getatrr so we have type checking. And we don't care if handler is actually a + # FileTaskHandler, it just needs to have a set_context function! + if hasattr(handler, "set_context"): + from airflow.utils.log.file_task_handler import FileTaskHandler + + flag = cast(FileTaskHandler, handler).set_context(value) + # By default we disable propagate once we have configured the logger, unless that handler + # explicitly asks us to keep it on. + if flag is not SetContextPropagate.MAINTAIN_PROPAGATE: + logger.propagate = False + if orig_propagate is True: + # If we were set to propagate before we turned if off, then keep passing set_context up logger = logger.parent else: break diff --git a/tests/utils/test_logging_mixin.py b/tests/utils/test_logging_mixin.py index 464c5773d306c..a1ffa3d629792 100644 --- a/tests/utils/test_logging_mixin.py +++ b/tests/utils/test_logging_mixin.py @@ -17,29 +17,64 @@ # under the License. from __future__ import annotations +import logging +import sys import warnings from unittest import mock -from airflow.utils.log.logging_mixin import StreamLogWriter, set_context +import pytest + +from airflow.utils.log.logging_mixin import SetContextPropagate, StreamLogWriter, set_context + + +@pytest.fixture +def logger(): + parent = logging.getLogger(__name__) + parent.propagate = False + yield parent + + parent.propagate = True + + +@pytest.fixture +def child_logger(logger): + yield logger.getChild("child") + + +@pytest.fixture +def parent_child_handlers(child_logger): + parent_handler = logging.NullHandler() + parent_handler.handle = mock.MagicMock(name="parent_handler.handle") + + child_handler = logging.NullHandler() + child_handler.handle = mock.MagicMock(name="handler.handle") + + logger = child_logger.parent + logger.addHandler(parent_handler) + + child_logger.addHandler(child_handler), + child_logger.propagate = True + + yield parent_handler, child_handler + + logger.removeHandler(parent_handler) + child_logger.removeHandler(child_handler) class TestLoggingMixin: def setup_method(self): warnings.filterwarnings(action="always") - def test_set_context(self): - handler1 = mock.MagicMock() - handler2 = mock.MagicMock() - parent = mock.MagicMock() + def test_set_context(self, child_logger, parent_child_handlers): + handler1, handler2 = parent_child_handlers + handler1.set_context = mock.MagicMock() + handler2.set_context = mock.MagicMock() + + parent = logging.getLogger(__name__) parent.propagate = False - parent.handlers = [ - handler1, - ] - log = mock.MagicMock() - log.handlers = [ - handler2, - ] - log.parent = parent + parent.addHandler(handler1) + log = parent.getChild("child") + log.addHandler(handler2), log.propagate = True value = "test" @@ -105,3 +140,37 @@ def test_iobase_compatibility(self): assert not log.closed # has no specific effect log.close() + + +@pytest.mark.parametrize(["maintain_propagate"], [[SetContextPropagate.MAINTAIN_PROPAGATE], [None]]) +def test_set_context_propagation(parent_child_handlers, child_logger, maintain_propagate): + # Test the behaviour of set_context and logger propagation and the MAINTAIN_PROPAGATE return + + parent_handler, handler = parent_child_handlers + handler.set_context = mock.MagicMock(return_value=maintain_propagate) + + # Before settting_context, ensure logs make it to the parent + line = sys._getframe().f_lineno + 1 + record = child_logger.makeRecord( + child_logger.name, logging.INFO, __file__, line, "test message", [], None + ) + child_logger.handle(record) + + handler.handle.assert_called_once_with(record) + # Should call the parent handler too in the default/unconfigured case + parent_handler.handle.assert_called_once_with(record) + + parent_handler.handle.reset_mock() + handler.handle.reset_mock() + + # Ensure that once we've called set_context on the handler we disable propagation to parent loggers by + # default! + set_context(child_logger, {}) + + child_logger.handle(record) + + handler.handle.assert_called_once_with(record) + if maintain_propagate is SetContextPropagate.MAINTAIN_PROPAGATE: + parent_handler.handle.assert_called_once_with(record) + else: + parent_handler.handle.assert_not_called()