Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't omit Content-Length header for Content-Length: 0 cases #1395

Merged
merged 7 commits into from Jan 7, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 2 additions & 2 deletions starlette/responses.py
Expand Up @@ -70,8 +70,8 @@ def init_headers(self, headers: typing.Mapping[str, str] = None) -> None:
populate_content_length = b"content-length" not in keys
populate_content_type = b"content-type" not in keys

body = getattr(self, "body", b"")
if body and populate_content_length:
body = getattr(self, "body", None)
if body is not None and populate_content_length:
content_length = str(len(body))
raw_headers.append((b"content-length", content_length.encode("latin-1")))

Expand Down
2 changes: 1 addition & 1 deletion starlette/staticfiles.py
Expand Up @@ -100,7 +100,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
def get_path(self, scope: Scope) -> str:
"""
Given the ASGI scope, return the `path` string to serve up,
with OS specific path seperators, and any '..', '.' components removed.
with OS specific path separators, and any '..', '.' components removed.
"""
return os.path.normpath(os.path.join(*scope["path"].split("/")))

Expand Down
57 changes: 57 additions & 0 deletions tests/test_responses.py
Expand Up @@ -13,6 +13,7 @@
Response,
StreamingResponse,
)
from starlette.testclient import TestClient


def test_text_response(test_client_factory):
Expand Down Expand Up @@ -73,6 +74,20 @@ async def app(scope, receive, send):
assert response.url == "http://testserver/I%20%E2%99%A5%20Starlette/"


def test_redirect_response_content_length_header(test_client_factory):
async def app(scope, receive, send):
if scope["path"] == "/":
response = Response("hello", media_type="text/plain") # pragma: nocover
else:
response = RedirectResponse("/")
await response(scope, receive, send)

client: TestClient = test_client_factory(app)
response = client.request("GET", "/redirect", allow_redirects=False)
assert response.url == "http://testserver/redirect"
assert response.headers["content-length"] == "0"


def test_streaming_response(test_client_factory):
filled_by_bg_task = ""

Expand Down Expand Up @@ -309,3 +324,45 @@ def test_head_method(test_client_factory):
client = test_client_factory(app)
response = client.head("/")
assert response.text == ""


def test_empty_response(test_client_factory):
app = Response()
client: TestClient = test_client_factory(app)
response = client.get("/")
assert response.headers["content-length"] == "0"


def test_non_empty_response(test_client_factory):
app = Response(content="hi")
client: TestClient = test_client_factory(app)
response = client.get("/")
assert response.headers["content-length"] == "2"


def test_file_response_known_size(tmpdir, test_client_factory):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't realise at the time, but we could just name this one test_file_response.

Copy link
Sponsor Member Author

@Kludex Kludex Jan 7, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's one above already (that also checks the content-length being present), maybe it's just redundant.

path = os.path.join(tmpdir, "xyz")
content = b"<file content>" * 1000
with open(path, "wb") as file:
file.write(content)

app = FileResponse(path=path, filename="example.png")
client: TestClient = test_client_factory(app)
response = client.get("/")
assert response.headers["content-length"] == str(len(content))


def test_streaming_response_unknown_size(test_client_factory):
app = StreamingResponse(content=iter(["hello", "world"]))
client: TestClient = test_client_factory(app)
response = client.get("/")
assert "content-length" not in response.headers


def test_streaming_response_known_size(test_client_factory):
app = StreamingResponse(
content=iter(["hello", "world"]), headers={"content-length": "10"}
)
client: TestClient = test_client_factory(app)
response = client.get("/")
assert response.headers["content-length"] == "10"