Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent BaseHTTPMiddleware from hiding errors of StreamingResponse #1459

Merged
merged 5 commits into from Jan 31, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions starlette/middleware/base.py
Expand Up @@ -52,6 +52,9 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]:
assert message["type"] == "http.response.body"
yield message.get("body", b"")

if app_exc is not None:
raise app_exc

response = StreamingResponse(
status_code=message["status"], content=body_stream()
)
Expand Down
25 changes: 24 additions & 1 deletion tests/middleware/test_base.py
Expand Up @@ -3,7 +3,7 @@
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import PlainTextResponse
from starlette.responses import PlainTextResponse, StreamingResponse
from starlette.routing import Route


Expand All @@ -28,6 +28,16 @@ def exc(request):
raise Exception("Exc")


@app.route("/exc-stream")
def exc_stream(request):
return StreamingResponse(_generate_faulty_stream())


def _generate_faulty_stream():
yield b"Ok"
raise Exception("Faulty Stream")


@app.route("/no-response")
class NoResponse:
def __init__(self, scope, receive, send):
Expand Down Expand Up @@ -56,6 +66,10 @@ def test_custom_middleware(test_client_factory):
response = client.get("/exc")
assert str(ctx.value) == "Exc"

with pytest.raises(Exception) as ctx:
response = client.get("/exc-stream")
assert str(ctx.value) == "Faulty Stream"

with pytest.raises(RuntimeError):
response = client.get("/no-response")

Expand Down Expand Up @@ -158,3 +172,12 @@ async def dispatch(self, request, call_next):
client = test_client_factory(app)
response = client.get("/does_not_exist")
assert response.text == "Custom"

o-fedorov marked this conversation as resolved.
Show resolved Hide resolved
def test_exception_on_mounted_apps(test_client_factory):
sub_app = Starlette(routes=[Route("/", exc)])
app.mount("/sub", sub_app)

client = test_client_factory(app)
with pytest.raises(Exception) as ctx:
client.get("/sub/")
assert str(ctx.value) == "Exc"