Skip to content

Commit

Permalink
feat: add headers attribute to UploadFile (#1382)
Browse files Browse the repository at this point in the history
This preserves the multipart field headers that may have been included in the original request

Co-authored-by: Tom Christie <tom@tomchristie.com>
  • Loading branch information
adriangb and tomchristie committed Jan 7, 2022
1 parent 4633427 commit f1c5049
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 2 deletions.
3 changes: 2 additions & 1 deletion docs/requests.md
Expand Up @@ -122,7 +122,7 @@ multidict, containing both file uploads and text input. File upload items are re
* `filename`: A `str` with the original file name that was uploaded (e.g. `myimage.jpg`).
* `content_type`: A `str` with the content type (MIME type / media type) (e.g. `image/jpeg`).
* `file`: A <a href="https://docs.python.org/3/library/tempfile.html#tempfile.SpooledTemporaryFile" target="_blank">`SpooledTemporaryFile`</a> (a <a href="https://docs.python.org/3/glossary.html#term-file-like-object" target="_blank">file-like</a> object). This is the actual Python file that you can pass directly to other functions or libraries that expect a "file-like" object.

* `headers`: A `Headers` object. Often this will only be the `Content-Type` header, but if additional headers were included in the multipart field they will be included here. Note that these headers have no relationship with the headers in `Request.headers`.

`UploadFile` has the following `async` methods. They all call the corresponding file methods underneath (using the internal `SpooledTemporaryFile`).

Expand All @@ -142,6 +142,7 @@ filename = form["upload_file"].filename
contents = await form["upload_file"].read()
```


#### Application

The originating Starlette application can be accessed via `request.app`.
Expand Down
9 changes: 8 additions & 1 deletion starlette/datastructures.py
Expand Up @@ -415,15 +415,22 @@ class UploadFile:
"""

spool_max_size = 1024 * 1024
headers: "Headers"

def __init__(
self, filename: str, file: typing.IO = None, content_type: str = ""
self,
filename: str,
file: typing.IO = None,
content_type: str = "",
*,
headers: "typing.Optional[Headers]" = None,
) -> None:
self.filename = filename
self.content_type = content_type
if file is None:
file = tempfile.SpooledTemporaryFile(max_size=self.spool_max_size)
self.file = file
self.headers = headers or Headers()

@property
def _in_memory(self) -> bool:
Expand Down
4 changes: 4 additions & 0 deletions starlette/formparsers.py
Expand Up @@ -184,6 +184,7 @@ async def parse(self) -> FormData:
file: typing.Optional[UploadFile] = None

items: typing.List[typing.Tuple[str, typing.Union[str, UploadFile]]] = []
item_headers: typing.List[typing.Tuple[bytes, bytes]] = []

# Feed the parser with data from the request.
async for chunk in self.stream:
Expand All @@ -195,6 +196,7 @@ async def parse(self) -> FormData:
content_disposition = None
content_type = b""
data = b""
item_headers = []
elif message_type == MultiPartMessage.HEADER_FIELD:
header_field += message_bytes
elif message_type == MultiPartMessage.HEADER_VALUE:
Expand All @@ -205,6 +207,7 @@ async def parse(self) -> FormData:
content_disposition = header_value
elif field == b"content-type":
content_type = header_value
item_headers.append((field, header_value))
header_field = b""
header_value = b""
elif message_type == MultiPartMessage.HEADERS_FINISHED:
Expand All @@ -215,6 +218,7 @@ async def parse(self) -> FormData:
file = UploadFile(
filename=filename,
content_type=content_type.decode("latin-1"),
headers=Headers(raw=item_headers),
)
else:
file = None
Expand Down
56 changes: 56 additions & 0 deletions tests/test_formparsers.py
Expand Up @@ -56,6 +56,26 @@ async def multi_items_app(scope, receive, send):
await response(scope, receive, send)


async def app_with_headers(scope, receive, send):
request = Request(scope, receive)
data = await request.form()
output = {}
for key, value in data.items():
if isinstance(value, UploadFile):
content = await value.read()
output[key] = {
"filename": value.filename,
"content": content.decode(),
"content_type": value.content_type,
"headers": list(value.headers.items()),
}
else:
output[key] = value
await request.close()
response = JSONResponse(output)
await response(scope, receive, send)


async def app_read_body(scope, receive, send):
request = Request(scope, receive)
# Read bytes, to force request.stream() to return the already parsed body
Expand Down Expand Up @@ -137,6 +157,42 @@ def test_multipart_request_multiple_files(tmpdir, test_client_factory):
}


def test_multipart_request_multiple_files_with_headers(tmpdir, test_client_factory):
path1 = os.path.join(tmpdir, "test1.txt")
with open(path1, "wb") as file:
file.write(b"<file1 content>")

path2 = os.path.join(tmpdir, "test2.txt")
with open(path2, "wb") as file:
file.write(b"<file2 content>")

client = test_client_factory(app_with_headers)
with open(path1, "rb") as f1, open(path2, "rb") as f2:
response = client.post(
"/",
files=[
("test1", (None, f1)),
("test2", ("test2.txt", f2, "text/plain", {"x-custom": "f2"})),
],
)
assert response.json() == {
"test1": "<file1 content>",
"test2": {
"filename": "test2.txt",
"content": "<file2 content>",
"content_type": "text/plain",
"headers": [
[
"content-disposition",
'form-data; name="test2"; filename="test2.txt"',
],
["content-type", "text/plain"],
["x-custom", "f2"],
],
},
}


def test_multi_items(tmpdir, test_client_factory):
path1 = os.path.join(tmpdir, "test1.txt")
with open(path1, "wb") as file:
Expand Down

0 comments on commit f1c5049

Please sign in to comment.