diff --git a/docs/middleware.md b/docs/middleware.md index e84c361a7..5dda069b4 100644 --- a/docs/middleware.md +++ b/docs/middleware.md @@ -180,6 +180,8 @@ The following arguments are supported: * `minimum_size` - Do not GZip responses that are smaller than this minimum size in bytes. Defaults to `500`. +The middleware won't GZip responses that already have a `Content-Encoding` set, to prevent them from being encoded twice. + ## BaseHTTPMiddleware An abstract class that allows you to write ASGI middleware against a request/response diff --git a/starlette/middleware/gzip.py b/starlette/middleware/gzip.py index 9d69ee7ca..cbb0f4a5b 100644 --- a/starlette/middleware/gzip.py +++ b/starlette/middleware/gzip.py @@ -33,6 +33,7 @@ def __init__(self, app: ASGIApp, minimum_size: int, compresslevel: int = 9) -> N self.send: Send = unattached_send self.initial_message: Message = {} self.started = False + self.content_encoding_set = False self.gzip_buffer = io.BytesIO() self.gzip_file = gzip.GzipFile( mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel @@ -48,6 +49,13 @@ async def send_with_gzip(self, message: Message) -> None: # Don't send the initial message until we've determined how to # modify the outgoing headers correctly. self.initial_message = message + headers = Headers(raw=self.initial_message["headers"]) + self.content_encoding_set = "content-encoding" in headers + elif message_type == "http.response.body" and self.content_encoding_set: + if not self.started: + self.started = True + await self.send(self.initial_message) + await self.send(message) elif message_type == "http.response.body" and not self.started: self.started = True body = message.get("body", b"") diff --git a/tests/middleware/test_gzip.py b/tests/middleware/test_gzip.py index a9cedb33b..74b09e4dd 100644 --- a/tests/middleware/test_gzip.py +++ b/tests/middleware/test_gzip.py @@ -76,3 +76,27 @@ async def generator(bytes, count): assert response.text == "x" * 4000 assert response.headers["Content-Encoding"] == "gzip" assert "Content-Length" not in response.headers + + +def test_gzip_ignored_for_responses_with_encoding_set(test_client_factory): + def homepage(request): + async def generator(bytes, count): + for index in range(count): + yield bytes + + streaming = generator(bytes=b"x" * 400, count=10) + return StreamingResponse( + streaming, status_code=200, headers={"Content-Encoding": "br"} + ) + + app = Starlette( + routes=[Route("/", endpoint=homepage)], + middleware=[Middleware(GZipMiddleware)], + ) + + client = test_client_factory(app) + response = client.get("/", headers={"accept-encoding": "gzip, br"}) + assert response.status_code == 200 + assert response.text == "x" * 4000 + assert response.headers["Content-Encoding"] == "br" + assert "Content-Length" not in response.headers