From d1e2559ef0bf93b6ff5783f99aba3878037f10c0 Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Sun, 18 Dec 2022 15:03:38 +0200 Subject: [PATCH 1/3] Add priority to register_middleware method --- sanic/app.py | 39 +++++++++++++---- sanic/middleware.py | 3 ++ tests/test_middleware_priority.py | 70 +++++++++++++++++++++++++++++++ 3 files changed, 104 insertions(+), 8 deletions(-) diff --git a/sanic/app.py b/sanic/app.py index 8d09866302..f367add57b 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -69,6 +69,7 @@ error_logger, logger, ) +from sanic.middleware import Middleware, MiddlewareLocation from sanic.mixins.listeners import ListenerEvent from sanic.mixins.startup import StartupMixin from sanic.models.futures import ( @@ -294,8 +295,12 @@ def register_listener( return listener def register_middleware( - self, middleware: MiddlewareType, attach_to: str = "request" - ) -> MiddlewareType: + self, + middleware: Union[MiddlewareType, Middleware], + attach_to: str = "request", + *, + priority=0, + ) -> Union[MiddlewareType, Middleware]: """ Register an application level middleware that will be attached to all the API URLs registered under this application. @@ -311,19 +316,29 @@ def register_middleware( **response** - Invoke before the response is returned back :return: decorated method """ - if attach_to == "request": + retval = middleware + location = MiddlewareLocation[attach_to.upper()] + + if not isinstance(middleware, Middleware): + middleware = Middleware( + middleware, location=location, priority=priority + ) + + if location is MiddlewareLocation.REQUEST: if middleware not in self.request_middleware: self.request_middleware.append(middleware) - if attach_to == "response": + if location is MiddlewareLocation.RESPONSE: if middleware not in self.response_middleware: self.response_middleware.appendleft(middleware) - return middleware + return retval def register_named_middleware( self, middleware: MiddlewareType, route_names: Iterable[str], attach_to: str = "request", + *, + priority=0, ): """ Method for attaching middleware to specific routes. This is mainly an @@ -337,19 +352,27 @@ def register_named_middleware( defaults to "request" :type attach_to: str, optional """ - if attach_to == "request": + retval = middleware + location = MiddlewareLocation[attach_to.upper()] + + if not isinstance(middleware, Middleware): + middleware = Middleware( + middleware, location=location, priority=priority + ) + + if location is MiddlewareLocation.REQUEST: for _rn in route_names: if _rn not in self.named_request_middleware: self.named_request_middleware[_rn] = deque() if middleware not in self.named_request_middleware[_rn]: self.named_request_middleware[_rn].append(middleware) - if attach_to == "response": + if location is MiddlewareLocation.RESPONSE: for _rn in route_names: if _rn not in self.named_response_middleware: self.named_response_middleware[_rn] = deque() if middleware not in self.named_response_middleware[_rn]: self.named_response_middleware[_rn].appendleft(middleware) - return middleware + return retval def _apply_exception_handler( self, diff --git a/sanic/middleware.py b/sanic/middleware.py index 5bbd777b6c..0c6058fa5e 100644 --- a/sanic/middleware.py +++ b/sanic/middleware.py @@ -32,6 +32,9 @@ def __init__( def __call__(self, *args, **kwargs): return self.func(*args, **kwargs) + def __hash__(self) -> int: + return hash(self.func) + def __repr__(self) -> str: return ( f"{self.__class__.__name__}(" diff --git a/tests/test_middleware_priority.py b/tests/test_middleware_priority.py index 9646f6d0f9..b4dc09ffaf 100644 --- a/tests/test_middleware_priority.py +++ b/tests/test_middleware_priority.py @@ -40,6 +40,76 @@ def reset_middleware(): Middleware.reset_count() +def test_add_register_priority(app: Sanic): + def foo(*_): + ... + + app.register_middleware(foo, priority=999) + assert len(app.request_middleware) == 1 + assert len(app.response_middleware) == 0 + assert app.request_middleware[0].priority == 999 # type: ignore + app.register_middleware(foo, attach_to="response", priority=999) + assert len(app.request_middleware) == 1 + assert len(app.response_middleware) == 1 + assert app.response_middleware[0].priority == 999 # type: ignore + + +def test_add_register_named_priority(app: Sanic): + def foo(*_): + ... + + app.register_named_middleware(foo, route_names=["foo"], priority=999) + assert len(app.named_request_middleware) == 1 + assert len(app.named_response_middleware) == 0 + assert app.named_request_middleware["foo"][0].priority == 999 # type: ignore + app.register_named_middleware( + foo, attach_to="response", route_names=["foo"], priority=999 + ) + assert len(app.named_request_middleware) == 1 + assert len(app.named_response_middleware) == 1 + assert app.named_response_middleware["foo"][0].priority == 999 # type: ignore + + +def test_add_decorator_priority(app: Sanic): + def foo(*_): + ... + + app.middleware(foo, priority=999) + assert len(app.request_middleware) == 1 + assert len(app.response_middleware) == 0 + assert app.request_middleware[0].priority == 999 # type: ignore + app.middleware(foo, attach_to="response", priority=999) + assert len(app.request_middleware) == 1 + assert len(app.response_middleware) == 1 + assert app.response_middleware[0].priority == 999 # type: ignore + + +def test_add_convenience_priority(app: Sanic): + def foo(*_): + ... + + app.on_request(foo, priority=999) + assert len(app.request_middleware) == 1 + assert len(app.response_middleware) == 0 + assert app.request_middleware[0].priority == 999 # type: ignore + app.on_response(foo, priority=999) + assert len(app.request_middleware) == 1 + assert len(app.response_middleware) == 1 + assert app.response_middleware[0].priority == 999 # type: ignore + + +def test_add_convenience_priority(app: Sanic): + def foo(*_): + ... + + app.on_request(foo) + assert len(app.request_middleware) == 1 + assert len(app.response_middleware) == 0 + app.on_response(foo) + assert len(app.request_middleware) == 1 + assert len(app.response_middleware) == 1 + + @pytest.mark.parametrize( "expected,priorities", PRIORITY_TEST_CASES, From df89b9b89866bab40c0b87f6d75182d1fb1a32da Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Sun, 18 Dec 2022 15:15:10 +0200 Subject: [PATCH 2/3] Allow overrides --- sanic/app.py | 27 ++++++++++++++++++++++----- tests/test_middleware_priority.py | 26 ++++++++++++++++++-------- 2 files changed, 40 insertions(+), 13 deletions(-) diff --git a/sanic/app.py b/sanic/app.py index f367add57b..06dcaba1b3 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -61,7 +61,7 @@ URLBuildError, ) from sanic.handlers import ErrorHandler -from sanic.helpers import Default +from sanic.helpers import Default, _default from sanic.http import Stage from sanic.log import ( LOGGING_CONFIG_DEFAULTS, @@ -299,7 +299,7 @@ def register_middleware( middleware: Union[MiddlewareType, Middleware], attach_to: str = "request", *, - priority=0, + priority: Union[Default, int] = _default, ) -> Union[MiddlewareType, Middleware]: """ Register an application level middleware that will be attached @@ -319,9 +319,18 @@ def register_middleware( retval = middleware location = MiddlewareLocation[attach_to.upper()] + print(">>>>>", priority) if not isinstance(middleware, Middleware): middleware = Middleware( - middleware, location=location, priority=priority + middleware, + location=location, + priority=priority if isinstance(priority, int) else 0, + ) + elif middleware.priority != priority and isinstance(priority, int): + middleware = Middleware( + middleware.func, + location=middleware.location, + priority=priority, ) if location is MiddlewareLocation.REQUEST: @@ -338,7 +347,7 @@ def register_named_middleware( route_names: Iterable[str], attach_to: str = "request", *, - priority=0, + priority: Union[Default, int] = _default, ): """ Method for attaching middleware to specific routes. This is mainly an @@ -357,7 +366,15 @@ def register_named_middleware( if not isinstance(middleware, Middleware): middleware = Middleware( - middleware, location=location, priority=priority + middleware, + location=location, + priority=priority if isinstance(priority, int) else 0, + ) + elif middleware.priority != priority and isinstance(priority, int): + middleware = Middleware( + middleware.func, + location=middleware.location, + priority=priority, ) if location is MiddlewareLocation.REQUEST: diff --git a/tests/test_middleware_priority.py b/tests/test_middleware_priority.py index b4dc09ffaf..c16f658b6d 100644 --- a/tests/test_middleware_priority.py +++ b/tests/test_middleware_priority.py @@ -3,7 +3,7 @@ import pytest from sanic import Sanic -from sanic.middleware import Middleware +from sanic.middleware import Middleware, MiddlewareLocation from sanic.response import json @@ -98,16 +98,26 @@ def foo(*_): assert app.response_middleware[0].priority == 999 # type: ignore -def test_add_convenience_priority(app: Sanic): +def test_add_conflicting_priority(app: Sanic): def foo(*_): ... - app.on_request(foo) - assert len(app.request_middleware) == 1 - assert len(app.response_middleware) == 0 - app.on_response(foo) - assert len(app.request_middleware) == 1 - assert len(app.response_middleware) == 1 + middleware = Middleware(foo, MiddlewareLocation.REQUEST, priority=998) + app.register_middleware(middleware=middleware, priority=999) + assert app.request_middleware[0].priority == 999 # type: ignore + middleware.priority == 998 + + +def test_add_conflicting_priority_named(app: Sanic): + def foo(*_): + ... + + middleware = Middleware(foo, MiddlewareLocation.REQUEST, priority=998) + app.register_named_middleware( + middleware=middleware, route_names=["foo"], priority=999 + ) + assert app.named_request_middleware["foo"][0].priority == 999 # type: ignore + middleware.priority == 998 @pytest.mark.parametrize( From 22f296d47e432a5dcda955d5552698ca5de0cc6a Mon Sep 17 00:00:00 2001 From: Adam Hopkins Date: Sun, 18 Dec 2022 15:15:31 +0200 Subject: [PATCH 3/3] squash --- sanic/app.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sanic/app.py b/sanic/app.py index 06dcaba1b3..cbefcc4d9d 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -319,7 +319,6 @@ def register_middleware( retval = middleware location = MiddlewareLocation[attach_to.upper()] - print(">>>>>", priority) if not isinstance(middleware, Middleware): middleware = Middleware( middleware,