Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add priority to register_middleware method #2636

Merged
merged 4 commits into from Dec 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
57 changes: 48 additions & 9 deletions sanic/app.py
Expand Up @@ -61,14 +61,15 @@
URLBuildError,
)
from sanic.handlers import ErrorHandler
from sanic.helpers import Default
from sanic.helpers import Default, _default
from sanic.http import Stage
from sanic.log import (
LOGGING_CONFIG_DEFAULTS,
deprecation,
error_logger,
logger,
)
from sanic.middleware import Middleware, MiddlewareLocation
from sanic.mixins.listeners import ListenerEvent
from sanic.mixins.startup import StartupMixin
from sanic.models.futures import (
Expand Down Expand Up @@ -294,8 +295,12 @@ def register_listener(
return listener

def register_middleware(
self, middleware: MiddlewareType, attach_to: str = "request"
) -> MiddlewareType:
self,
middleware: Union[MiddlewareType, Middleware],
attach_to: str = "request",
*,
priority: Union[Default, int] = _default,
) -> Union[MiddlewareType, Middleware]:
"""
Register an application level middleware that will be attached
to all the API URLs registered under this application.
Expand All @@ -311,19 +316,37 @@ def register_middleware(
**response** - Invoke before the response is returned back
:return: decorated method
"""
if attach_to == "request":
retval = middleware
ahopkins marked this conversation as resolved.
Show resolved Hide resolved
location = MiddlewareLocation[attach_to.upper()]

if not isinstance(middleware, Middleware):
middleware = Middleware(
middleware,
location=location,
priority=priority if isinstance(priority, int) else 0,
)
elif middleware.priority != priority and isinstance(priority, int):
ahopkins marked this conversation as resolved.
Show resolved Hide resolved
middleware = Middleware(
middleware.func,
location=middleware.location,
priority=priority,
)

if location is MiddlewareLocation.REQUEST:
if middleware not in self.request_middleware:
self.request_middleware.append(middleware)
if attach_to == "response":
if location is MiddlewareLocation.RESPONSE:
if middleware not in self.response_middleware:
self.response_middleware.appendleft(middleware)
return middleware
return retval

def register_named_middleware(
self,
middleware: MiddlewareType,
route_names: Iterable[str],
attach_to: str = "request",
*,
priority: Union[Default, int] = _default,
):
"""
Method for attaching middleware to specific routes. This is mainly an
Expand All @@ -337,19 +360,35 @@ def register_named_middleware(
defaults to "request"
:type attach_to: str, optional
"""
if attach_to == "request":
retval = middleware
location = MiddlewareLocation[attach_to.upper()]

if not isinstance(middleware, Middleware):
middleware = Middleware(
middleware,
location=location,
priority=priority if isinstance(priority, int) else 0,
)
elif middleware.priority != priority and isinstance(priority, int):
middleware = Middleware(
middleware.func,
location=middleware.location,
priority=priority,
)

if location is MiddlewareLocation.REQUEST:
for _rn in route_names:
if _rn not in self.named_request_middleware:
self.named_request_middleware[_rn] = deque()
if middleware not in self.named_request_middleware[_rn]:
self.named_request_middleware[_rn].append(middleware)
if attach_to == "response":
if location is MiddlewareLocation.RESPONSE:
for _rn in route_names:
if _rn not in self.named_response_middleware:
self.named_response_middleware[_rn] = deque()
if middleware not in self.named_response_middleware[_rn]:
self.named_response_middleware[_rn].appendleft(middleware)
return middleware
return retval

def _apply_exception_handler(
self,
Expand Down
3 changes: 3 additions & 0 deletions sanic/middleware.py
Expand Up @@ -32,6 +32,9 @@ def __init__(
def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)

def __hash__(self) -> int:
return hash(self.func)

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}("
Expand Down
82 changes: 81 additions & 1 deletion tests/test_middleware_priority.py
Expand Up @@ -3,7 +3,7 @@
import pytest

from sanic import Sanic
from sanic.middleware import Middleware
from sanic.middleware import Middleware, MiddlewareLocation
from sanic.response import json


Expand Down Expand Up @@ -40,6 +40,86 @@ def reset_middleware():
Middleware.reset_count()


def test_add_register_priority(app: Sanic):
def foo(*_):
...

app.register_middleware(foo, priority=999)
assert len(app.request_middleware) == 1
assert len(app.response_middleware) == 0
assert app.request_middleware[0].priority == 999 # type: ignore
app.register_middleware(foo, attach_to="response", priority=999)
assert len(app.request_middleware) == 1
assert len(app.response_middleware) == 1
assert app.response_middleware[0].priority == 999 # type: ignore


def test_add_register_named_priority(app: Sanic):
def foo(*_):
...

app.register_named_middleware(foo, route_names=["foo"], priority=999)
assert len(app.named_request_middleware) == 1
assert len(app.named_response_middleware) == 0
assert app.named_request_middleware["foo"][0].priority == 999 # type: ignore
app.register_named_middleware(
foo, attach_to="response", route_names=["foo"], priority=999
)
assert len(app.named_request_middleware) == 1
assert len(app.named_response_middleware) == 1
assert app.named_response_middleware["foo"][0].priority == 999 # type: ignore


def test_add_decorator_priority(app: Sanic):
def foo(*_):
...

app.middleware(foo, priority=999)
assert len(app.request_middleware) == 1
assert len(app.response_middleware) == 0
assert app.request_middleware[0].priority == 999 # type: ignore
app.middleware(foo, attach_to="response", priority=999)
assert len(app.request_middleware) == 1
assert len(app.response_middleware) == 1
assert app.response_middleware[0].priority == 999 # type: ignore


def test_add_convenience_priority(app: Sanic):
def foo(*_):
...

app.on_request(foo, priority=999)
assert len(app.request_middleware) == 1
assert len(app.response_middleware) == 0
assert app.request_middleware[0].priority == 999 # type: ignore
app.on_response(foo, priority=999)
assert len(app.request_middleware) == 1
assert len(app.response_middleware) == 1
assert app.response_middleware[0].priority == 999 # type: ignore


def test_add_conflicting_priority(app: Sanic):
def foo(*_):
...

middleware = Middleware(foo, MiddlewareLocation.REQUEST, priority=998)
app.register_middleware(middleware=middleware, priority=999)
assert app.request_middleware[0].priority == 999 # type: ignore
middleware.priority == 998


def test_add_conflicting_priority_named(app: Sanic):
def foo(*_):
...

middleware = Middleware(foo, MiddlewareLocation.REQUEST, priority=998)
app.register_named_middleware(
middleware=middleware, route_names=["foo"], priority=999
)
assert app.named_request_middleware["foo"][0].priority == 999 # type: ignore
middleware.priority == 998


@pytest.mark.parametrize(
"expected,priorities",
PRIORITY_TEST_CASES,
Expand Down