From d8fcbfeedebf79ac81c4eec247dc2610074997b4 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Thu, 6 Jan 2022 13:24:48 +0100 Subject: [PATCH 1/5] Add content-length header by default --- starlette/responses.py | 6 +++--- starlette/staticfiles.py | 2 +- tests/test_responses.py | 5 +++++ 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/starlette/responses.py b/starlette/responses.py index 1f9c43a21..8905acf18 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -70,8 +70,8 @@ def init_headers(self, headers: typing.Mapping[str, str] = None) -> None: populate_content_length = b"content-length" not in keys populate_content_type = b"content-type" not in keys - body = getattr(self, "body", b"") - if body and populate_content_length: + if populate_content_length: + body = getattr(self, "body", b"") content_length = str(len(body)) raw_headers.append((b"content-length", content_length.encode("latin-1"))) @@ -289,7 +289,7 @@ def set_stat_headers(self, stat_result: os.stat_result) -> None: etag_base = str(stat_result.st_mtime) + "-" + str(stat_result.st_size) etag = hashlib.md5(etag_base.encode()).hexdigest() - self.headers.setdefault("content-length", content_length) + self.headers["content-length"] = content_length self.headers.setdefault("last-modified", last_modified) self.headers.setdefault("etag", etag) diff --git a/starlette/staticfiles.py b/starlette/staticfiles.py index 76e435310..bd4d8bced 100644 --- a/starlette/staticfiles.py +++ b/starlette/staticfiles.py @@ -100,7 +100,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: def get_path(self, scope: Scope) -> str: """ Given the ASGI scope, return the `path` string to serve up, - with OS specific path seperators, and any '..', '.' components removed. + with OS specific path separators, and any '..', '.' components removed. """ return os.path.normpath(os.path.join(*scope["path"].split("/"))) diff --git a/tests/test_responses.py b/tests/test_responses.py index baba549ba..5852d2ed7 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -309,3 +309,8 @@ def test_head_method(test_client_factory): client = test_client_factory(app) response = client.head("/") assert response.text == "" + + +def test_empty_response(): + response = Response() + assert response.headers["Content-Length"] == "0" From 9e5239dc8d9e133f90aceddf7c544a42f6549796 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Thu, 6 Jan 2022 13:34:52 +0100 Subject: [PATCH 2/5] Add test for #1099 --- tests/test_responses.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/test_responses.py b/tests/test_responses.py index 5852d2ed7..a87fde11c 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -13,6 +13,7 @@ Response, StreamingResponse, ) +from starlette.testclient import TestClient def test_text_response(test_client_factory): @@ -73,6 +74,20 @@ async def app(scope, receive, send): assert response.url == "http://testserver/I%20%E2%99%A5%20Starlette/" +def test_redirect_response_content_length_header(test_client_factory): + async def app(scope, receive, send): + if scope["path"] == "/": + response = Response("hello", media_type="text/plain") # pragma: nocover + else: + response = RedirectResponse("/") + await response(scope, receive, send) + + client: TestClient = test_client_factory(app) + response = client.request("GET", "/redirect", allow_redirects=False) + assert response.url == "http://testserver/redirect" + assert response.headers["content-length"] == "0" + + def test_streaming_response(test_client_factory): filled_by_bg_task = "" From 799b0162328120272d25a8c00a86c15d97fdc6b9 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Fri, 7 Jan 2022 10:56:37 +0100 Subject: [PATCH 3/5] Revert changes and add tests --- starlette/responses.py | 6 +++--- tests/test_responses.py | 43 ++++++++++++++++++++++++++++++++++++++--- 2 files changed, 43 insertions(+), 6 deletions(-) diff --git a/starlette/responses.py b/starlette/responses.py index 25bea421d..ffde4b97d 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -70,8 +70,8 @@ def init_headers(self, headers: typing.Mapping[str, str] = None) -> None: populate_content_length = b"content-length" not in keys populate_content_type = b"content-type" not in keys - if populate_content_length: - body = getattr(self, "body", b"") + body = getattr(self, "body", b"") + if body and populate_content_length: content_length = str(len(body)) raw_headers.append((b"content-length", content_length.encode("latin-1"))) @@ -289,7 +289,7 @@ def set_stat_headers(self, stat_result: os.stat_result) -> None: etag_base = str(stat_result.st_mtime) + "-" + str(stat_result.st_size) etag = md5_hexdigest(etag_base.encode(), usedforsecurity=False) - self.headers["content-length"] = content_length + self.headers.setdefault("content-length", content_length) self.headers.setdefault("last-modified", last_modified) self.headers.setdefault("etag", etag) diff --git a/tests/test_responses.py b/tests/test_responses.py index a87fde11c..150fe4795 100644 --- a/tests/test_responses.py +++ b/tests/test_responses.py @@ -326,6 +326,43 @@ def test_head_method(test_client_factory): assert response.text == "" -def test_empty_response(): - response = Response() - assert response.headers["Content-Length"] == "0" +def test_empty_response(test_client_factory): + app = Response() + client: TestClient = test_client_factory(app) + response = client.get("/") + assert response.headers["content-length"] == "0" + + +def test_non_empty_response(test_client_factory): + app = Response(content="hi") + client: TestClient = test_client_factory(app) + response = client.get("/") + assert response.headers["content-length"] == "2" + + +def test_file_response_known_size(tmpdir, test_client_factory): + path = os.path.join(tmpdir, "xyz") + content = b"" * 1000 + with open(path, "wb") as file: + file.write(content) + + app = FileResponse(path=path, filename="example.png") + client: TestClient = test_client_factory(app) + response = client.get("/") + assert response.headers["content-length"] == str(len(content)) + + +def test_streaming_response_unknown_size(test_client_factory): + app = StreamingResponse(content=iter(["hello", "world"])) + client: TestClient = test_client_factory(app) + response = client.get("/") + assert "content-length" not in response.headers + + +def test_streaming_response_known_size(test_client_factory): + app = StreamingResponse( + content=iter(["hello", "world"]), headers={"content-length": "10"} + ) + client: TestClient = test_client_factory(app) + response = client.get("/") + assert response.headers["content-length"] == "10" From b04a7b483e21885194ddc7b8b1e7a208d9c2c0e9 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Fri, 7 Jan 2022 11:21:25 +0100 Subject: [PATCH 4/5] Check if is StreamingResponse or FileResponse before adding content-length headers --- starlette/responses.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/starlette/responses.py b/starlette/responses.py index ffde4b97d..b133c0f39 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -70,8 +70,10 @@ def init_headers(self, headers: typing.Mapping[str, str] = None) -> None: populate_content_length = b"content-length" not in keys populate_content_type = b"content-type" not in keys - body = getattr(self, "body", b"") - if body and populate_content_length: + if populate_content_length and not isinstance( + self, (StreamingResponse, FileResponse) + ): + body = getattr(self, "body", b"") content_length = str(len(body)) raw_headers.append((b"content-length", content_length.encode("latin-1"))) From d0f48a7c7966d71cc3308ccc8adc96b37e4f30ed Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Fri, 7 Jan 2022 11:50:45 +0100 Subject: [PATCH 5/5] Change conditional logic to check if body is present --- starlette/responses.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/starlette/responses.py b/starlette/responses.py index b133c0f39..da765cfa9 100644 --- a/starlette/responses.py +++ b/starlette/responses.py @@ -70,10 +70,8 @@ def init_headers(self, headers: typing.Mapping[str, str] = None) -> None: populate_content_length = b"content-length" not in keys populate_content_type = b"content-type" not in keys - if populate_content_length and not isinstance( - self, (StreamingResponse, FileResponse) - ): - body = getattr(self, "body", b"") + body = getattr(self, "body", None) + if body is not None and populate_content_length: content_length = str(len(body)) raw_headers.append((b"content-length", content_length.encode("latin-1")))