Skip to content

Commit

Permalink
Prevent BaseHTTPMiddleware from hiding errors of StreamingResponse (#…
Browse files Browse the repository at this point in the history
…1459)

* Prevent BaseHTTPMiddleware from hiding errors of StreamingResponse

* Apply notes from PR:

* remove `nonlocal app_exc`;
* add extra test.

* Fix formatting

Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
  • Loading branch information
o-fedorov and Kludex committed Jan 31, 2022
1 parent d6269e2 commit aff548a
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
3 changes: 3 additions & 0 deletions starlette/middleware/base.py
Expand Up @@ -52,6 +52,9 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]:
assert message["type"] == "http.response.body"
yield message.get("body", b"")

if app_exc is not None:
raise app_exc

response = StreamingResponse(
status_code=message["status"], content=body_stream()
)
Expand Down
26 changes: 25 additions & 1 deletion tests/middleware/test_base.py
Expand Up @@ -3,7 +3,7 @@
from starlette.applications import Starlette
from starlette.middleware import Middleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.responses import PlainTextResponse
from starlette.responses import PlainTextResponse, StreamingResponse
from starlette.routing import Route


Expand All @@ -28,6 +28,16 @@ def exc(request):
raise Exception("Exc")


@app.route("/exc-stream")
def exc_stream(request):
return StreamingResponse(_generate_faulty_stream())


def _generate_faulty_stream():
yield b"Ok"
raise Exception("Faulty Stream")


@app.route("/no-response")
class NoResponse:
def __init__(self, scope, receive, send):
Expand Down Expand Up @@ -56,6 +66,10 @@ def test_custom_middleware(test_client_factory):
response = client.get("/exc")
assert str(ctx.value) == "Exc"

with pytest.raises(Exception) as ctx:
response = client.get("/exc-stream")
assert str(ctx.value) == "Faulty Stream"

with pytest.raises(RuntimeError):
response = client.get("/no-response")

Expand Down Expand Up @@ -158,3 +172,13 @@ async def dispatch(self, request, call_next):
client = test_client_factory(app)
response = client.get("/does_not_exist")
assert response.text == "Custom"


def test_exception_on_mounted_apps(test_client_factory):
sub_app = Starlette(routes=[Route("/", exc)])
app.mount("/sub", sub_app)

client = test_client_factory(app)
with pytest.raises(Exception) as ctx:
client.get("/sub/")
assert str(ctx.value) == "Exc"

0 comments on commit aff548a

Please sign in to comment.