From 1130c30b7201b4d93f218767d29f920f4e483b02 Mon Sep 17 00:00:00 2001 From: Andrew Svetlov Date: Sun, 22 Nov 2020 13:42:54 +0200 Subject: [PATCH] Make type hints for http parser stricter (#5267) --- CHANGES/5267.feature | 1 + aiohttp/client_exceptions.py | 3 ++- aiohttp/client_proto.py | 4 ++-- aiohttp/http_parser.py | 28 ++++++++++++++++------------ aiohttp/web_protocol.py | 3 ++- 5 files changed, 23 insertions(+), 16 deletions(-) create mode 100644 CHANGES/5267.feature diff --git a/CHANGES/5267.feature b/CHANGES/5267.feature new file mode 100644 index 00000000000..63dd2ffc518 --- /dev/null +++ b/CHANGES/5267.feature @@ -0,0 +1 @@ +Make type hints for http parser stricter diff --git a/aiohttp/client_exceptions.py b/aiohttp/client_exceptions.py index 1a43e6c0ef6..12fffff64ab 100644 --- a/aiohttp/client_exceptions.py +++ b/aiohttp/client_exceptions.py @@ -3,6 +3,7 @@ import asyncio from typing import TYPE_CHECKING, Any, Optional, Tuple, Union +from .http_parser import RawResponseMessage from .typedefs import LooseHeaders try: @@ -192,7 +193,7 @@ class ServerConnectionError(ClientConnectionError): class ServerDisconnectedError(ServerConnectionError): """Server disconnected.""" - def __init__(self, message: Optional[str] = None) -> None: + def __init__(self, message: Union[RawResponseMessage, str, None] = None) -> None: if message is None: message = "Server disconnected" diff --git a/aiohttp/client_proto.py b/aiohttp/client_proto.py index d0a0f1056ab..dc6a6732d34 100644 --- a/aiohttp/client_proto.py +++ b/aiohttp/client_proto.py @@ -23,7 +23,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop) -> None: self._should_close = False - self._payload = None + self._payload: Optional[StreamReader] = None self._skip_payload = False self._payload_parser = None @@ -230,7 +230,7 @@ def data_received(self, data: bytes) -> None: self._upgraded = upgraded - payload = None + payload: Optional[StreamReader] = None for message, payload in messages: if message.should_close: self._should_close = True diff --git a/aiohttp/http_parser.py b/aiohttp/http_parser.py index fc7fca5dd87..87755d48a37 100644 --- a/aiohttp/http_parser.py +++ b/aiohttp/http_parser.py @@ -4,8 +4,9 @@ import re import string import zlib +from contextlib import suppress from enum import IntEnum -from typing import Any, List, Optional, Tuple, Type, Union +from typing import Any, Generic, List, Optional, Tuple, Type, TypeVar, Union from multidict import CIMultiDict, CIMultiDictProxy, istr from yarl import URL @@ -88,6 +89,9 @@ ) +_MsgT = TypeVar("_MsgT", RawRequestMessage, RawResponseMessage) + + class ParseState(IntEnum): PARSE_NONE = 0 @@ -198,7 +202,7 @@ def parse_headers( return (CIMultiDictProxy(headers), tuple(raw_headers)) -class HttpParser(abc.ABC): +class HttpParser(abc.ABC, Generic[_MsgT]): def __init__( self, protocol: BaseProtocol, @@ -239,10 +243,10 @@ def __init__( self._headers_parser = HeadersParser(max_line_size, max_headers, max_field_size) @abc.abstractmethod - def parse_message(self, lines: List[bytes]) -> Any: + def parse_message(self, lines: List[bytes]) -> _MsgT: pass - def feed_eof(self) -> Any: + def feed_eof(self) -> Optional[_MsgT]: if self._payload_parser is not None: self._payload_parser.feed_eof() self._payload_parser = None @@ -254,10 +258,9 @@ def feed_eof(self) -> Any: if self._lines: if self._lines[-1] != "\r\n": self._lines.append(b"") - try: + with suppress(Exception): return self.parse_message(self._lines) - except Exception: - return None + return None def feed_data( self, @@ -267,7 +270,7 @@ def feed_data( CONTENT_LENGTH: istr = hdrs.CONTENT_LENGTH, METH_CONNECT: str = hdrs.METH_CONNECT, SEC_WEBSOCKET_KEY1: istr = hdrs.SEC_WEBSOCKET_KEY1, - ) -> Tuple[List[Any], bool, bytes]: + ) -> Tuple[List[Tuple[_MsgT, StreamReader]], bool, bytes]: messages = [] @@ -346,6 +349,7 @@ def feed_data( if not payload_parser.done: self._payload_parser = payload_parser elif method == METH_CONNECT: + assert isinstance(msg, RawRequestMessage) payload = StreamReader( self.protocol, timer=self.timer, @@ -480,13 +484,13 @@ def set_upgraded(self, val: bool) -> None: self._upgraded = val -class HttpRequestParser(HttpParser): +class HttpRequestParser(HttpParser[RawRequestMessage]): """Read request status line. Exception .http_exceptions.BadStatusLine could be raised in case of any errors in status line. Returns RawRequestMessage. """ - def parse_message(self, lines: List[bytes]) -> Any: + def parse_message(self, lines: List[bytes]) -> RawRequestMessage: # request line line = lines[0].decode("utf-8", "surrogateescape") try: @@ -543,13 +547,13 @@ def parse_message(self, lines: List[bytes]) -> Any: ) -class HttpResponseParser(HttpParser): +class HttpResponseParser(HttpParser[RawResponseMessage]): """Read response status line and headers. BadStatusLine could be raised in case of any errors in status line. Returns RawResponseMessage""" - def parse_message(self, lines: List[bytes]) -> Any: + def parse_message(self, lines: List[bytes]) -> RawResponseMessage: line = lines[0].decode("utf-8", "surrogateescape") try: version, status = line.split(None, 1) diff --git a/aiohttp/web_protocol.py b/aiohttp/web_protocol.py index 8e54ed50758..544a903097d 100644 --- a/aiohttp/web_protocol.py +++ b/aiohttp/web_protocol.py @@ -11,6 +11,7 @@ Any, Awaitable, Callable, + Deque, Optional, Tuple, Type, @@ -200,7 +201,7 @@ def __init__( self._keepalive_timeout = keepalive_timeout self._lingering_time = float(lingering_time) - self._messages = deque() # type: Any # Python 3.5 has no typing.Deque + self._messages: Deque[Tuple[RawRequestMessage, StreamReader]] = deque() self._message_tail = b"" self._waiter = None # type: Optional[asyncio.Future[None]]