Skip to content

Commit

Permalink
Add priority to register_middleware method (#2636)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahopkins committed Dec 19, 2022
1 parent 911485d commit 2abe66b
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 10 deletions.
57 changes: 48 additions & 9 deletions sanic/app.py
Expand Up @@ -61,14 +61,15 @@
URLBuildError,
)
from sanic.handlers import ErrorHandler
from sanic.helpers import Default
from sanic.helpers import Default, _default
from sanic.http import Stage
from sanic.log import (
LOGGING_CONFIG_DEFAULTS,
deprecation,
error_logger,
logger,
)
from sanic.middleware import Middleware, MiddlewareLocation
from sanic.mixins.listeners import ListenerEvent
from sanic.mixins.startup import StartupMixin
from sanic.models.futures import (
Expand Down Expand Up @@ -294,8 +295,12 @@ def register_listener(
return listener

def register_middleware(
self, middleware: MiddlewareType, attach_to: str = "request"
) -> MiddlewareType:
self,
middleware: Union[MiddlewareType, Middleware],
attach_to: str = "request",
*,
priority: Union[Default, int] = _default,
) -> Union[MiddlewareType, Middleware]:
"""
Register an application level middleware that will be attached
to all the API URLs registered under this application.
Expand All @@ -311,19 +316,37 @@ def register_middleware(
**response** - Invoke before the response is returned back
:return: decorated method
"""
if attach_to == "request":
retval = middleware
location = MiddlewareLocation[attach_to.upper()]

if not isinstance(middleware, Middleware):
middleware = Middleware(
middleware,
location=location,
priority=priority if isinstance(priority, int) else 0,
)
elif middleware.priority != priority and isinstance(priority, int):
middleware = Middleware(
middleware.func,
location=middleware.location,
priority=priority,
)

if location is MiddlewareLocation.REQUEST:
if middleware not in self.request_middleware:
self.request_middleware.append(middleware)
if attach_to == "response":
if location is MiddlewareLocation.RESPONSE:
if middleware not in self.response_middleware:
self.response_middleware.appendleft(middleware)
return middleware
return retval

def register_named_middleware(
self,
middleware: MiddlewareType,
route_names: Iterable[str],
attach_to: str = "request",
*,
priority: Union[Default, int] = _default,
):
"""
Method for attaching middleware to specific routes. This is mainly an
Expand All @@ -337,19 +360,35 @@ def register_named_middleware(
defaults to "request"
:type attach_to: str, optional
"""
if attach_to == "request":
retval = middleware
location = MiddlewareLocation[attach_to.upper()]

if not isinstance(middleware, Middleware):
middleware = Middleware(
middleware,
location=location,
priority=priority if isinstance(priority, int) else 0,
)
elif middleware.priority != priority and isinstance(priority, int):
middleware = Middleware(
middleware.func,
location=middleware.location,
priority=priority,
)

if location is MiddlewareLocation.REQUEST:
for _rn in route_names:
if _rn not in self.named_request_middleware:
self.named_request_middleware[_rn] = deque()
if middleware not in self.named_request_middleware[_rn]:
self.named_request_middleware[_rn].append(middleware)
if attach_to == "response":
if location is MiddlewareLocation.RESPONSE:
for _rn in route_names:
if _rn not in self.named_response_middleware:
self.named_response_middleware[_rn] = deque()
if middleware not in self.named_response_middleware[_rn]:
self.named_response_middleware[_rn].appendleft(middleware)
return middleware
return retval

def _apply_exception_handler(
self,
Expand Down
3 changes: 3 additions & 0 deletions sanic/middleware.py
Expand Up @@ -32,6 +32,9 @@ def __init__(
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)

def __hash__(self) -> int:
return hash(self.func)

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}("
Expand Down
82 changes: 81 additions & 1 deletion tests/test_middleware_priority.py
Expand Up @@ -3,7 +3,7 @@
import pytest

from sanic import Sanic
from sanic.middleware import Middleware
from sanic.middleware import Middleware, MiddlewareLocation
from sanic.response import json


Expand Down Expand Up @@ -40,6 +40,86 @@ def reset_middleware():
Middleware.reset_count()


def test_add_register_priority(app: Sanic):
def foo(*_):
...

app.register_middleware(foo, priority=999)
assert len(app.request_middleware) == 1
assert len(app.response_middleware) == 0
assert app.request_middleware[0].priority == 999 # type: ignore
app.register_middleware(foo, attach_to="response", priority=999)
assert len(app.request_middleware) == 1
assert len(app.response_middleware) == 1
assert app.response_middleware[0].priority == 999 # type: ignore


def test_add_register_named_priority(app: Sanic):
def foo(*_):
...

app.register_named_middleware(foo, route_names=["foo"], priority=999)
assert len(app.named_request_middleware) == 1
assert len(app.named_response_middleware) == 0
assert app.named_request_middleware["foo"][0].priority == 999 # type: ignore
app.register_named_middleware(
foo, attach_to="response", route_names=["foo"], priority=999
)
assert len(app.named_request_middleware) == 1
assert len(app.named_response_middleware) == 1
assert app.named_response_middleware["foo"][0].priority == 999 # type: ignore


def test_add_decorator_priority(app: Sanic):
def foo(*_):
...

app.middleware(foo, priority=999)
assert len(app.request_middleware) == 1
assert len(app.response_middleware) == 0
assert app.request_middleware[0].priority == 999 # type: ignore
app.middleware(foo, attach_to="response", priority=999)
assert len(app.request_middleware) == 1
assert len(app.response_middleware) == 1
assert app.response_middleware[0].priority == 999 # type: ignore


def test_add_convenience_priority(app: Sanic):
def foo(*_):
...

app.on_request(foo, priority=999)
assert len(app.request_middleware) == 1
assert len(app.response_middleware) == 0
assert app.request_middleware[0].priority == 999 # type: ignore
app.on_response(foo, priority=999)
assert len(app.request_middleware) == 1
assert len(app.response_middleware) == 1
assert app.response_middleware[0].priority == 999 # type: ignore


def test_add_conflicting_priority(app: Sanic):
def foo(*_):
...

middleware = Middleware(foo, MiddlewareLocation.REQUEST, priority=998)
app.register_middleware(middleware=middleware, priority=999)
assert app.request_middleware[0].priority == 999 # type: ignore
middleware.priority == 998


def test_add_conflicting_priority_named(app: Sanic):
def foo(*_):
...

middleware = Middleware(foo, MiddlewareLocation.REQUEST, priority=998)
app.register_named_middleware(
middleware=middleware, route_names=["foo"], priority=999
)
assert app.named_request_middleware["foo"][0].priority == 999 # type: ignore
middleware.priority == 998


@pytest.mark.parametrize(
"expected,priorities",
PRIORITY_TEST_CASES,
Expand Down

0 comments on commit 2abe66b

Please sign in to comment.