Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make type hints for http parser stricter #5267

Merged
merged 2 commits into from Nov 22, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/5267.feature
@@ -0,0 +1 @@
Make type hints for http parser stricter
3 changes: 2 additions & 1 deletion aiohttp/client_exceptions.py
Expand Up @@ -3,6 +3,7 @@
import asyncio
from typing import TYPE_CHECKING, Any, Optional, Tuple, Union

from .http_parser import RawResponseMessage
from .typedefs import LooseHeaders

try:
Expand Down Expand Up @@ -192,7 +193,7 @@ class ServerConnectionError(ClientConnectionError):
class ServerDisconnectedError(ServerConnectionError):
"""Server disconnected."""

def __init__(self, message: Optional[str] = None) -> None:
def __init__(self, message: Union[RawResponseMessage, str, None] = None) -> None:
if message is None:
message = "Server disconnected"

Expand Down
4 changes: 2 additions & 2 deletions aiohttp/client_proto.py
Expand Up @@ -23,7 +23,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop) -> None:

self._should_close = False

self._payload = None
self._payload: Optional[StreamReader] = None
self._skip_payload = False
self._payload_parser = None

Expand Down Expand Up @@ -230,7 +230,7 @@ def data_received(self, data: bytes) -> None:

self._upgraded = upgraded

payload = None
payload: Optional[StreamReader] = None
for message, payload in messages:
if message.should_close:
self._should_close = True
Expand Down
28 changes: 16 additions & 12 deletions aiohttp/http_parser.py
Expand Up @@ -4,8 +4,9 @@
import re
import string
import zlib
from contextlib import suppress
from enum import IntEnum
from typing import Any, List, Optional, Tuple, Type, Union
from typing import Any, Generic, List, Optional, Tuple, Type, TypeVar, Union

from multidict import CIMultiDict, CIMultiDictProxy, istr
from yarl import URL
Expand Down Expand Up @@ -88,6 +89,9 @@
)


_MsgT = TypeVar("_MsgT", RawRequestMessage, RawResponseMessage)


class ParseState(IntEnum):

PARSE_NONE = 0
Expand Down Expand Up @@ -198,7 +202,7 @@ def parse_headers(
return (CIMultiDictProxy(headers), tuple(raw_headers))


class HttpParser(abc.ABC):
class HttpParser(abc.ABC, Generic[_MsgT]):
def __init__(
self,
protocol: BaseProtocol,
Expand Down Expand Up @@ -239,10 +243,10 @@ def __init__(
self._headers_parser = HeadersParser(max_line_size, max_headers, max_field_size)

@abc.abstractmethod
def parse_message(self, lines: List[bytes]) -> Any:
def parse_message(self, lines: List[bytes]) -> _MsgT:
pass

def feed_eof(self) -> Any:
def feed_eof(self) -> Optional[_MsgT]:
if self._payload_parser is not None:
self._payload_parser.feed_eof()
self._payload_parser = None
Expand All @@ -254,10 +258,9 @@ def feed_eof(self) -> Any:
if self._lines:
if self._lines[-1] != "\r\n":
self._lines.append(b"")
try:
with suppress(Exception):
return self.parse_message(self._lines)
except Exception:
return None
return None

def feed_data(
self,
Expand All @@ -267,7 +270,7 @@ def feed_data(
CONTENT_LENGTH: istr = hdrs.CONTENT_LENGTH,
METH_CONNECT: str = hdrs.METH_CONNECT,
SEC_WEBSOCKET_KEY1: istr = hdrs.SEC_WEBSOCKET_KEY1,
) -> Tuple[List[Any], bool, bytes]:
) -> Tuple[List[Tuple[_MsgT, StreamReader]], bool, bytes]:

messages = []

Expand Down Expand Up @@ -346,6 +349,7 @@ def feed_data(
if not payload_parser.done:
self._payload_parser = payload_parser
elif method == METH_CONNECT:
assert isinstance(msg, RawRequestMessage)
payload = StreamReader(
self.protocol,
timer=self.timer,
Expand Down Expand Up @@ -480,13 +484,13 @@ def set_upgraded(self, val: bool) -> None:
self._upgraded = val


class HttpRequestParser(HttpParser):
class HttpRequestParser(HttpParser[RawRequestMessage]):
"""Read request status line. Exception .http_exceptions.BadStatusLine
could be raised in case of any errors in status line.
Returns RawRequestMessage.
"""

def parse_message(self, lines: List[bytes]) -> Any:
def parse_message(self, lines: List[bytes]) -> RawRequestMessage:
# request line
line = lines[0].decode("utf-8", "surrogateescape")
try:
Expand Down Expand Up @@ -543,13 +547,13 @@ def parse_message(self, lines: List[bytes]) -> Any:
)


class HttpResponseParser(HttpParser):
class HttpResponseParser(HttpParser[RawResponseMessage]):
"""Read response status line and headers.

BadStatusLine could be raised in case of any errors in status line.
Returns RawResponseMessage"""

def parse_message(self, lines: List[bytes]) -> Any:
def parse_message(self, lines: List[bytes]) -> RawResponseMessage:
line = lines[0].decode("utf-8", "surrogateescape")
try:
version, status = line.split(None, 1)
Expand Down
3 changes: 2 additions & 1 deletion aiohttp/web_protocol.py
Expand Up @@ -11,6 +11,7 @@
Any,
Awaitable,
Callable,
Deque,
Optional,
Tuple,
Type,
Expand Down Expand Up @@ -200,7 +201,7 @@ def __init__(
self._keepalive_timeout = keepalive_timeout
self._lingering_time = float(lingering_time)

self._messages = deque() # type: Any # Python 3.5 has no typing.Deque
self._messages: Deque[Tuple[RawRequestMessage, StreamReader]] = deque()
self._message_tail = b""

self._waiter = None # type: Optional[asyncio.Future[None]]
Expand Down