diff --git a/sanic/blueprints.py b/sanic/blueprints.py index e5e1d33327..e13cafcdb5 100644 --- a/sanic/blueprints.py +++ b/sanic/blueprints.py @@ -4,6 +4,7 @@ from collections import defaultdict from copy import deepcopy +from enum import Enum from types import SimpleNamespace from typing import ( TYPE_CHECKING, @@ -144,7 +145,7 @@ def exception(self, *args, **kwargs): kwargs["apply"] = False return super().exception(*args, **kwargs) - def signal(self, event: str, *args, **kwargs): + def signal(self, event: Union[str, Enum], *args, **kwargs): kwargs["apply"] = False return super().signal(event, *args, **kwargs) diff --git a/sanic/mixins/signals.py b/sanic/mixins/signals.py index 2be9fee2e6..57b01b46e8 100644 --- a/sanic/mixins/signals.py +++ b/sanic/mixins/signals.py @@ -1,4 +1,5 @@ -from typing import Any, Callable, Dict, Optional, Set +from enum import Enum +from typing import Any, Callable, Dict, Optional, Set, Union from sanic.models.futures import FutureSignal from sanic.models.handler_types import SignalHandler @@ -19,7 +20,7 @@ def _apply_signal(self, signal: FutureSignal) -> Signal: def signal( self, - event: str, + event: Union[str, Enum], *, apply: bool = True, condition: Dict[str, Any] = None, @@ -41,13 +42,11 @@ async def signal_handler(thing, **kwargs): filtering, defaults to None :type condition: Dict[str, Any], optional """ + event_value = str(event.value) if isinstance(event, Enum) else event def decorator(handler: SignalHandler): - nonlocal event - nonlocal apply - future_signal = FutureSignal( - handler, event, HashableDict(condition or {}) + handler, event_value, HashableDict(condition or {}) ) self._future_signals.add(future_signal) diff --git a/sanic/signals.py b/sanic/signals.py index 9da7eccded..7bb510fa8a 100644 --- a/sanic/signals.py +++ b/sanic/signals.py @@ -2,6 +2,7 @@ import asyncio +from enum import Enum from inspect import isawaitable from typing import Any, Dict, List, Optional, Tuple, Union @@ -14,29 +15,47 @@ from sanic.models.handler_types import SignalHandler +class Event(Enum): + SERVER_INIT_AFTER = "server.init.after" + SERVER_INIT_BEFORE = "server.init.before" + SERVER_SHUTDOWN_AFTER = "server.shutdown.after" + SERVER_SHUTDOWN_BEFORE = "server.shutdown.before" + HTTP_LIFECYCLE_BEGIN = "http.lifecycle.begin" + HTTP_LIFECYCLE_COMPLETE = "http.lifecycle.complete" + HTTP_LIFECYCLE_EXCEPTION = "http.lifecycle.exception" + HTTP_LIFECYCLE_HANDLE = "http.lifecycle.handle" + HTTP_LIFECYCLE_READ_BODY = "http.lifecycle.read_body" + HTTP_LIFECYCLE_READ_HEAD = "http.lifecycle.read_head" + HTTP_LIFECYCLE_REQUEST = "http.lifecycle.request" + HTTP_LIFECYCLE_RESPONSE = "http.lifecycle.response" + HTTP_ROUTING_AFTER = "http.routing.after" + HTTP_ROUTING_BEFORE = "http.routing.before" + HTTP_LIFECYCLE_SEND = "http.lifecycle.send" + HTTP_MIDDLEWARE_AFTER = "http.middleware.after" + HTTP_MIDDLEWARE_BEFORE = "http.middleware.before" + + RESERVED_NAMESPACES = { "server": ( - # "server.main.start", - # "server.main.stop", - "server.init.before", - "server.init.after", - "server.shutdown.before", - "server.shutdown.after", + Event.SERVER_INIT_AFTER.value, + Event.SERVER_INIT_BEFORE.value, + Event.SERVER_SHUTDOWN_AFTER.value, + Event.SERVER_SHUTDOWN_BEFORE.value, ), "http": ( - "http.lifecycle.begin", - "http.lifecycle.complete", - "http.lifecycle.exception", - "http.lifecycle.handle", - "http.lifecycle.read_body", - "http.lifecycle.read_head", - "http.lifecycle.request", - "http.lifecycle.response", - "http.routing.after", - "http.routing.before", - "http.lifecycle.send", - "http.middleware.after", - "http.middleware.before", + Event.HTTP_LIFECYCLE_BEGIN.value, + Event.HTTP_LIFECYCLE_COMPLETE.value, + Event.HTTP_LIFECYCLE_EXCEPTION.value, + Event.HTTP_LIFECYCLE_HANDLE.value, + Event.HTTP_LIFECYCLE_READ_BODY.value, + Event.HTTP_LIFECYCLE_READ_HEAD.value, + Event.HTTP_LIFECYCLE_REQUEST.value, + Event.HTTP_LIFECYCLE_RESPONSE.value, + Event.HTTP_ROUTING_AFTER.value, + Event.HTTP_ROUTING_BEFORE.value, + Event.HTTP_LIFECYCLE_SEND.value, + Event.HTTP_MIDDLEWARE_AFTER.value, + Event.HTTP_MIDDLEWARE_BEFORE.value, ), } diff --git a/tests/test_signals.py b/tests/test_signals.py index 9b8a94953a..51aea3c868 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -1,5 +1,6 @@ import asyncio +from enum import Enum from inspect import isawaitable import pytest @@ -50,6 +51,25 @@ def handler(): ... +@pytest.mark.asyncio +async def test_dispatch_signal_with_enum_event(app): + counter = 0 + + class FooEnum(Enum): + FOO_BAR_BAZ = "foo.bar.baz" + + @app.signal(FooEnum.FOO_BAR_BAZ) + def sync_signal(*_): + nonlocal counter + + counter += 1 + + app.signal_router.finalize() + + await app.dispatch("foo.bar.baz") + assert counter == 1 + + @pytest.mark.asyncio async def test_dispatch_signal_triggers_multiple_handlers(app): counter = 0