Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Debug extension #1991

Merged
merged 4 commits into from Feb 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
29 changes: 26 additions & 3 deletions starlette/middleware/base.py
Expand Up @@ -2,8 +2,9 @@

import anyio

from starlette.background import BackgroundTask
from starlette.requests import Request
from starlette.responses import Response, StreamingResponse
from starlette.responses import ContentStream, Response, StreamingResponse
from starlette.types import ASGIApp, Message, Receive, Scope, Send

RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
Expand Down Expand Up @@ -75,6 +76,9 @@ async def coro() -> None:

try:
message = await recv_stream.receive()
info = message.get("info", None)
if message["type"] == "http.response.debug" and info is not None:
message = await recv_stream.receive()
except anyio.EndOfStream:
if app_exc is not None:
raise app_exc
Expand All @@ -93,8 +97,8 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]:
if app_exc is not None:
raise app_exc

response = StreamingResponse(
status_code=message["status"], content=body_stream()
response = _StreamingResponse(
status_code=message["status"], content=body_stream(), info=info
)
response.raw_headers = message["headers"]
return response
Expand All @@ -109,3 +113,22 @@ async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
raise NotImplementedError() # pragma: no cover


class _StreamingResponse(StreamingResponse):
def __init__(
self,
content: ContentStream,
status_code: int = 200,
headers: typing.Optional[typing.Mapping[str, str]] = None,
media_type: typing.Optional[str] = None,
background: typing.Optional[BackgroundTask] = None,
info: typing.Optional[typing.Mapping[str, typing.Any]] = None,
) -> None:
self._info = info
super().__init__(content, status_code, headers, media_type, background)

async def stream_response(self, send: Send) -> None:
if self._info:
await send({"type": "http.response.debug", "info": self._info})
return await super().stream_response(send)
10 changes: 6 additions & 4 deletions starlette/templating.py
Expand Up @@ -41,12 +41,14 @@ def __init__(
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
request = self.context.get("request", {})
extensions = request.get("extensions", {})
if "http.response.template" in extensions:
if "http.response.debug" in extensions:
await send(
{
"type": "http.response.template",
"template": self.template,
"context": self.context,
"type": "http.response.debug",
"info": {
"template": self.template,
"context": self.context,
},
}
)
await super().__call__(scope, receive, send)
Expand Down
8 changes: 4 additions & 4 deletions starlette/testclient.py
Expand Up @@ -259,7 +259,7 @@ def handle_request(self, request: httpx.Request) -> httpx.Response:
"headers": headers,
"client": ["testclient", 50000],
"server": [host, port],
"extensions": {"http.response.template": {}},
"extensions": {"http.response.debug": {}},
}

request_complete = False
Expand Down Expand Up @@ -324,9 +324,9 @@ async def send(message: Message) -> None:
if not more_body:
raw_kwargs["stream"].seek(0)
response_complete.set()
elif message["type"] == "http.response.template":
template = message["template"]
context = message["context"]
elif message["type"] == "http.response.debug":
template = message["info"]["template"]
context = message["info"]["context"]

try:
with self.portal_factory() as portal:
Expand Down
28 changes: 28 additions & 0 deletions tests/test_templates.py
Expand Up @@ -3,6 +3,8 @@
import pytest

from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.routing import Route
from starlette.templating import Jinja2Templates

Expand Down Expand Up @@ -60,3 +62,29 @@ def hello_world_processor(request):
assert response.text == "<html>Hello World</html>"
assert response.template.name == "index.html"
assert set(response.context.keys()) == {"request", "username"}


def test_template_with_middleware(tmpdir, test_client_factory):
path = os.path.join(tmpdir, "index.html")
with open(path, "w") as file:
file.write("<html>Hello, <a href='{{ url_for('homepage') }}'>world</a></html>")

async def homepage(request):
return templates.TemplateResponse("index.html", {"request": request})

class CustomMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request, call_next):
return await call_next(request)

app = Starlette(
debug=True,
routes=[Route("/", endpoint=homepage)],
middleware=[Middleware(CustomMiddleware)],
)
templates = Jinja2Templates(directory=str(tmpdir))

client = test_client_factory(app)
response = client.get("/")
assert response.text == "<html>Hello, <a href='http://testserver/'>world</a></html>"
assert response.template.name == "index.html"
assert set(response.context.keys()) == {"request"}