Skip to content

Commit

Permalink
Add support for range headers to FileResponse
Browse files Browse the repository at this point in the history
This PR aims to solve #950.

The implementation differs from [baize's
FileResponse](https://github.com/abersheeran/baize/blob/23791841f30ca92775e50a544a8606d1d4deac93/baize/asgi/responses.py#L184),
since that one takes in consideration the "range" request header. The
desgin decision is justified as the Response classes in Starlette are
naive in regards to the request.
  • Loading branch information
Kludex committed Jan 7, 2023
1 parent 048643a commit 2612f83
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 4 deletions.
28 changes: 24 additions & 4 deletions starlette/responses.py
Expand Up @@ -290,9 +290,11 @@ def __init__(
stat_result: typing.Optional[os.stat_result] = None,
method: typing.Optional[str] = None,
content_disposition_type: str = "attachment",
range: typing.Optional[typing.Tuple[int, int]] = None,
) -> None:
self.path = path
self.status_code = status_code
self.status_code = status_code if range is None else 206
self.range = range
self.filename = filename
self.send_header_only = method is not None and method.upper() == "HEAD"
if media_type is None:
Expand All @@ -316,9 +318,17 @@ def __init__(
self.set_stat_headers(stat_result)

def set_stat_headers(self, stat_result: os.stat_result) -> None:
content_length = str(stat_result.st_size)
size = str(stat_result.st_size)
last_modified = formatdate(stat_result.st_mtime, usegmt=True)
etag_base = str(stat_result.st_mtime) + "-" + str(stat_result.st_size)
if self.range is not None:
start, end = self.range
etag_base += f"-{start}/{end}"
content_length = str(end - start + 1)
self.headers.setdefault("accept-ranges", "bytes")
self.headers.setdefault("content-range", f"bytes {start}-{end}/{size}")
else:
content_length = size
etag = md5_hexdigest(etag_base.encode(), usedforsecurity=False)

self.headers.setdefault("content-length", content_length)
Expand All @@ -336,6 +346,8 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
mode = stat_result.st_mode
if not stat.S_ISREG(mode):
raise RuntimeError(f"File at path {self.path} is not a file.")
else:
stat_result = self.stat_result
await send(
{
"type": "http.response.start",
Expand All @@ -347,10 +359,18 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
await send({"type": "http.response.body", "body": b"", "more_body": False})
else:
async with await anyio.open_file(self.path, mode="rb") as file:
if self.range is not None:
start, end = self.range
await file.seek(start)
else:
start, end = 0, stat_result.st_size - 1
remaining_bytes = end - start + 1
more_body = True
while more_body:
chunk = await file.read(self.chunk_size)
more_body = len(chunk) == self.chunk_size
chunk_size = min(remaining_bytes, self.chunk_size)
chunk = await file.read(chunk_size)
remaining_bytes -= len(chunk)
more_body = remaining_bytes > 0 and len(chunk) == chunk_size
await send(
{
"type": "http.response.body",
Expand Down
26 changes: 26 additions & 0 deletions tests/test_responses.py
Expand Up @@ -5,6 +5,7 @@

from starlette import status
from starlette.background import BackgroundTask
from starlette.datastructures import Headers
from starlette.requests import Request
from starlette.responses import (
FileResponse,
Expand Down Expand Up @@ -241,6 +242,31 @@ async def app(scope, receive, send):
assert filled_by_bg_task == "6, 7, 8, 9"


def test_file_response_with_range(tmpdir, test_client_factory):
path = os.path.join(tmpdir, "xyz")
content = b"<file content>"
with open(path, "wb") as file:
file.write(content)

async def app(scope, receive, send):
range_header = Headers(scope=scope)["range"]
start, end = (int(v) for v in range_header[len("bytes=") :].split("-"))
response = FileResponse(path=path, filename="example.png", range=(start, end))
await response(scope, receive, send)

client = test_client_factory(app)
response = client.get("/", headers={"range": "bytes=1-12"})
expected_disposition = 'attachment; filename="example.png"'
assert response.status_code == status.HTTP_206_PARTIAL_CONTENT
assert response.content == content[1:13]
assert response.headers["content-type"] == "image/png"
assert response.headers["content-disposition"] == expected_disposition
assert response.headers["content-range"] == "bytes 1-12/14"
assert "content-length" in response.headers
assert "last-modified" in response.headers
assert "etag" in response.headers


def test_file_response_with_directory_raises_error(tmpdir, test_client_factory):
app = FileResponse(path=tmpdir, filename="example.png")
client = test_client_factory(app)
Expand Down

0 comments on commit 2612f83

Please sign in to comment.