Skip to content

Commit

Permalink
Improve typings for multipart (#3905)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tomasz Trębski authored and asvetlov committed Jul 19, 2019
1 parent 64a8698 commit d7b08ad
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 56 deletions.
2 changes: 2 additions & 0 deletions CHANGES/3621.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Improve typing annotations for multipart.py along with changes required
by mypy in files that references multipart.py.
77 changes: 51 additions & 26 deletions aiohttp/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
Tuple,
Type,
Union,
cast,
)
from urllib.parse import parse_qsl, unquote, urlencode

Expand Down Expand Up @@ -195,21 +194,26 @@ def content_disposition_filename(params: Mapping[str, str],


class MultipartResponseWrapper:
"""Wrapper around the MultipartBodyReader.
"""Wrapper around the MultipartReader.
It takes care about
underlying connection and close it when it needs in.
"""

def __init__(self, resp: 'ClientResponse', stream: Any) -> None:
# TODO: add strong annotation to stream
def __init__(
self,
resp: 'ClientResponse',
stream: 'MultipartReader',
) -> None:
self.resp = resp
self.stream = stream

def __aiter__(self) -> 'MultipartResponseWrapper':
return self

async def __anext__(self) -> Any:
async def __anext__(
self,
) -> Union['MultipartReader', 'BodyPartReader']:
part = await self.next()
if part is None:
raise StopAsyncIteration # NOQA
Expand All @@ -219,7 +223,9 @@ def at_eof(self) -> bool:
"""Returns True when all response data had been read."""
return self.resp.content.at_eof()

async def next(self) -> Any:
async def next(
self,
) -> Optional[Union['MultipartReader', 'BodyPartReader']]:
"""Emits next multipart reader object."""
item = await self.stream.next()
if self.stream.at_eof():
Expand All @@ -240,7 +246,7 @@ class BodyPartReader:
def __init__(
self,
boundary: bytes,
headers: Mapping[str, Optional[str]],
headers: 'CIMultiDictProxy[str]',
content: StreamReader,
*,
_newline: bytes = b'\r\n'
Expand All @@ -262,19 +268,19 @@ def __init__(
def __aiter__(self) -> 'BodyPartReader':
return self

async def __anext__(self) -> Any:
async def __anext__(self) -> bytes:
part = await self.next()
if part is None:
raise StopAsyncIteration # NOQA
return part

async def next(self) -> Any:
async def next(self) -> Optional[bytes]:
item = await self.read()
if not item:
return None
return item

async def read(self, *, decode: bool=False) -> Any:
async def read(self, *, decode: bool=False) -> bytes:
"""Reads body part data.
decode: Decodes data following by encoding
Expand Down Expand Up @@ -429,7 +435,9 @@ async def text(self, *, encoding: Optional[str]=None) -> str:
encoding = encoding or self.get_charset(default='utf-8')
return data.decode(encoding)

async def json(self, *, encoding: Optional[str]=None) -> Any:
async def json(self,
*,
encoding: Optional[str]=None) -> Optional[Dict[str, Any]]:
"""Like read(), but assumes that body parts contains JSON data."""
data = await self.read(decode=True)
if not data:
Expand Down Expand Up @@ -468,7 +476,7 @@ def decode(self, data: bytes) -> bytes:
return data

def _decode_content(self, data: bytes) -> bytes:
encoding = cast(str, self.headers[CONTENT_ENCODING]).lower()
encoding = self.headers.get(CONTENT_ENCODING, '').lower()

if encoding == 'deflate':
return zlib.decompress(data, -zlib.MAX_WBITS)
Expand All @@ -480,7 +488,7 @@ def _decode_content(self, data: bytes) -> bytes:
raise RuntimeError('unknown content encoding: {}'.format(encoding))

def _decode_content_transfer(self, data: bytes) -> bytes:
encoding = cast(str, self.headers[CONTENT_TRANSFER_ENCODING]).lower()
encoding = self.headers.get(CONTENT_TRANSFER_ENCODING, '').lower()

if encoding == 'base64':
return base64.b64decode(data)
Expand Down Expand Up @@ -564,22 +572,27 @@ def __init__(
self._boundary = ('--' + self._get_boundary()).encode()
self._newline = _newline
self._content = content
self._last_part = None
self._last_part = None # type: Optional[Union['MultipartReader', BodyPartReader]] # noqa
self._at_eof = False
self._at_bof = True
self._unread = [] # type: List[bytes]

def __aiter__(self) -> 'MultipartReader':
return self

async def __anext__(self) -> Any:
async def __anext__(
self,
) -> Union['MultipartReader', BodyPartReader]:
part = await self.next()
if part is None:
raise StopAsyncIteration # NOQA
return part

@classmethod
def from_response(cls, response: 'ClientResponse') -> Any:
def from_response(
cls,
response: 'ClientResponse',
) -> MultipartResponseWrapper:
"""Constructs reader instance from HTTP response.
:param response: :class:`~aiohttp.client.ClientResponse` instance
Expand All @@ -594,19 +607,21 @@ def at_eof(self) -> bool:
"""
return self._at_eof

async def next(self) -> Any:
async def next(
self,
) -> Optional[Union['MultipartReader', BodyPartReader]]:
"""Emits the next multipart body part."""
# So, if we're at BOF, we need to skip till the boundary.
if self._at_eof:
return
return None
await self._maybe_release_last_part()
if self._at_bof:
await self._read_until_first_boundary()
self._at_bof = False
else:
await self._read_boundary()
if self._at_eof: # we just read the last boundary, nothing to do there
return
return None
self._last_part = await self.fetch_next_part()
return self._last_part

Expand All @@ -618,12 +633,17 @@ async def release(self) -> None:
break
await item.release()

async def fetch_next_part(self) -> Any:
async def fetch_next_part(
self,
) -> Union['MultipartReader', BodyPartReader]:
"""Returns the next body part reader."""
headers = await self._read_headers()
return self._get_part_reader(headers)

def _get_part_reader(self, headers: 'CIMultiDictProxy[str]') -> Any:
def _get_part_reader(
self,
headers: 'CIMultiDictProxy[str]',
) -> Union['MultipartReader', BodyPartReader]:
"""Dispatches the response by the `Content-Type` header, returning
suitable reader instance.
Expand Down Expand Up @@ -822,7 +842,7 @@ def boundary(self) -> str:
def append(
self,
obj: Any,
headers: Optional['MultiMapping[str]']=None
headers: Optional[MultiMapping[str]]=None
) -> Payload:
if headers is None:
headers = CIMultiDict()
Expand All @@ -841,15 +861,20 @@ def append(
def append_payload(self, payload: Payload) -> Payload:
"""Adds a new body part to multipart writer."""
# compression
encoding = payload.headers.get(CONTENT_ENCODING, '').lower() # type: Optional[str] # noqa
encoding = payload.headers.get(
CONTENT_ENCODING,
'',
).lower() # type: Optional[str]
if encoding and encoding not in ('deflate', 'gzip', 'identity'):
raise RuntimeError('unknown content encoding: {}'.format(encoding))
if encoding == 'identity':
encoding = None

# te encoding
te_encoding = payload.headers.get(
CONTENT_TRANSFER_ENCODING, '').lower() # type: Optional[str] # noqa
CONTENT_TRANSFER_ENCODING,
'',
).lower() # type: Optional[str]
if te_encoding not in ('', 'base64', 'quoted-printable', 'binary'):
raise RuntimeError('unknown content transfer encoding: {}'
''.format(te_encoding))
Expand All @@ -867,7 +892,7 @@ def append_payload(self, payload: Payload) -> Payload:
def append_json(
self,
obj: Any,
headers: Optional['MultiMapping[str]']=None
headers: Optional[MultiMapping[str]]=None
) -> Payload:
"""Helper to append JSON part."""
if headers is None:
Expand All @@ -879,7 +904,7 @@ def append_form(
self,
obj: Union[Sequence[Tuple[str, str]],
Mapping[str, str]],
headers: Optional['MultiMapping[str]']=None
headers: Optional[MultiMapping[str]]=None
) -> Payload:
"""Helper to append form urlencoded part."""
assert isinstance(obj, (Sequence, Mapping))
Expand Down
68 changes: 38 additions & 30 deletions aiohttp/web_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
sentinel,
)
from .http_parser import RawRequestMessage
from .multipart import MultipartReader
from .multipart import BodyPartReader, MultipartReader
from .streams import EmptyStreamReader, StreamReader
from .typedefs import (
DEFAULT_JSON_DECODER,
Expand Down Expand Up @@ -633,41 +633,49 @@ async def post(self) -> 'MultiDictProxy[Union[str, bytes, FileField]]':
field = await multipart.next()
while field is not None:
size = 0
content_type = field.headers.get(hdrs.CONTENT_TYPE)

if field.filename:
# store file in temp file
tmp = tempfile.TemporaryFile()
chunk = await field.read_chunk(size=2**16)
while chunk:
chunk = field.decode(chunk)
tmp.write(chunk)
size += len(chunk)
field_ct = field.headers.get(hdrs.CONTENT_TYPE)

if isinstance(field, BodyPartReader):
if field.filename and field_ct:
# store file in temp file
tmp = tempfile.TemporaryFile()
chunk = await field.read_chunk(size=2**16)
while chunk:
chunk = field.decode(chunk)
tmp.write(chunk)
size += len(chunk)
if 0 < max_size < size:
raise HTTPRequestEntityTooLarge(
max_size=max_size,
actual_size=size
)
chunk = await field.read_chunk(size=2**16)
tmp.seek(0)

ff = FileField(field.name, field.filename,
cast(io.BufferedReader, tmp),
field_ct, field.headers)
out.add(field.name, ff)
else:
# deal with ordinary data
value = await field.read(decode=True)
if field_ct is None or \
field_ct.startswith('text/'):
charset = field.get_charset(default='utf-8')
out.add(field.name, value.decode(charset))
else:
out.add(field.name, value)
size += len(value)
if 0 < max_size < size:
raise HTTPRequestEntityTooLarge(
max_size=max_size,
actual_size=size
)
chunk = await field.read_chunk(size=2**16)
tmp.seek(0)

ff = FileField(field.name, field.filename,
cast(io.BufferedReader, tmp),
content_type, field.headers)
out.add(field.name, ff)
else:
value = await field.read(decode=True)
if content_type is None or \
content_type.startswith('text/'):
charset = field.get_charset(default='utf-8')
value = value.decode(charset)
out.add(field.name, value)
size += len(value)
if 0 < max_size < size:
raise HTTPRequestEntityTooLarge(
max_size=max_size,
actual_size=size
)
raise ValueError(
'To decode nested multipart you need '
'to use custom reader',
)

field = await multipart.next()
else:
Expand Down

0 comments on commit d7b08ad

Please sign in to comment.