From 96c027bad526d16c781b90c16a258b85e0eb34cd Mon Sep 17 00:00:00 2001 From: Zhiwei <43905414+ChihweiLHBird@users.noreply.github.com> Date: Thu, 9 Dec 2021 03:00:18 -0700 Subject: [PATCH] Prevent sending multiple or mixed responses on a single request (#2327) Co-authored-by: Adam Hopkins Co-authored-by: Adam Hopkins --- sanic/app.py | 59 ++++++++- sanic/asgi.py | 17 ++- sanic/http.py | 5 + sanic/request.py | 30 ++++- sanic/response.py | 20 ++- tests/conftest.py | 16 ++- tests/test_exceptions_handler.py | 75 +++++++++-- tests/test_middleware.py | 24 ++++ tests/test_requests.py | 4 +- tests/test_response.py | 208 ++++++++++++++++++++++++++----- 10 files changed, 405 insertions(+), 53 deletions(-) diff --git a/sanic/app.py b/sanic/app.py index e78e53da8f..f02301657f 100644 --- a/sanic/app.py +++ b/sanic/app.py @@ -42,7 +42,7 @@ Union, ) from urllib.parse import urlencode, urlunparse -from warnings import filterwarnings +from warnings import filterwarnings, warn from sanic_routing.exceptions import ( # type: ignore FinalizationError, @@ -67,6 +67,7 @@ URLBuildError, ) from sanic.handlers import ErrorHandler +from sanic.http import Stage from sanic.log import LOGGING_CONFIG_DEFAULTS, Colors, error_logger, logger from sanic.mixins.listeners import ListenerEvent from sanic.models.futures import ( @@ -736,6 +737,50 @@ async def handle_exception( context={"request": request, "exception": exception}, ) + if ( + request.stream is not None + and request.stream.stage is not Stage.HANDLER + ): + error_logger.exception(exception, exc_info=True) + logger.error( + "The error response will not be sent to the client for " + f'the following exception:"{exception}". A previous response ' + "has at least partially been sent." + ) + + # ----------------- deprecated ----------------- + handler = self.error_handler._lookup( + exception, request.name if request else None + ) + if handler: + warn( + "An error occurred while handling the request after at " + "least some part of the response was sent to the client. " + "Therefore, the response from your custom exception " + f"handler {handler.__name__} will not be sent to the " + "client. Beginning in v22.6, Sanic will stop executing " + "custom exception handlers in this scenario. Exception " + "handlers should only be used to generate the exception " + "responses. If you would like to perform any other " + "action on a raised exception, please consider using a " + "signal handler like " + '`@app.signal("http.lifecycle.exception")`\n' + "For further information, please see the docs: " + "https://sanicframework.org/en/guide/advanced/" + "signals.html", + DeprecationWarning, + ) + try: + response = self.error_handler.response(request, exception) + if isawaitable(response): + response = await response + except BaseException as e: + logger.error("An error occurred in the exception handler.") + error_logger.exception(e) + # ---------------------------------------------- + + return + # -------------------------------------------- # # Request Middleware # -------------------------------------------- # @@ -765,6 +810,7 @@ async def handle_exception( ) if response is not None: try: + request.reset_response() response = await request.respond(response) except BaseException: # Skip response middleware @@ -874,7 +920,16 @@ async def handle_request(self, request: Request): # no cov if isawaitable(response): response = await response - if response is not None: + if request.responded: + if response is not None: + error_logger.error( + "The response object returned by the route handler " + "will not be sent to client. The request has already " + "been responded to." + ) + if request.stream is not None: + response = request.stream.response + elif response is not None: response = await request.respond(response) elif not hasattr(handler, "is_websocket"): response = request.stream.response # type: ignore diff --git a/sanic/asgi.py b/sanic/asgi.py index 55c18d5cf5..00b181dcde 100644 --- a/sanic/asgi.py +++ b/sanic/asgi.py @@ -7,8 +7,10 @@ from sanic.compat import Header from sanic.exceptions import ServerError +from sanic.http import Stage from sanic.models.asgi import ASGIReceive, ASGIScope, ASGISend, MockTransport from sanic.request import Request +from sanic.response import BaseHTTPResponse from sanic.server import ConnInfo from sanic.server.websockets.connection import WebSocketConnection @@ -83,6 +85,8 @@ class ASGIApp: transport: MockTransport lifespan: Lifespan ws: Optional[WebSocketConnection] + stage: Stage + response: Optional[BaseHTTPResponse] def __init__(self) -> None: self.ws = None @@ -95,6 +99,8 @@ async def create( instance.sanic_app = sanic_app instance.transport = MockTransport(scope, receive, send) instance.transport.loop = sanic_app.loop + instance.stage = Stage.IDLE + instance.response = None setattr(instance.transport, "add_task", sanic_app.loop.create_task) headers = Header( @@ -149,6 +155,8 @@ async def read(self) -> Optional[bytes]: """ Read and stream the body in chunks from an incoming ASGI message. """ + if self.stage is Stage.IDLE: + self.stage = Stage.REQUEST message = await self.transport.receive() body = message.get("body", b"") if not message.get("more_body", False): @@ -163,11 +171,17 @@ async def __aiter__(self): if data: yield data - def respond(self, response): + def respond(self, response: BaseHTTPResponse): + if self.stage is not Stage.HANDLER: + self.stage = Stage.FAILED + raise RuntimeError("Response already started") + if self.response is not None: + self.response.stream = None response.stream, self.response = self, response return response async def send(self, data, end_stream): + self.stage = Stage.IDLE if end_stream else Stage.RESPONSE if self.response: response, self.response = self.response, None await self.transport.send( @@ -195,6 +209,7 @@ async def __call__(self) -> None: Handle the incoming request. """ try: + self.stage = Stage.HANDLER await self.sanic_app.handle_request(self.request) except Exception as e: await self.sanic_app.handle_exception(self.request, e) diff --git a/sanic/http.py b/sanic/http.py index 6f59ef250f..86f23fe3e3 100644 --- a/sanic/http.py +++ b/sanic/http.py @@ -584,6 +584,11 @@ def respond(self, response: BaseHTTPResponse) -> BaseHTTPResponse: self.stage = Stage.FAILED raise RuntimeError("Response already started") + # Disconnect any earlier but unused response object + if self.response is not None: + self.response.stream = None + + # Connect and return the response self.response, response.stream = response, self return response diff --git a/sanic/request.py b/sanic/request.py index 68c2725724..ddec6e825d 100644 --- a/sanic/request.py +++ b/sanic/request.py @@ -18,7 +18,6 @@ if TYPE_CHECKING: from sanic.server import ConnInfo from sanic.app import Sanic - from sanic.http import Http import email.utils import uuid @@ -32,7 +31,7 @@ from sanic.compat import CancelledErrors, Header from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE -from sanic.exceptions import InvalidUsage +from sanic.exceptions import InvalidUsage, ServerError from sanic.headers import ( AcceptContainer, Options, @@ -42,6 +41,7 @@ parse_host, parse_xforwarded, ) +from sanic.http import Http, Stage from sanic.log import error_logger, logger from sanic.models.protocol_types import TransportProtocol from sanic.response import BaseHTTPResponse, HTTPResponse @@ -104,6 +104,7 @@ class Request: "parsed_json", "parsed_forwarded", "raw_url", + "responded", "request_middleware_started", "route", "stream", @@ -155,6 +156,7 @@ def __init__( self.stream: Optional[Http] = None self.route: Optional[Route] = None self._protocol = None + self.responded: bool = False def __repr__(self): class_name = self.__class__.__name__ @@ -164,6 +166,21 @@ def __repr__(self): def generate_id(*_): return uuid.uuid4() + def reset_response(self): + try: + if ( + self.stream is not None + and self.stream.stage is not Stage.HANDLER + ): + raise ServerError( + "Cannot reset response because previous response was sent." + ) + self.stream.response.stream = None + self.stream.response = None + self.responded = False + except AttributeError: + pass + async def respond( self, response: Optional[BaseHTTPResponse] = None, @@ -172,13 +189,19 @@ async def respond( headers: Optional[Union[Header, Dict[str, str]]] = None, content_type: Optional[str] = None, ): + try: + if self.stream is not None and self.stream.response: + raise ServerError("Second respond call is not allowed.") + except AttributeError: + pass # This logic of determining which response to use is subject to change if response is None: - response = (self.stream and self.stream.response) or HTTPResponse( + response = HTTPResponse( status=status, headers=headers, content_type=content_type, ) + # Connect the response if isinstance(response, BaseHTTPResponse) and self.stream: response = self.stream.respond(response) @@ -193,6 +216,7 @@ async def respond( error_logger.exception( "Exception occurred in one of response middleware handlers" ) + self.responded = True return response async def receive_body(self): diff --git a/sanic/response.py b/sanic/response.py index 1da4486a1d..357668e682 100644 --- a/sanic/response.py +++ b/sanic/response.py @@ -3,6 +3,7 @@ from os import path from pathlib import PurePath from typing import ( + TYPE_CHECKING, Any, AnyStr, Callable, @@ -19,11 +20,15 @@ from sanic.compat import Header, open_async from sanic.constants import DEFAULT_HTTP_CONTENT_TYPE from sanic.cookies import CookieJar +from sanic.exceptions import SanicException, ServerError from sanic.helpers import has_message_body, remove_entity_headers from sanic.http import Http from sanic.models.protocol_types import HTMLProtocol, Range +if TYPE_CHECKING: + from sanic.asgi import ASGIApp + try: from ujson import dumps as json_dumps except ImportError: @@ -45,7 +50,7 @@ def __init__(self): self.asgi: bool = False self.body: Optional[bytes] = None self.content_type: Optional[str] = None - self.stream: Http = None + self.stream: Optional[Union[Http, ASGIApp]] = None self.status: int = None self.headers = Header({}) self._cookies: Optional[CookieJar] = None @@ -112,8 +117,17 @@ async def send( """ if data is None and end_stream is None: end_stream = True - if end_stream and not data and self.stream.send is None: - return + if self.stream is None: + raise SanicException( + "No stream is connected to the response object instance." + ) + if self.stream.send is None: + if end_stream and not data: + return + raise ServerError( + "Response stream was ended, no more response data is " + "allowed to be sent." + ) data = ( data.encode() # type: ignore if hasattr(data, "encode") diff --git a/tests/conftest.py b/tests/conftest.py index 175e967efa..292914cd4d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,7 +6,8 @@ import sys import uuid -from typing import Tuple +from logging import LogRecord +from typing import Callable, List, Tuple import pytest @@ -170,3 +171,16 @@ def run(app): return caplog.record_tuples return run + + +@pytest.fixture(scope="function") +def message_in_records(): + def msg_in_log(records: List[LogRecord], msg: str): + error_captured = False + for record in records: + if msg in record.message: + error_captured = True + break + return error_captured + + return msg_in_log diff --git a/tests/test_exceptions_handler.py b/tests/test_exceptions_handler.py index 371baa8dd9..a0de67373a 100644 --- a/tests/test_exceptions_handler.py +++ b/tests/test_exceptions_handler.py @@ -1,15 +1,18 @@ import asyncio import logging +from typing import Callable, List from unittest.mock import Mock import pytest from bs4 import BeautifulSoup +from pytest import LogCaptureFixture, MonkeyPatch from sanic import Sanic, handlers from sanic.exceptions import Forbidden, InvalidUsage, NotFound, ServerError from sanic.handlers import ErrorHandler +from sanic.request import Request from sanic.response import stream, text @@ -90,35 +93,35 @@ async def some_request_middleware(request): return exception_handler_app -def test_invalid_usage_exception_handler(exception_handler_app): +def test_invalid_usage_exception_handler(exception_handler_app: Sanic): request, response = exception_handler_app.test_client.get("/1") assert response.status == 400 -def test_server_error_exception_handler(exception_handler_app): +def test_server_error_exception_handler(exception_handler_app: Sanic): request, response = exception_handler_app.test_client.get("/2") assert response.status == 200 assert response.text == "OK" -def test_not_found_exception_handler(exception_handler_app): +def test_not_found_exception_handler(exception_handler_app: Sanic): request, response = exception_handler_app.test_client.get("/3") assert response.status == 200 -def test_text_exception__handler(exception_handler_app): +def test_text_exception__handler(exception_handler_app: Sanic): request, response = exception_handler_app.test_client.get("/random") assert response.status == 200 assert response.text == "Done." -def test_async_exception_handler(exception_handler_app): +def test_async_exception_handler(exception_handler_app: Sanic): request, response = exception_handler_app.test_client.get("/7") assert response.status == 200 assert response.text == "foo,bar" -def test_html_traceback_output_in_debug_mode(exception_handler_app): +def test_html_traceback_output_in_debug_mode(exception_handler_app: Sanic): request, response = exception_handler_app.test_client.get("/4", debug=True) assert response.status == 500 soup = BeautifulSoup(response.body, "html.parser") @@ -133,12 +136,12 @@ def test_html_traceback_output_in_debug_mode(exception_handler_app): ) == summary_text -def test_inherited_exception_handler(exception_handler_app): +def test_inherited_exception_handler(exception_handler_app: Sanic): request, response = exception_handler_app.test_client.get("/5") assert response.status == 200 -def test_chained_exception_handler(exception_handler_app): +def test_chained_exception_handler(exception_handler_app: Sanic): request, response = exception_handler_app.test_client.get( "/6/0", debug=True ) @@ -157,7 +160,7 @@ def test_chained_exception_handler(exception_handler_app): ) == summary_text -def test_exception_handler_lookup(exception_handler_app): +def test_exception_handler_lookup(exception_handler_app: Sanic): class CustomError(Exception): pass @@ -205,13 +208,17 @@ class ModuleNotFoundError(ImportError): ) -def test_exception_handler_processed_request_middleware(exception_handler_app): +def test_exception_handler_processed_request_middleware( + exception_handler_app: Sanic, +): request, response = exception_handler_app.test_client.get("/8") assert response.status == 200 assert response.text == "Done." -def test_single_arg_exception_handler_notice(exception_handler_app, caplog): +def test_single_arg_exception_handler_notice( + exception_handler_app: Sanic, caplog: LogCaptureFixture +): class CustomErrorHandler(ErrorHandler): def lookup(self, exception): return super().lookup(exception, None) @@ -233,7 +240,9 @@ def lookup(self, exception): assert response.status == 400 -def test_error_handler_noisy_log(exception_handler_app, monkeypatch): +def test_error_handler_noisy_log( + exception_handler_app: Sanic, monkeypatch: MonkeyPatch +): err_logger = Mock() monkeypatch.setattr(handlers, "error_logger", err_logger) @@ -246,3 +255,45 @@ def test_error_handler_noisy_log(exception_handler_app, monkeypatch): err_logger.exception.assert_called_with( "Exception occurred while handling uri: %s", repr(request.url) ) + + +def test_exception_handler_response_was_sent( + app: Sanic, + caplog: LogCaptureFixture, + message_in_records: Callable[[List[logging.LogRecord], str], bool], +): + exception_handler_ran = False + + @app.exception(ServerError) + async def exception_handler(request, exception): + nonlocal exception_handler_ran + exception_handler_ran = True + return text("Error") + + @app.route("/1") + async def handler1(request: Request): + response = await request.respond() + await response.send("some text") + raise ServerError("Exception") + + @app.route("/2") + async def handler2(request: Request): + response = await request.respond() + raise ServerError("Exception") + + with caplog.at_level(logging.WARNING): + _, response = app.test_client.get("/1") + assert "some text" in response.text + + # Change to assert warning not in the records in the future version. + message_in_records( + caplog.records, + ( + "An error occurred while handling the request after at " + "least some part of the response was sent to the client. " + "Therefore, the response from your custom exception " + ), + ) + + _, response = app.test_client.get("/2") + assert "Error" in response.text diff --git a/tests/test_middleware.py b/tests/test_middleware.py index c19386e7e4..2163e47c28 100644 --- a/tests/test_middleware.py +++ b/tests/test_middleware.py @@ -297,3 +297,27 @@ async def handler(request): _, response = app.test_client.get("/") assert response.json["foo"] == "bar" + + +def test_middleware_return_response(app): + response_middleware_run_count = 0 + request_middleware_run_count = 0 + + @app.on_response + def response(_, response): + nonlocal response_middleware_run_count + response_middleware_run_count += 1 + + @app.on_request + def request(_): + nonlocal request_middleware_run_count + request_middleware_run_count += 1 + + @app.get("/") + async def handler(request): + resp1 = await request.respond() + return resp1 + + _, response = app.test_client.get("/") + assert response_middleware_run_count == 1 + assert request_middleware_run_count == 1 diff --git a/tests/test_requests.py b/tests/test_requests.py index e5db9d20db..c8f6e3f016 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -15,8 +15,8 @@ ) from sanic import Blueprint, Sanic -from sanic.exceptions import ServerError -from sanic.request import DEFAULT_HTTP_CONTENT_TYPE, RequestParameters +from sanic.exceptions import SanicException, ServerError +from sanic.request import DEFAULT_HTTP_CONTENT_TYPE, Request, RequestParameters from sanic.response import html, json, text diff --git a/tests/test_response.py b/tests/test_response.py index 0676b8851f..8d301abfee 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -3,15 +3,18 @@ import os from collections import namedtuple +from logging import ERROR, LogRecord from mimetypes import guess_type from random import choice +from typing import Callable, List from urllib.parse import unquote import pytest from aiofiles import os as async_os +from pytest import LogCaptureFixture -from sanic import Sanic +from sanic import Request, Sanic from sanic.response import ( HTTPResponse, empty, @@ -33,7 +36,7 @@ def test_response_body_not_a_string(app): random_num = choice(range(1000)) @app.route("/hello") - async def hello_route(request): + async def hello_route(request: Request): return text(random_num) request, response = app.test_client.get("/hello") @@ -51,7 +54,7 @@ def test_method_not_allowed(): app = Sanic("app") @app.get("/") - async def test_get(request): + async def test_get(request: Request): return response.json({"hello": "world"}) request, response = app.test_client.head("/") @@ -67,7 +70,7 @@ async def test_get(request): app.router.reset() @app.post("/") - async def test_post(request): + async def test_post(request: Request): return response.json({"hello": "world"}) request, response = app.test_client.head("/") @@ -89,7 +92,7 @@ async def test_post(request): def test_response_header(app): @app.get("/") - async def test(request): + async def test(request: Request): return json({"ok": True}, headers={"CONTENT-TYPE": "application/json"}) request, response = app.test_client.get("/") @@ -102,14 +105,14 @@ async def test(request): def test_response_content_length(app): @app.get("/response_with_space") - async def response_with_space(request): + async def response_with_space(request: Request): return json( {"message": "Data", "details": "Some Details"}, headers={"CONTENT-TYPE": "application/json"}, ) @app.get("/response_without_space") - async def response_without_space(request): + async def response_without_space(request: Request): return json( {"message": "Data", "details": "Some Details"}, headers={"CONTENT-TYPE": "application/json"}, @@ -135,7 +138,7 @@ async def response_without_space(request): def test_response_content_length_with_different_data_types(app): @app.get("/") - async def get_data_with_different_types(request): + async def get_data_with_different_types(request: Request): # Indentation issues in the Response is intentional. Please do not fix return json( {"bool": True, "none": None, "string": "string", "number": -1}, @@ -149,23 +152,23 @@ async def get_data_with_different_types(request): @pytest.fixture def json_app(app): @app.route("/") - async def test(request): + async def test(request: Request): return json(JSON_DATA) @app.get("/no-content") - async def no_content_handler(request): + async def no_content_handler(request: Request): return json(JSON_DATA, status=204) @app.get("/no-content/unmodified") - async def no_content_unmodified_handler(request): + async def no_content_unmodified_handler(request: Request): return json(None, status=304) @app.get("/unmodified") - async def unmodified_handler(request): + async def unmodified_handler(request: Request): return json(JSON_DATA, status=304) @app.delete("/") - async def delete_handler(request): + async def delete_handler(request: Request): return json(None, status=204) return app @@ -207,7 +210,7 @@ def test_no_content(json_app): @pytest.fixture def streaming_app(app): @app.route("/") - async def test(request): + async def test(request: Request): return stream( sample_streaming_fn, content_type="text/csv", @@ -219,7 +222,7 @@ async def test(request): @pytest.fixture def non_chunked_streaming_app(app): @app.route("/") - async def test(request): + async def test(request: Request): return stream( sample_streaming_fn, headers={"Content-Length": "7"}, @@ -276,7 +279,7 @@ def test_non_chunked_streaming_returns_correct_content( def test_stream_response_with_cookies(app): @app.route("/") - async def test(request): + async def test(request: Request): response = stream(sample_streaming_fn, content_type="text/csv") response.cookies["test"] = "modified" response.cookies["test"] = "pass" @@ -288,7 +291,7 @@ async def test(request): def test_stream_response_without_cookies(app): @app.route("/") - async def test(request): + async def test(request: Request): return stream(sample_streaming_fn, content_type="text/csv") request, response = app.test_client.get("/") @@ -314,7 +317,7 @@ def get_file_content(static_file_directory, file_name): "file_name", ["test.file", "decode me.txt", "python.png"] ) @pytest.mark.parametrize("status", [200, 401]) -def test_file_response(app, file_name, static_file_directory, status): +def test_file_response(app: Sanic, file_name, static_file_directory, status): @app.route("/files/", methods=["GET"]) def file_route(request, filename): file_path = os.path.join(static_file_directory, filename) @@ -340,7 +343,7 @@ def file_route(request, filename): ], ) def test_file_response_custom_filename( - app, source, dest, static_file_directory + app: Sanic, source, dest, static_file_directory ): @app.route("/files/", methods=["GET"]) def file_route(request, filename): @@ -358,7 +361,7 @@ def file_route(request, filename): @pytest.mark.parametrize("file_name", ["test.file", "decode me.txt"]) -def test_file_head_response(app, file_name, static_file_directory): +def test_file_head_response(app: Sanic, file_name, static_file_directory): @app.route("/files/", methods=["GET", "HEAD"]) async def file_route(request, filename): file_path = os.path.join(static_file_directory, filename) @@ -391,7 +394,7 @@ async def file_route(request, filename): @pytest.mark.parametrize( "file_name", ["test.file", "decode me.txt", "python.png"] ) -def test_file_stream_response(app, file_name, static_file_directory): +def test_file_stream_response(app: Sanic, file_name, static_file_directory): @app.route("/files/", methods=["GET"]) def file_route(request, filename): file_path = os.path.join(static_file_directory, filename) @@ -417,7 +420,7 @@ def file_route(request, filename): ], ) def test_file_stream_response_custom_filename( - app, source, dest, static_file_directory + app: Sanic, source, dest, static_file_directory ): @app.route("/files/", methods=["GET"]) def file_route(request, filename): @@ -435,7 +438,9 @@ def file_route(request, filename): @pytest.mark.parametrize("file_name", ["test.file", "decode me.txt"]) -def test_file_stream_head_response(app, file_name, static_file_directory): +def test_file_stream_head_response( + app: Sanic, file_name, static_file_directory +): @app.route("/files/", methods=["GET", "HEAD"]) async def file_route(request, filename): file_path = os.path.join(static_file_directory, filename) @@ -479,7 +484,7 @@ async def file_route(request, filename): "size,start,end", [(1024, 0, 1024), (4096, 1024, 8192)] ) def test_file_stream_response_range( - app, file_name, static_file_directory, size, start, end + app: Sanic, file_name, static_file_directory, size, start, end ): Range = namedtuple("Range", ["size", "start", "end", "total"]) @@ -508,7 +513,7 @@ def file_route(request, filename): def test_raw_response(app): @app.get("/test") - def handler(request): + def handler(request: Request): return raw(b"raw_response") request, response = app.test_client.get("/test") @@ -518,7 +523,7 @@ def handler(request): def test_empty_response(app): @app.get("/test") - def handler(request): + def handler(request: Request): return empty() request, response = app.test_client.get("/test") @@ -526,17 +531,162 @@ def handler(request): assert response.body == b"" -def test_direct_response_stream(app): +def test_direct_response_stream(app: Sanic): @app.route("/") - async def test(request): + async def test(request: Request): response = await request.respond(content_type="text/csv") await response.send("foo,") await response.send("bar") await response.eof() - return response _, response = app.test_client.get("/") assert response.text == "foo,bar" assert response.headers["Transfer-Encoding"] == "chunked" assert response.headers["Content-Type"] == "text/csv" assert "Content-Length" not in response.headers + + +def test_two_respond_calls(app: Sanic): + @app.route("/") + async def handler(request: Request): + response = await request.respond() + await response.send("foo,") + await response.send("bar") + await response.eof() + + +def test_multiple_responses( + app: Sanic, + caplog: LogCaptureFixture, + message_in_records: Callable[[List[LogRecord], str], bool], +): + @app.route("/1") + async def handler(request: Request): + response = await request.respond() + await response.send("foo") + response = await request.respond() + + @app.route("/2") + async def handler(request: Request): + response = await request.respond() + response = await request.respond() + await response.send("foo") + + @app.get("/3") + async def handler(request: Request): + response = await request.respond() + await response.send("foo,") + response = await request.respond() + await response.send("bar") + + @app.get("/4") + async def handler(request: Request): + response = await request.respond(headers={"one": "one"}) + return json({"foo": "bar"}, headers={"one": "two"}) + + @app.get("/5") + async def handler(request: Request): + response = await request.respond(headers={"one": "one"}) + await response.send("foo") + return json({"foo": "bar"}, headers={"one": "two"}) + + @app.get("/6") + async def handler(request: Request): + response = await request.respond(headers={"one": "one"}) + await response.send("foo, ") + json_response = json({"foo": "bar"}, headers={"one": "two"}) + await response.send("bar") + return json_response + + error_msg0 = "Second respond call is not allowed." + + error_msg1 = ( + "The error response will not be sent to the client for the following " + 'exception:"Second respond call is not allowed.". A previous ' + "response has at least partially been sent." + ) + + error_msg2 = ( + "The response object returned by the route handler " + "will not be sent to client. The request has already " + "been responded to." + ) + + error_msg3 = ( + "Response stream was ended, no more " + "response data is allowed to be sent." + ) + + with caplog.at_level(ERROR): + _, response = app.test_client.get("/1") + assert response.status == 200 + assert message_in_records(caplog.records, error_msg0) + assert message_in_records(caplog.records, error_msg1) + + with caplog.at_level(ERROR): + _, response = app.test_client.get("/2") + assert response.status == 500 + assert "500 — Internal Server Error" in response.text + + with caplog.at_level(ERROR): + _, response = app.test_client.get("/3") + assert response.status == 200 + assert "foo," in response.text + assert message_in_records(caplog.records, error_msg0) + assert message_in_records(caplog.records, error_msg1) + + with caplog.at_level(ERROR): + _, response = app.test_client.get("/4") + print(response.json) + assert response.status == 200 + assert "foo" not in response.text + assert "one" in response.headers + assert response.headers["one"] == "one" + + print(response.headers) + assert message_in_records(caplog.records, error_msg2) + + with caplog.at_level(ERROR): + _, response = app.test_client.get("/5") + assert response.status == 200 + assert "foo" in response.text + assert "one" in response.headers + assert response.headers["one"] == "one" + assert message_in_records(caplog.records, error_msg2) + + with caplog.at_level(ERROR): + _, response = app.test_client.get("/6") + assert "foo, bar" in response.text + assert "one" in response.headers + assert response.headers["one"] == "one" + assert message_in_records(caplog.records, error_msg2) + + +def send_response_after_eof_should_fail( + app: Sanic, + caplog: LogCaptureFixture, + message_in_records: Callable[[List[LogRecord], str], bool], +): + @app.get("/") + async def handler(request: Request): + response = await request.respond() + await response.send("foo, ") + await response.eof() + await response.send("bar") + + error_msg1 = ( + "The error response will not be sent to the client for the following " + 'exception:"Second respond call is not allowed.". A previous ' + "response has at least partially been sent." + ) + + error_msg2 = ( + "Response stream was ended, no more " + "response data is allowed to be sent." + ) + + with caplog.at_level(ERROR): + _, response = app.test_client.get("/") + assert "foo, " in response.text + assert message_in_records(caplog.records, error_msg1) + assert message_in_records(caplog.records, error_msg2)