From 2abe66b67086d398d58cd754cf5b158b812acebb Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Mon, 19 Dec 2022 19:14:46 +0200 Subject: [PATCH] Add priority to register_middleware method (#2636) --- sanic/app.py | 57 +++++++++++++++++---- sanic/middleware.py | 3 ++ tests/test_middleware_priority.py | 82 ++++++++++++++++++++++++++++++- 3 files changed, 132 insertions(+), 10 deletions(-) diff --git a/sanic/app.py b/sanic/app.py index 8d09866302..cbefcc4d9d 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -61,7 +61,7 @@ URLBuildError, ) from sanic.handlers import ErrorHandler -from sanic.helpers import Default +from sanic.helpers import Default, _default from sanic.http import Stage from sanic.log import ( LOGGING_CONFIG_DEFAULTS, @@ -69,6 +69,7 @@ error_logger, logger, ) +from sanic.middleware import Middleware, MiddlewareLocation from sanic.mixins.listeners import ListenerEvent from sanic.mixins.startup import StartupMixin from sanic.models.futures import ( @@ -294,8 +295,12 @@ def register_listener( return listener def register_middleware( - self, middleware: MiddlewareType, attach_to: str = "request" - ) -> MiddlewareType: + self, + middleware: Union[MiddlewareType, Middleware], + attach_to: str = "request", + *, + priority: Union[Default, int] = _default, + ) -> Union[MiddlewareType, Middleware]: """ Register an application level middleware that will be attached to all the API URLs registered under this application. @@ -311,19 +316,37 @@ def register_middleware( **response** - Invoke before the response is returned back :return: decorated method """ - if attach_to == "request": + retval = middleware + location = MiddlewareLocation[attach_to.upper()] + + if not isinstance(middleware, Middleware): + middleware = Middleware( + middleware, + location=location, + priority=priority if isinstance(priority, int) else 0, + ) + elif middleware.priority != priority and isinstance(priority, int): + middleware = Middleware( + middleware.func, + location=middleware.location, + priority=priority, + ) + + if location is MiddlewareLocation.REQUEST: if middleware not in self.request_middleware: self.request_middleware.append(middleware) - if attach_to == "response": + if location is MiddlewareLocation.RESPONSE: if middleware not in self.response_middleware: self.response_middleware.appendleft(middleware) - return middleware + return retval def register_named_middleware( self, middleware: MiddlewareType, route_names: Iterable[str], attach_to: str = "request", + *, + priority: Union[Default, int] = _default, ): """ Method for attaching middleware to specific routes. This is mainly an @@ -337,19 +360,35 @@ def register_named_middleware( defaults to "request" :type attach_to: str, optional """ - if attach_to == "request": + retval = middleware + location = MiddlewareLocation[attach_to.upper()] + + if not isinstance(middleware, Middleware): + middleware = Middleware( + middleware, + location=location, + priority=priority if isinstance(priority, int) else 0, + ) + elif middleware.priority != priority and isinstance(priority, int): + middleware = Middleware( + middleware.func, + location=middleware.location, + priority=priority, + ) + + if location is MiddlewareLocation.REQUEST: for _rn in route_names: if _rn not in self.named_request_middleware: self.named_request_middleware[_rn] = deque() if middleware not in self.named_request_middleware[_rn]: self.named_request_middleware[_rn].append(middleware) - if attach_to == "response": + if location is MiddlewareLocation.RESPONSE: for _rn in route_names: if _rn not in self.named_response_middleware: self.named_response_middleware[_rn] = deque() if middleware not in self.named_response_middleware[_rn]: self.named_response_middleware[_rn].appendleft(middleware) - return middleware + return retval def _apply_exception_handler( self, diff --git a/sanic/middleware.py b/sanic/middleware.py index 5bbd777b6c..0c6058fa5e 100644 --- a/sanic/middleware.py +++ b/sanic/middleware.py @@ -32,6 +32,9 @@ def __init__( def __call__(self, *args, **kwargs): return self.func(*args, **kwargs) + def __hash__(self) -> int: + return hash(self.func) + def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" diff --git a/tests/test_middleware_priority.py b/tests/test_middleware_priority.py index 9646f6d0f9..c16f658b6d 100644 --- a/tests/test_middleware_priority.py +++ b/tests/test_middleware_priority.py @@ -3,7 +3,7 @@ import pytest from sanic import Sanic -from sanic.middleware import Middleware +from sanic.middleware import Middleware, MiddlewareLocation from sanic.response import json @@ -40,6 +40,86 @@ def reset_middleware(): Middleware.reset_count() +def test_add_register_priority(app: Sanic): + def foo(*_): + ... + + app.register_middleware(foo, priority=999) + assert len(app.request_middleware) == 1 + assert len(app.response_middleware) == 0 + assert app.request_middleware[0].priority == 999 # type: ignore + app.register_middleware(foo, attach_to="response", priority=999) + assert len(app.request_middleware) == 1 + assert len(app.response_middleware) == 1 + assert app.response_middleware[0].priority == 999 # type: ignore + + +def test_add_register_named_priority(app: Sanic): + def foo(*_): + ... + + app.register_named_middleware(foo, route_names=["foo"], priority=999) + assert len(app.named_request_middleware) == 1 + assert len(app.named_response_middleware) == 0 + assert app.named_request_middleware["foo"][0].priority == 999 # type: ignore + app.register_named_middleware( + foo, attach_to="response", route_names=["foo"], priority=999 + ) + assert len(app.named_request_middleware) == 1 + assert len(app.named_response_middleware) == 1 + assert app.named_response_middleware["foo"][0].priority == 999 # type: ignore + + +def test_add_decorator_priority(app: Sanic): + def foo(*_): + ... + + app.middleware(foo, priority=999) + assert len(app.request_middleware) == 1 + assert len(app.response_middleware) == 0 + assert app.request_middleware[0].priority == 999 # type: ignore + app.middleware(foo, attach_to="response", priority=999) + assert len(app.request_middleware) == 1 + assert len(app.response_middleware) == 1 + assert app.response_middleware[0].priority == 999 # type: ignore + + +def test_add_convenience_priority(app: Sanic): + def foo(*_): + ... + + app.on_request(foo, priority=999) + assert len(app.request_middleware) == 1 + assert len(app.response_middleware) == 0 + assert app.request_middleware[0].priority == 999 # type: ignore + app.on_response(foo, priority=999) + assert len(app.request_middleware) == 1 + assert len(app.response_middleware) == 1 + assert app.response_middleware[0].priority == 999 # type: ignore + + +def test_add_conflicting_priority(app: Sanic): + def foo(*_): + ... + + middleware = Middleware(foo, MiddlewareLocation.REQUEST, priority=998) + app.register_middleware(middleware=middleware, priority=999) + assert app.request_middleware[0].priority == 999 # type: ignore + middleware.priority == 998 + + +def test_add_conflicting_priority_named(app: Sanic): + def foo(*_): + ... + + middleware = Middleware(foo, MiddlewareLocation.REQUEST, priority=998) + app.register_named_middleware( + middleware=middleware, route_names=["foo"], priority=999 + ) + assert app.named_request_middleware["foo"][0].priority == 999 # type: ignore + middleware.priority == 998 + + @pytest.mark.parametrize( "expected,priorities", PRIORITY_TEST_CASES,