Skip to content

Commit

Permalink
[3.8] Make type hints for http parser stricter (#5267).
Browse files Browse the repository at this point in the history
(cherry picked from commit a6c7f15)

Co-authored-by: Andrew Svetlov <andrew.svetlov@gmail.com>
  • Loading branch information
asvetlov committed Nov 22, 2020
1 parent 91ec66d commit c60fc85
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 17 deletions.
1 change: 1 addition & 0 deletions CHANGES/5267.feature
@@ -0,0 +1 @@
Make type hints for http parser stricter
3 changes: 2 additions & 1 deletion aiohttp/client_exceptions.py
Expand Up @@ -4,6 +4,7 @@
import warnings
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union

from .http_parser import RawResponseMessage
from .typedefs import LooseHeaders

try:
Expand Down Expand Up @@ -225,7 +226,7 @@ class ServerConnectionError(ClientConnectionError):
class ServerDisconnectedError(ServerConnectionError):
"""Server disconnected."""

def __init__(self, message: Optional[str] = None) -> None:
def __init__(self, message: Union[RawResponseMessage, str, None] = None) -> None:
if message is None:
message = "Server disconnected"

Expand Down
4 changes: 2 additions & 2 deletions aiohttp/client_proto.py
Expand Up @@ -23,7 +23,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop) -> None:

self._should_close = False

self._payload = None
self._payload: Optional[StreamReader] = None
self._skip_payload = False
self._payload_parser = None

Expand Down Expand Up @@ -223,7 +223,7 @@ def data_received(self, data: bytes) -> None:

self._upgraded = upgraded

payload = None
payload: Optional[StreamReader] = None
for message, payload in messages:
if message.should_close:
self._should_close = True
Expand Down
28 changes: 16 additions & 12 deletions aiohttp/http_parser.py
Expand Up @@ -4,8 +4,9 @@
import re
import string
import zlib
from contextlib import suppress
from enum import IntEnum
from typing import Any, List, Optional, Tuple, Type, Union
from typing import Generic, List, Optional, Tuple, Type, TypeVar, Union

from multidict import CIMultiDict, CIMultiDictProxy, istr
from yarl import URL
Expand Down Expand Up @@ -88,6 +89,9 @@
)


_MsgT = TypeVar("_MsgT", RawRequestMessage, RawResponseMessage)


class ParseState(IntEnum):

PARSE_NONE = 0
Expand Down Expand Up @@ -198,7 +202,7 @@ def parse_headers(
return (CIMultiDictProxy(headers), tuple(raw_headers))


class HttpParser(abc.ABC):
class HttpParser(abc.ABC, Generic[_MsgT]):
def __init__(
self,
protocol: Optional[BaseProtocol] = None,
Expand Down Expand Up @@ -239,10 +243,10 @@ def __init__(
self._headers_parser = HeadersParser(max_line_size, max_headers, max_field_size)

@abc.abstractmethod
def parse_message(self, lines: List[bytes]) -> Any:
def parse_message(self, lines: List[bytes]) -> _MsgT:
pass

def feed_eof(self) -> Any:
def feed_eof(self) -> Optional[_MsgT]:
if self._payload_parser is not None:
self._payload_parser.feed_eof()
self._payload_parser = None
Expand All @@ -254,10 +258,9 @@ def feed_eof(self) -> Any:
if self._lines:
if self._lines[-1] != "\r\n":
self._lines.append(b"")
try:
with suppress(Exception):
return self.parse_message(self._lines)
except Exception:
return None
return None

def feed_data(
self,
Expand All @@ -267,7 +270,7 @@ def feed_data(
CONTENT_LENGTH: istr = hdrs.CONTENT_LENGTH,
METH_CONNECT: str = hdrs.METH_CONNECT,
SEC_WEBSOCKET_KEY1: istr = hdrs.SEC_WEBSOCKET_KEY1,
) -> Tuple[List[Any], bool, bytes]:
) -> Tuple[List[Tuple[_MsgT, StreamReader]], bool, bytes]:

messages = []

Expand Down Expand Up @@ -346,6 +349,7 @@ def feed_data(
if not payload_parser.done:
self._payload_parser = payload_parser
elif method == METH_CONNECT:
assert isinstance(msg, RawRequestMessage)
payload = StreamReader(
self.protocol,
timer=self.timer,
Expand Down Expand Up @@ -479,13 +483,13 @@ def set_upgraded(self, val: bool) -> None:
self._upgraded = val


class HttpRequestParser(HttpParser):
class HttpRequestParser(HttpParser[RawRequestMessage]):
"""Read request status line. Exception .http_exceptions.BadStatusLine
could be raised in case of any errors in status line.
Returns RawRequestMessage.
"""

def parse_message(self, lines: List[bytes]) -> Any:
def parse_message(self, lines: List[bytes]) -> RawRequestMessage:
# request line
line = lines[0].decode("utf-8", "surrogateescape")
try:
Expand Down Expand Up @@ -542,13 +546,13 @@ def parse_message(self, lines: List[bytes]) -> Any:
)


class HttpResponseParser(HttpParser):
class HttpResponseParser(HttpParser[RawResponseMessage]):
"""Read response status line and headers.
BadStatusLine could be raised in case of any errors in status line.
Returns RawResponseMessage"""

def parse_message(self, lines: List[bytes]) -> Any:
def parse_message(self, lines: List[bytes]) -> RawResponseMessage:
line = lines[0].decode("utf-8", "surrogateescape")
try:
version, status = line.split(None, 1)
Expand Down
14 changes: 12 additions & 2 deletions aiohttp/web_protocol.py
Expand Up @@ -7,7 +7,17 @@
from html import escape as html_escape
from http import HTTPStatus
from logging import Logger
from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, Type, cast
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Deque,
Optional,
Tuple,
Type,
cast,
)

import yarl

Expand Down Expand Up @@ -172,7 +182,7 @@ def __init__(
self._keepalive_timeout = keepalive_timeout
self._lingering_time = float(lingering_time)

self._messages = deque() # type: Any # Python 3.5 has no typing.Deque
self._messages: Deque[Tuple[RawRequestMessage, StreamReader]] = deque()
self._message_tail = b""

self._waiter = None # type: Optional[asyncio.Future[None]]
Expand Down

0 comments on commit c60fc85

Please sign in to comment.