Skip to content

Commit

Permalink
Handle streaming from ASGI (#2937)
Browse files Browse the repository at this point in the history
  • Loading branch information
ahopkins committed Apr 8, 2024
1 parent 7331ced commit 7eea12c
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 14 deletions.
19 changes: 13 additions & 6 deletions sanic/asgi.py
Expand Up @@ -219,19 +219,26 @@ def respond(self, response: BaseHTTPResponse):
return response

async def send(self, data, end_stream):
self.stage = Stage.IDLE if end_stream else Stage.RESPONSE
if self.response:
response, self.response = self.response, None
if self.stage is Stage.IDLE:
if not end_stream or data:
raise RuntimeError(
"There is no request to respond to, either the "
"response has already been sent or the "
"request has not been received yet."
)
return
if self.response and self.stage is Stage.HANDLER:
await self.transport.send(
{
"type": "http.response.start",
"status": response.status,
"headers": response.processed_headers,
"status": self.response.status,
"headers": self.response.processed_headers,
}
)
response_body = getattr(response, "body", None)
response_body = getattr(self.response, "body", None)
if response_body:
data = response_body + data if data else response_body
self.stage = Stage.IDLE if end_stream else Stage.RESPONSE
await self.transport.send(
{
"type": "http.response.body",
Expand Down
53 changes: 45 additions & 8 deletions tests/test_response.py
Expand Up @@ -575,14 +575,20 @@ async def test(request: Request):
assert "Content-Length" not in response.headers


def test_two_respond_calls(app: Sanic):
@pytest.mark.asyncio
async def test_direct_response_stream_asgi(app: Sanic):
@app.route("/")
async def handler(request: Request):
response = await request.respond()
async def test(request: Request):
response = await request.respond(content_type="text/csv")
await response.send("foo,")
await response.send("bar")
await response.eof()

_, response = await app.asgi_client.get("/")
assert response.text == "foo,bar"
assert response.headers["Content-Type"] == "text/csv"
assert "Content-Length" not in response.headers


def test_multiple_responses(
app: Sanic,
Expand Down Expand Up @@ -684,7 +690,7 @@ async def handler6(request: Request):
assert message_in_records(caplog.records, error_msg2)


def send_response_after_eof_should_fail(
def test_send_response_after_eof_should_fail(
app: Sanic,
caplog: LogCaptureFixture,
message_in_records: Callable[[List[LogRecord], str], bool],
Expand All @@ -698,17 +704,48 @@ async def handler(request: Request):

error_msg1 = (
"The error response will not be sent to the client for the following "
'exception:"Second respond call is not allowed.". A previous '
'exception:"Response stream was ended, no more response '
'data is allowed to be sent.". A previous '
"response has at least partially been sent."
)

error_msg2 = "Response stream was ended, no more response data is allowed to be sent."

with caplog.at_level(ERROR):
_, response = app.test_client.get("/")
assert "foo, " in response.text
assert message_in_records(caplog.records, error_msg1)
assert message_in_records(caplog.records, error_msg2)


@pytest.mark.asyncio
async def test_send_response_after_eof_should_fail_asgi(
app: Sanic,
caplog: LogCaptureFixture,
message_in_records: Callable[[List[LogRecord], str], bool],
):
@app.get("/")
async def handler(request: Request):
response = await request.respond()
await response.send("foo, ")
await response.eof()
await response.send("bar")

error_msg1 = (
"The error response will not be sent to the client for the "
'following exception:"There is no request to respond to, '
"either the response has already been sent or the request "
'has not been received yet.". A previous response has '
"at least partially been sent."
)

error_msg2 = (
"Response stream was ended, no more "
"response data is allowed to be sent."
"There is no request to respond to, either the response has "
"already been sent or the request has not been received yet."
)

with caplog.at_level(ERROR):
_, response = app.test_client.get("/")
_, response = await app.asgi_client.get("/")
assert "foo, " in response.text
assert message_in_records(caplog.records, error_msg1)
assert message_in_records(caplog.records, error_msg2)
Expand Down

0 comments on commit 7eea12c

Please sign in to comment.