diff --git a/starlette/datastructures.py b/starlette/datastructures.py index ebef2ebdf..eee3834e0 100644 --- a/starlette/datastructures.py +++ b/starlette/datastructures.py @@ -1,4 +1,3 @@ -import tempfile import typing from collections.abc import Sequence from shlex import shlex @@ -428,28 +427,24 @@ class UploadFile: An uploaded file included as part of the request data. """ - spool_max_size = 1024 * 1024 - file: typing.BinaryIO - headers: "Headers" - def __init__( self, - filename: str, - file: typing.Optional[typing.BinaryIO] = None, - content_type: str = "", + file: typing.BinaryIO, *, + filename: typing.Optional[str] = None, headers: "typing.Optional[Headers]" = None, ) -> None: self.filename = filename - self.content_type = content_type - if file is None: - self.file = tempfile.SpooledTemporaryFile(max_size=self.spool_max_size) # type: ignore[assignment] # noqa: E501 - else: - self.file = file + self.file = file self.headers = headers or Headers() + @property + def content_type(self) -> typing.Optional[str]: + return self.headers.get("content-type", None) + @property def _in_memory(self) -> bool: + # check for SpooledTemporaryFile._rolled rolled_to_disk = getattr(self.file, "_rolled", True) return not rolled_to_disk diff --git a/starlette/formparsers.py b/starlette/formparsers.py index 53538c814..739befae8 100644 --- a/starlette/formparsers.py +++ b/starlette/formparsers.py @@ -1,5 +1,6 @@ import typing from enum import Enum +from tempfile import SpooledTemporaryFile from urllib.parse import unquote_plus from starlette.datastructures import FormData, Headers, UploadFile @@ -116,6 +117,8 @@ async def parse(self) -> FormData: class MultiPartParser: + max_file_size = 1024 * 1024 + def __init__( self, headers: Headers, stream: typing.AsyncGenerator[bytes, None] ) -> None: @@ -160,7 +163,7 @@ def on_end(self) -> None: async def parse(self) -> FormData: # Parse the Content-Type header to get the multipart boundary. - content_type, params = parse_options_header(self.headers["Content-Type"]) + _, params = parse_options_header(self.headers["Content-Type"]) charset = params.get(b"charset", "utf-8") if type(charset) == bytes: charset = charset.decode("latin-1") @@ -186,7 +189,6 @@ async def parse(self) -> FormData: header_field = b"" header_value = b"" content_disposition = None - content_type = b"" field_name = "" data = b"" file: typing.Optional[UploadFile] = None @@ -202,7 +204,6 @@ async def parse(self) -> FormData: for message_type, message_bytes in messages: if message_type == MultiPartMessage.PART_BEGIN: content_disposition = None - content_type = b"" data = b"" item_headers = [] elif message_type == MultiPartMessage.HEADER_FIELD: @@ -213,8 +214,6 @@ async def parse(self) -> FormData: field = header_field.lower() if field == b"content-disposition": content_disposition = header_value - elif field == b"content-type": - content_type = header_value item_headers.append((field, header_value)) header_field = b"" header_value = b"" @@ -229,9 +228,10 @@ async def parse(self) -> FormData: ) if b"filename" in options: filename = _user_safe_decode(options[b"filename"], charset) + tempfile = SpooledTemporaryFile(max_size=self.max_file_size) file = UploadFile( + file=tempfile, # type: ignore[arg-type] filename=filename, - content_type=content_type.decode("latin-1"), headers=Headers(raw=item_headers), ) else: diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index b87b26e22..16f9da4a5 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -1,4 +1,6 @@ import io +from tempfile import SpooledTemporaryFile +from typing import BinaryIO import pytest @@ -269,20 +271,6 @@ def test_queryparams(): assert QueryParams(q) == q -class BigUploadFile(UploadFile): - spool_max_size = 1024 - - -@pytest.mark.anyio -async def test_upload_file(): - big_file = BigUploadFile("big-file") - await big_file.write(b"big-data" * 512) - await big_file.write(b"big-data") - await big_file.seek(0) - assert await big_file.read(1024) == b"big-data" * 128 - await big_file.close() - - @pytest.mark.anyio async def test_upload_file_file_input(): """Test passing file/stream into the UploadFile constructor""" @@ -295,6 +283,28 @@ async def test_upload_file_file_input(): assert await file.read() == b"data and more data!" +@pytest.mark.anyio +@pytest.mark.parametrize("max_size", [1, 1024], ids=["rolled", "unrolled"]) +async def test_uploadfile_rolling(max_size: int) -> None: + """Test that we can r/w to a SpooledTemporaryFile + managed by UploadFile before and after it rolls to disk + """ + stream: BinaryIO = SpooledTemporaryFile( # type: ignore[assignment] + max_size=max_size + ) + file = UploadFile(filename="file", file=stream) + assert await file.read() == b"" + await file.write(b"data") + assert await file.read() == b"" + await file.seek(0) + assert await file.read() == b"data" + await file.write(b" more") + assert await file.read() == b"" + await file.seek(0) + assert await file.read() == b"data more" + await file.close() + + def test_formdata(): stream = io.BytesIO(b"data") upload = UploadFile(filename="file", file=stream)