diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index 8f2f02ea06..c9880206e2 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -2,8 +2,9 @@ import anyio +from starlette.background import BackgroundTask from starlette.requests import Request -from starlette.responses import Response, StreamingResponse +from starlette.responses import ContentStream, Response, StreamingResponse from starlette.types import ASGIApp, Message, Receive, Scope, Send RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]] @@ -75,6 +76,9 @@ async def coro() -> None: try: message = await recv_stream.receive() + info = message.get("info", None) + if message["type"] == "http.response.debug" and info is not None: + message = await recv_stream.receive() except anyio.EndOfStream: if app_exc is not None: raise app_exc @@ -93,8 +97,8 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]: if app_exc is not None: raise app_exc - response = StreamingResponse( - status_code=message["status"], content=body_stream() + response = _StreamingResponse( + status_code=message["status"], content=body_stream(), info=info ) response.raw_headers = message["headers"] return response @@ -109,3 +113,22 @@ async def dispatch( self, request: Request, call_next: RequestResponseEndpoint ) -> Response: raise NotImplementedError() # pragma: no cover + + +class _StreamingResponse(StreamingResponse): + def __init__( + self, + content: ContentStream, + status_code: int = 200, + headers: typing.Optional[typing.Mapping[str, str]] = None, + media_type: typing.Optional[str] = None, + background: typing.Optional[BackgroundTask] = None, + info: typing.Optional[typing.Mapping[str, typing.Any]] = None, + ) -> None: + self._info = info + super().__init__(content, status_code, headers, media_type, background) + + async def stream_response(self, send: Send) -> None: + if self._info: + await send({"type": "http.response.debug", "info": self._info}) + return await super().stream_response(send) diff --git a/starlette/templating.py b/starlette/templating.py index 7c46a65a0b..ecea4f3a3d 100644 --- a/starlette/templating.py +++ b/starlette/templating.py @@ -42,12 +42,14 @@ def __init__( async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: request = self.context.get("request", {}) extensions = request.get("extensions", {}) - if "http.response.template" in extensions: + if "http.response.debug" in extensions: await send( { - "type": "http.response.template", - "template": self.template, - "context": self.context, + "type": "http.response.debug", + "info": { + "template": self.template, + "context": self.context, + }, } ) await super().__call__(scope, receive, send) diff --git a/starlette/testclient.py b/starlette/testclient.py index 48afcb6734..549fa76219 100644 --- a/starlette/testclient.py +++ b/starlette/testclient.py @@ -259,7 +259,7 @@ def handle_request(self, request: httpx.Request) -> httpx.Response: "headers": headers, "client": ["testclient", 50000], "server": [host, port], - "extensions": {"http.response.template": {}}, + "extensions": {"http.response.debug": {}}, } request_complete = False @@ -324,9 +324,9 @@ async def send(message: Message) -> None: if not more_body: raw_kwargs["stream"].seek(0) response_complete.set() - elif message["type"] == "http.response.template": - template = message["template"] - context = message["context"] + elif message["type"] == "http.response.debug": + template = message["info"]["template"] + context = message["info"]["context"] try: with self.portal_factory() as portal: diff --git a/tests/test_templates.py b/tests/test_templates.py index 0bf4bce075..2d918f7d0c 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -3,6 +3,8 @@ import pytest from starlette.applications import Starlette +from starlette.middleware import Middleware +from starlette.middleware.base import BaseHTTPMiddleware from starlette.routing import Route from starlette.templating import Jinja2Templates @@ -60,3 +62,29 @@ def hello_world_processor(request): assert response.text == "Hello World" assert response.template.name == "index.html" assert set(response.context.keys()) == {"request", "username"} + + +def test_template_with_middleware(tmpdir, test_client_factory): + path = os.path.join(tmpdir, "index.html") + with open(path, "w") as file: + file.write("Hello, world") + + async def homepage(request): + return templates.TemplateResponse("index.html", {"request": request}) + + class CustomMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request, call_next): + return await call_next(request) + + app = Starlette( + debug=True, + routes=[Route("/", endpoint=homepage)], + middleware=[Middleware(CustomMiddleware)], + ) + templates = Jinja2Templates(directory=str(tmpdir)) + + client = test_client_factory(app) + response = client.get("/") + assert response.text == "Hello, world" + assert response.template.name == "index.html" + assert set(response.context.keys()) == {"request"}