diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py
index 8f2f02ea06..c9880206e2 100644
--- a/starlette/middleware/base.py
+++ b/starlette/middleware/base.py
@@ -2,8 +2,9 @@
import anyio
+from starlette.background import BackgroundTask
from starlette.requests import Request
-from starlette.responses import Response, StreamingResponse
+from starlette.responses import ContentStream, Response, StreamingResponse
from starlette.types import ASGIApp, Message, Receive, Scope, Send
RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]]
@@ -75,6 +76,9 @@ async def coro() -> None:
try:
message = await recv_stream.receive()
+ info = message.get("info", None)
+ if message["type"] == "http.response.debug" and info is not None:
+ message = await recv_stream.receive()
except anyio.EndOfStream:
if app_exc is not None:
raise app_exc
@@ -93,8 +97,8 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]:
if app_exc is not None:
raise app_exc
- response = StreamingResponse(
- status_code=message["status"], content=body_stream()
+ response = _StreamingResponse(
+ status_code=message["status"], content=body_stream(), info=info
)
response.raw_headers = message["headers"]
return response
@@ -109,3 +113,22 @@ async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
raise NotImplementedError() # pragma: no cover
+
+
+class _StreamingResponse(StreamingResponse):
+ def __init__(
+ self,
+ content: ContentStream,
+ status_code: int = 200,
+ headers: typing.Optional[typing.Mapping[str, str]] = None,
+ media_type: typing.Optional[str] = None,
+ background: typing.Optional[BackgroundTask] = None,
+ info: typing.Optional[typing.Mapping[str, typing.Any]] = None,
+ ) -> None:
+ self._info = info
+ super().__init__(content, status_code, headers, media_type, background)
+
+ async def stream_response(self, send: Send) -> None:
+ if self._info:
+ await send({"type": "http.response.debug", "info": self._info})
+ return await super().stream_response(send)
diff --git a/starlette/templating.py b/starlette/templating.py
index 7c46a65a0b..ecea4f3a3d 100644
--- a/starlette/templating.py
+++ b/starlette/templating.py
@@ -42,12 +42,14 @@ def __init__(
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
request = self.context.get("request", {})
extensions = request.get("extensions", {})
- if "http.response.template" in extensions:
+ if "http.response.debug" in extensions:
await send(
{
- "type": "http.response.template",
- "template": self.template,
- "context": self.context,
+ "type": "http.response.debug",
+ "info": {
+ "template": self.template,
+ "context": self.context,
+ },
}
)
await super().__call__(scope, receive, send)
diff --git a/starlette/testclient.py b/starlette/testclient.py
index 48afcb6734..549fa76219 100644
--- a/starlette/testclient.py
+++ b/starlette/testclient.py
@@ -259,7 +259,7 @@ def handle_request(self, request: httpx.Request) -> httpx.Response:
"headers": headers,
"client": ["testclient", 50000],
"server": [host, port],
- "extensions": {"http.response.template": {}},
+ "extensions": {"http.response.debug": {}},
}
request_complete = False
@@ -324,9 +324,9 @@ async def send(message: Message) -> None:
if not more_body:
raw_kwargs["stream"].seek(0)
response_complete.set()
- elif message["type"] == "http.response.template":
- template = message["template"]
- context = message["context"]
+ elif message["type"] == "http.response.debug":
+ template = message["info"]["template"]
+ context = message["info"]["context"]
try:
with self.portal_factory() as portal:
diff --git a/tests/test_templates.py b/tests/test_templates.py
index 0bf4bce075..2d918f7d0c 100644
--- a/tests/test_templates.py
+++ b/tests/test_templates.py
@@ -3,6 +3,8 @@
import pytest
from starlette.applications import Starlette
+from starlette.middleware import Middleware
+from starlette.middleware.base import BaseHTTPMiddleware
from starlette.routing import Route
from starlette.templating import Jinja2Templates
@@ -60,3 +62,29 @@ def hello_world_processor(request):
assert response.text == "Hello World"
assert response.template.name == "index.html"
assert set(response.context.keys()) == {"request", "username"}
+
+
+def test_template_with_middleware(tmpdir, test_client_factory):
+ path = os.path.join(tmpdir, "index.html")
+ with open(path, "w") as file:
+ file.write("Hello, world")
+
+ async def homepage(request):
+ return templates.TemplateResponse("index.html", {"request": request})
+
+ class CustomMiddleware(BaseHTTPMiddleware):
+ async def dispatch(self, request, call_next):
+ return await call_next(request)
+
+ app = Starlette(
+ debug=True,
+ routes=[Route("/", endpoint=homepage)],
+ middleware=[Middleware(CustomMiddleware)],
+ )
+ templates = Jinja2Templates(directory=str(tmpdir))
+
+ client = test_client_factory(app)
+ response = client.get("/")
+ assert response.text == "Hello, world"
+ assert response.template.name == "index.html"
+ assert set(response.context.keys()) == {"request"}