diff --git a/sanic/blueprints.py b/sanic/blueprints.py index 7ef8f58d56..521f77f658 100644 --- a/sanic/blueprints.py +++ b/sanic/blueprints.py @@ -400,8 +400,9 @@ def register(self, app, options): for future in self._future_signals: if (self, future) in app._future_registry: continue - future.condition.update({"blueprint": self.name}) - app._apply_signal(future) + future.condition.update({"__blueprint__": self.name}) + # Force exclusive to be False + app._apply_signal(tuple((*future[:-1], False))) self.routes += [route for route in routes if isinstance(route, Route)] self.websocket_routes += [ @@ -426,7 +427,7 @@ def register(self, app, options): async def dispatch(self, *args, **kwargs): condition = kwargs.pop("condition", {}) - condition.update({"blueprint": self.name}) + condition.update({"__blueprint__": self.name}) kwargs["condition"] = condition await asyncio.gather( *[app.dispatch(*args, **kwargs) for app in self.apps] diff --git a/sanic/mixins/signals.py b/sanic/mixins/signals.py index a838316a5c..601c4b18b0 100644 --- a/sanic/mixins/signals.py +++ b/sanic/mixins/signals.py @@ -21,6 +21,7 @@ def signal( *, apply: bool = True, condition: Dict[str, Any] = None, + exclusive: bool = True, ) -> Callable[[SignalHandler], SignalHandler]: """ For creating a signal handler, used similar to a route handler: @@ -33,17 +34,22 @@ async def signal_handler(thing, **kwargs): :param event: Representation of the event in ``one.two.three`` form :type event: str - :param apply: For lazy evaluation, defaults to True + :param apply: For lazy evaluation, defaults to ``True`` :type apply: bool, optional :param condition: For use with the ``condition`` argument in dispatch - filtering, defaults to None + filtering, defaults to ``None`` + :param exclusive: When ``True``, the signal can only be dispatched + when the condition has been met. When ``False``, the signal can + be dispatched either with or without it. *THIS IS INAPPLICABLE TO + BLUEPRINT SIGNALS. THEY ARE ALWAYS NON-EXCLUSIVE*, defaults + to ``True`` :type condition: Dict[str, Any], optional """ event_value = str(event.value) if isinstance(event, Enum) else event def decorator(handler: SignalHandler): future_signal = FutureSignal( - handler, event_value, HashableDict(condition or {}) + handler, event_value, HashableDict(condition or {}), exclusive ) self._future_signals.add(future_signal) @@ -59,6 +65,7 @@ def add_signal( handler: Optional[Callable[..., Any]], event: str, condition: Dict[str, Any] = None, + exclusive: bool = True, ): if not handler: @@ -66,7 +73,9 @@ async def noop(): ... handler = noop - self.signal(event=event, condition=condition)(handler) + self.signal(event=event, condition=condition, exclusive=exclusive)( + handler + ) return handler def event(self, event: str): diff --git a/sanic/models/futures.py b/sanic/models/futures.py index f25c0270f0..e97a54b046 100644 --- a/sanic/models/futures.py +++ b/sanic/models/futures.py @@ -62,6 +62,7 @@ class FutureSignal(NamedTuple): handler: SignalHandler event: str condition: Optional[Dict[str, str]] + exclusive: bool class FutureRegistry(set): diff --git a/sanic/signals.py b/sanic/signals.py index 7bb510fa8a..f4061b69cc 100644 --- a/sanic/signals.py +++ b/sanic/signals.py @@ -4,7 +4,7 @@ from enum import Enum from inspect import isawaitable -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, Dict, List, Optional, Tuple, Union, cast from sanic_routing import BaseRouter, Route, RouteGroup # type: ignore from sanic_routing.exceptions import NotFound # type: ignore @@ -142,12 +142,21 @@ async def _dispatch( if context: params.update(context) + signals = group.routes if not reverse: - handlers = handlers[::-1] + signals = signals[::-1] try: - for handler in handlers: - if condition is None or condition == handler.__requirements__: - maybe_coroutine = handler(**params) + for signal in signals: + params.pop("__trigger__", None) + if ( + (condition is None and signal.ctx.exclusive is False) + or ( + condition is None + and not signal.handler.__requirements__ + ) + or (condition == signal.handler.__requirements__) + ) and (signal.ctx.trigger or event == signal.ctx.definition): + maybe_coroutine = signal.handler(**params) if isawaitable(maybe_coroutine): retval = await maybe_coroutine if retval: @@ -190,23 +199,36 @@ def add( # type: ignore handler: SignalHandler, event: str, condition: Optional[Dict[str, Any]] = None, + exclusive: bool = True, ) -> Signal: + event_definition = event parts = self._build_event_parts(event) if parts[2].startswith("<"): name = ".".join([*parts[:-1], "*"]) + trigger = self._clean_trigger(parts[2]) else: name = event + trigger = "" + + if not trigger: + event = ".".join([*parts[:2], "<__trigger__>"]) handler.__requirements__ = condition # type: ignore + handler.__trigger__ = trigger # type: ignore - return super().add( + signal = super().add( event, handler, - requirements=condition, name=name, append=True, ) # type: ignore + signal.ctx.exclusive = exclusive + signal.ctx.trigger = trigger + signal.ctx.definition = event_definition + + return cast(Signal, signal) + def finalize(self, do_compile: bool = True, do_optimize: bool = False): self.add(_blank, "sanic.__signal__.__init__") @@ -238,3 +260,9 @@ def _build_event_parts(self, event: str) -> Tuple[str, str, str]: "Cannot declare reserved signal event: %s" % event ) return parts + + def _clean_trigger(self, trigger: str) -> str: + trigger = trigger[1:-1] + if ":" in trigger: + trigger, _ = trigger.split(":") + return trigger diff --git a/tests/test_signals.py b/tests/test_signals.py index 51aea3c868..9835430967 100644 --- a/tests/test_signals.py +++ b/tests/test_signals.py @@ -145,6 +145,23 @@ def sync_signal(*_): assert counter == 1 +@pytest.mark.asyncio +async def test_dispatch_signal_triggers_with_requirements_exclusive(app): + counter = 0 + + @app.signal("foo.bar.baz", condition={"one": "two"}, exclusive=False) + def sync_signal(*_): + nonlocal counter + counter += 1 + + app.signal_router.finalize() + + await app.dispatch("foo.bar.baz") + assert counter == 1 + await app.dispatch("foo.bar.baz", condition={"one": "two"}) + assert counter == 2 + + @pytest.mark.asyncio async def test_dispatch_signal_triggers_with_context(app): counter = 0 @@ -204,6 +221,24 @@ def bp_signal(): assert bp_counter == 2 +@pytest.mark.asyncio +async def test_dispatch_signal_triggers_on_bp_alone(app): + bp = Blueprint("bp") + + bp_counter = 0 + + @bp.signal("foo.bar.baz") + def bp_signal(): + nonlocal bp_counter + bp_counter += 1 + + app.blueprint(bp) + app.signal_router.finalize() + await app.dispatch("foo.bar.baz") + await bp.dispatch("foo.bar.baz") + assert bp_counter == 2 + + @pytest.mark.asyncio async def test_dispatch_signal_triggers_event(app): app_counter = 0