diff --git a/asgi_gzip/__init__.py b/asgi_gzip/__init__.py index d199007..a3cfd6c 100644 --- a/asgi_gzip/__init__.py +++ b/asgi_gzip/__init__.py @@ -30,6 +30,7 @@ def __init__(self, app, minimum_size: int, compresslevel: int = 9) -> None: self.send = unattached_send self.initial_message = {} self.started = False + self.content_encoding_set = False self.gzip_buffer = io.BytesIO() self.gzip_file = gzip.GzipFile( mode="wb", fileobj=self.gzip_buffer, compresslevel=compresslevel @@ -45,6 +46,13 @@ async def send_with_gzip(self, message) -> None: # Don't send the initial message until we've determined how to # modify the outgoing headers correctly. self.initial_message = message + headers = Headers(raw=self.initial_message["headers"]) + self.content_encoding_set = "content-encoding" in headers + elif message_type == "http.response.body" and self.content_encoding_set: + if not self.started: + self.started = True + await self.send(self.initial_message) + await self.send(message) elif message_type == "http.response.body" and not self.started: self.started = True body = message.get("body", b"") diff --git a/tests/test_asgi_gzip.py b/tests/test_asgi_gzip.py index a44c480..66d62a1 100644 --- a/tests/test_asgi_gzip.py +++ b/tests/test_asgi_gzip.py @@ -76,3 +76,27 @@ async def generator(bytes, count): assert response.text == "x" * 4000 assert response.headers["Content-Encoding"] == "gzip" assert "Content-Length" not in response.headers + + +def test_gzip_ignored_for_responses_with_encoding_set(test_client_factory): + def homepage(request): + async def generator(bytes, count): + for index in range(count): + yield bytes + + streaming = generator(bytes=b"x" * 400, count=10) + return StreamingResponse( + streaming, status_code=200, headers={"Content-Encoding": "br"} + ) + + app = Starlette( + routes=[Route("/", endpoint=homepage)], + middleware=[Middleware(GZipMiddleware)], + ) + + client = test_client_factory(app) + response = client.get("/", headers={"accept-encoding": "gzip, br"}) + assert response.status_code == 200 + assert response.text == "x" * 4000 + assert response.headers["Content-Encoding"] == "br" + assert "Content-Length" not in response.headers