Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revise middleware handling #2521

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
162 changes: 78 additions & 84 deletions sanic/app.py
Expand Up @@ -701,7 +701,10 @@ def url_for(self, view_name: str, **kwargs):
# -------------------------------------------------------------------- #

async def handle_exception(
self, request: Request, exception: BaseException
self,
request: Request,
exception: BaseException,
run_middleware: bool = True,
): # no cov
"""
A handler that catches specific exceptions and outputs a response.
Expand All @@ -710,6 +713,7 @@ async def handle_exception(
:param exception: The exception that was raised
:raises ServerError: response 500
"""
response = None
await self.dispatch(
"http.lifecycle.exception",
inline=True,
Expand Down Expand Up @@ -750,9 +754,12 @@ async def handle_exception(
# -------------------------------------------- #
# Request Middleware
# -------------------------------------------- #
response = await self._run_request_middleware(
request, request_name=None
)
if (
run_middleware
and request.route
and request.route.extra.request_middleware
ahopkins marked this conversation as resolved.
Show resolved Hide resolved
):
response = await self._run_request_middleware(request)
# No middleware results
if not response:
try:
Expand Down Expand Up @@ -832,7 +839,12 @@ async def handle_request(self, request: Request): # no cov

# Define `response` var here to remove warnings about
# allocation before assignment below.
response = None
response: Optional[
Union[
BaseHTTPResponse,
Coroutine[Any, Any, Optional[BaseHTTPResponse]],
]
] = None
try:

await self.dispatch(
Expand Down Expand Up @@ -877,9 +889,8 @@ async def handle_request(self, request: Request): # no cov
# -------------------------------------------- #
# Request Middleware
# -------------------------------------------- #
response = await self._run_request_middleware(
request, request_name=route.name
)
if request.route.extra.request_middleware:
response = await self._run_request_middleware(request)

# No middleware results
if not response:
Expand Down Expand Up @@ -910,7 +921,7 @@ async def handle_request(self, request: Request): # no cov
if request.stream is not None:
response = request.stream.response
elif response is not None:
response = await request.respond(response)
response = await request.respond(response) # type: ignore
elif not hasattr(handler, "is_websocket"):
response = request.stream.response # type: ignore

Expand All @@ -928,7 +939,7 @@ async def handle_request(self, request: Request): # no cov
...
await response.send(end_stream=True)
elif isinstance(response, ResponseStream):
resp = await response(request)
resp = await response(request) # type: ignore
await self.dispatch(
"http.lifecycle.response",
inline=True,
Expand All @@ -937,7 +948,7 @@ async def handle_request(self, request: Request): # no cov
"response": resp,
},
)
await response.eof()
await response.eof() # type: ignore
else:
if not hasattr(handler, "is_websocket"):
raise ServerError(
Expand All @@ -949,7 +960,7 @@ async def handle_request(self, request: Request): # no cov
raise
except Exception as e:
# Response Generation Failed
await self.handle_exception(request, e)
await self.handle_exception(request, e, run_middleware=False)

async def _websocket_handler(
self, handler, request, *args, subprotocols=None, **kwargs
Expand Down Expand Up @@ -1017,87 +1028,69 @@ def asgi_client(self): # noqa
# Execution
# -------------------------------------------------------------------- #

async def _run_request_middleware(
self, request, request_name=None
): # no cov
# The if improves speed. I don't know why
named_middleware = self.named_request_middleware.get(
request_name, deque()
)
applicable_middleware = self.request_middleware + named_middleware

# request.request_middleware_started is meant as a stop-gap solution
# until RFC 1630 is adopted
if applicable_middleware and not request.request_middleware_started:
request.request_middleware_started = True
async def _run_request_middleware(self, request): # no cov
request._request_middleware_started = True

for middleware in applicable_middleware:
await self.dispatch(
"http.middleware.before",
inline=True,
context={
"request": request,
"response": None,
},
condition={"attach_to": "request"},
)
for middleware in request.route.extra.request_middleware:
await self.dispatch(
"http.middleware.before",
inline=True,
context={
"request": request,
"response": None,
},
condition={"attach_to": "request"},
)

response = middleware(request)
if isawaitable(response):
response = await response
response = middleware(request)
if isawaitable(response):
response = await response

await self.dispatch(
"http.middleware.after",
inline=True,
context={
"request": request,
"response": None,
},
condition={"attach_to": "request"},
)
await self.dispatch(
"http.middleware.after",
inline=True,
context={
"request": request,
"response": None,
},
condition={"attach_to": "request"},
)

if response:
return response
if response:
return response
return None

async def _run_response_middleware(
self, request, response, request_name=None
): # no cov
named_middleware = self.named_response_middleware.get(
request_name, deque()
)
applicable_middleware = self.response_middleware + named_middleware
if applicable_middleware:
for middleware in applicable_middleware:
await self.dispatch(
"http.middleware.before",
inline=True,
context={
"request": request,
"response": response,
},
condition={"attach_to": "response"},
)
async def _run_response_middleware(self, request, response): # no cov
for middleware in request.route.extra.response_middleware:
await self.dispatch(
"http.middleware.before",
inline=True,
context={
"request": request,
"response": response,
},
condition={"attach_to": "response"},
)

_response = middleware(request, response)
if isawaitable(_response):
_response = await _response
_response = middleware(request, response)
if isawaitable(_response):
_response = await _response

await self.dispatch(
"http.middleware.after",
inline=True,
context={
"request": request,
"response": _response if _response else response,
},
condition={"attach_to": "response"},
)
await self.dispatch(
"http.middleware.after",
inline=True,
context={
"request": request,
"response": _response if _response else response,
},
condition={"attach_to": "response"},
)

if _response:
response = _response
if isinstance(response, BaseHTTPResponse):
response = request.stream.respond(response)
break
if _response:
response = _response
if isinstance(response, BaseHTTPResponse):
response = request.stream.respond(response)
break
return response

def _build_endpoint_name(self, *parts):
Expand Down Expand Up @@ -1495,6 +1488,7 @@ def finalize(self):
except FinalizationError as e:
if not Sanic.test_mode:
raise e
self.finalize_middleware()

def signalize(self, allow_fail_builtin=True):
self.signal_router.allow_fail_builtin = allow_fail_builtin
Expand Down
61 changes: 61 additions & 0 deletions sanic/middleware.py
@@ -0,0 +1,61 @@
from __future__ import annotations

from collections import deque
from enum import IntEnum, auto
from itertools import count
from typing import Deque, Optional, Sequence, Union

from sanic.models.handler_types import MiddlewareType


class MiddlewareLocation(IntEnum):
REQUEST = auto()
RESPONSE = auto()


class Middleware:
counter = count()

__slots__ = ("func", "priority", "location", "definition")

def __init__(
self,
func: MiddlewareType,
location: MiddlewareLocation,
priority: int = 0,
) -> None:
self.func = func
self.priority = priority
self.location = location
self.definition = next(Middleware.counter)

def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)

def __repr__(self) -> str:
return (
f"{self.__class__.__name__}("
f"func=<function {self.func.__name__}>, "
f"priority={self.priority}, "
f"location={self.location.name})"
)

@property
def order(self):
return (self.priority, -self.definition)

@classmethod
def convert(
cls,
*middleware_collections: Sequence[Union[Middleware, MiddlewareType]],
location: MiddlewareLocation,
) -> Deque[Middleware]:
return deque(
[
middleware
if isinstance(middleware, Middleware)
else Middleware(middleware, location)
for collection in middleware_collections
for middleware in collection
]
)