Skip to content

Commit

Permalink
Add Request contextvars (#2475)
Browse files Browse the repository at this point in the history
* Add Request contextvars

* Add missing contextvar setter

* Move location of context setter
  • Loading branch information
ahopkins committed Jun 16, 2022
1 parent a744041 commit ce926a3
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 2 deletions.
1 change: 1 addition & 0 deletions sanic/http.py
Expand Up @@ -265,6 +265,7 @@ async def http1_request_header(self): # no cov
transport=self.protocol.transport,
app=self.protocol.app,
)
self.protocol.request_class._current.set(request)
await self.dispatch(
"http.lifecycle.request",
inline=True,
Expand Down
12 changes: 11 additions & 1 deletion sanic/request.py
@@ -1,5 +1,6 @@
from __future__ import annotations

from contextvars import ContextVar
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -35,7 +36,7 @@

from sanic.compat import CancelledErrors, Header
from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE
from sanic.exceptions import BadRequest, BadURL, ServerError
from sanic.exceptions import BadRequest, BadURL, SanicException, ServerError
from sanic.headers import (
AcceptContainer,
Options,
Expand Down Expand Up @@ -82,6 +83,8 @@ class Request:
Properties of an HTTP request such as URL, headers, etc.
"""

_current: ContextVar[Request] = ContextVar("request")

__slots__ = (
"__weakref__",
"_cookies",
Expand Down Expand Up @@ -174,6 +177,13 @@ def __repr__(self):
class_name = self.__class__.__name__
return f"<{class_name}: {self.method} {self.path}>"

@classmethod
def get_current(cls) -> Request:
request = cls._current.get(None)
if not request:
raise SanicException("No current request")
return request

@classmethod
def generate_id(*_):
return uuid.uuid4()
Expand Down
16 changes: 15 additions & 1 deletion tests/test_request.py
Expand Up @@ -4,7 +4,7 @@
import pytest

from sanic import Sanic, response
from sanic.exceptions import BadURL
from sanic.exceptions import BadURL, SanicException
from sanic.request import Request, uuid
from sanic.server import HttpProtocol

Expand Down Expand Up @@ -217,3 +217,17 @@ async def get(request):
assert request.scope is not None
assert request.scope["method"].lower() == "get"
assert request.scope["path"].lower() == "/"


def test_cannot_get_request_outside_of_cycle():
with pytest.raises(SanicException, match="No current request"):
Request.get_current()


def test_get_current_request(app):
@app.get("/")
async def get(request):
return response.json({"same": request is Request.get_current()})

_, resp = app.test_client.get("/")
assert resp.json["same"]

0 comments on commit ce926a3

Please sign in to comment.