Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change headers to a dict that parses comma-separated values #7679

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
25 changes: 15 additions & 10 deletions aiohttp/_http_parser.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ from cpython.mem cimport PyMem_Free, PyMem_Malloc
from libc.limits cimport ULLONG_MAX
from libc.string cimport memcpy

from multidict import CIMultiDict as _CIMultiDict, CIMultiDictProxy as _CIMultiDictProxy
from yarl import URL as _URL

from aiohttp import hdrs
Expand All @@ -31,7 +30,10 @@ from .http_exceptions import (
PayloadEncodingError,
TransferEncodingError,
)
from .http_parser import DeflateBuffer as _DeflateBuffer
from .http_parser import (
DeflateBuffer as _DeflateBuffer,
HeadersDictProxy as _HeadersDictProxy,
)
from .http_writer import (
HttpVersion as _HttpVersion,
HttpVersion10 as _HttpVersion10,
Expand Down Expand Up @@ -59,8 +61,7 @@ __all__ = ('HttpRequestParser', 'HttpResponseParser',

cdef object URL = _URL
cdef object URL_build = URL.build
cdef object CIMultiDict = _CIMultiDict
cdef object CIMultiDictProxy = _CIMultiDictProxy
cdef object HeadersDictProxy = _HeadersDictProxy
cdef object HttpVersion = _HttpVersion
cdef object HttpVersion10 = _HttpVersion10
cdef object HttpVersion11 = _HttpVersion11
Expand Down Expand Up @@ -111,7 +112,7 @@ cdef class RawRequestMessage:
cdef readonly str method
cdef readonly str path
cdef readonly object version # HttpVersion
cdef readonly object headers # CIMultiDict
cdef readonly object headers # HeadersDictProxy
cdef readonly object raw_headers # tuple
cdef readonly object should_close
cdef readonly object compression
Expand Down Expand Up @@ -211,7 +212,7 @@ cdef class RawResponseMessage:
cdef readonly object version # HttpVersion
cdef readonly int code
cdef readonly str reason
cdef readonly object headers # CIMultiDict
cdef readonly object headers # HeadersDictProxy
cdef readonly object raw_headers # tuple
cdef readonly object should_close
cdef readonly object compression
Expand Down Expand Up @@ -383,8 +384,6 @@ cdef class HttpParser:
name = find_header(raw_name)
value = raw_value.decode('utf-8', 'surrogateescape')

self._headers.add(name, value)

if name is CONTENT_ENCODING:
self._content_encoding = value

Expand All @@ -393,6 +392,12 @@ cdef class HttpParser:
self._has_value = False
self._raw_headers.append((raw_name, raw_value))

name = name.title()
if name in self._headers:
self._headers[name] += ", " + value
else:
self._headers[name] = value

cdef _on_header_field(self, char* at, size_t length):
cdef Py_ssize_t size
cdef char *buf
Expand Down Expand Up @@ -423,7 +428,7 @@ cdef class HttpParser:
chunked = self._cparser.flags & cparser.F_CHUNKED

raw_headers = tuple(self._raw_headers)
headers = CIMultiDictProxy(self._headers)
headers = HeadersDictProxy(self._headers)

if upgrade or self._cparser.method == cparser.HTTP_CONNECT:
self._upgraded = True
Expand Down Expand Up @@ -665,7 +670,7 @@ cdef int cb_on_message_begin(cparser.llhttp_t* parser) except -1:
cdef HttpParser pyparser = <HttpParser>parser.data

pyparser._started = True
pyparser._headers = CIMultiDict()
pyparser._headers = {}
pyparser._raw_headers = []
PyByteArray_Resize(pyparser._buf, 0)
pyparser._path = None
Expand Down
76 changes: 73 additions & 3 deletions aiohttp/http_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import asyncio
import re
import string
from collections.abc import Mapping
from contextlib import suppress
from enum import IntEnum
from typing import (
Any,
ClassVar,
Dict,
Final,
Generic,
List,
Expand Down Expand Up @@ -65,9 +67,56 @@
METHRE: Final[Pattern[str]] = re.compile(r"[!#$%&'*+\-.^_`|~0-9A-Za-z]+")
VERSRE: Final[Pattern[str]] = re.compile(r"HTTP/(\d).(\d)")
HDRRE: Final[Pattern[bytes]] = re.compile(rb"[\x00-\x1F\x7F()<>@,;:\[\]={} \t\"\\]")
QUOTEHDRRE = re.compile(r'(".*?(?:[^\\]"))[ \t]*(?:,|$)')
HEXDIGIT = re.compile(rb"[0-9a-fA-F]+")


class HeadersDictProxy(Mapping[str, str]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you compared performance compared to inheriting multidict and overriding the behaviors of storing the data and outputting the combined values?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implementation details are not a concern yet, so will come back to that. First, I want to get consensus that this is the correct approach and should be changed in v4, given that it is likely to cause backwards-compatibility breakages for atleast a small proportion of users.

def __init__(self, d: Mapping[str, str]):
self._d = d

def __getitem__(self, key: str):
return self._d.__getitem__(key.title())

def __contains__(self, key: str):
return self._d.__contains__(key.title())

def __iter__(self):
return self._d.__iter__()

def __len__(self):
return self._d.__len__()

def has_key(self, key: str):
return self._d.has_key(key.title())

def get(self, key: str, default=None):
return self._d.get(key.title(), default)

def getall(self, key: str) -> Tuple[str]:
return self._split_on_commas(self._d.get(key, ""))

def _split_on_commas(self, val: str) -> Tuple[str]:
values = []
while val:
quoted = re.match(QUOTEHDRRE, val)
if quoted:
values.append(quoted.group(1)[1:-1])
val = val[len(quoted.group()) :].lstrip()
else:
try:
h, val = val.split(",", maxsplit=1)
except ValueError:
h = val
val = ""
val = val.lstrip()
h = h.rstrip()
if h:
values.append(h)

return tuple(values)


class RawRequestMessage(NamedTuple):
method: str
path: str
Expand Down Expand Up @@ -123,7 +172,7 @@ def __init__(
def parse_headers(
self, lines: List[bytes]
) -> Tuple["CIMultiDictProxy[str]", RawHeaders]:
headers: CIMultiDict[str] = CIMultiDict()
headers = {}
raw_headers = []

lines_idx = 1
Expand Down Expand Up @@ -205,10 +254,31 @@ def parse_headers(
if "\n" in value or "\r" in value or "\x00" in value:
raise InvalidHeader(bvalue)

headers.add(name, value)
raw_headers.append((bname, bvalue))
name = name.title()
if name in headers:
# https://www.rfc-editor.org/rfc/rfc9110.html#name-field-order
# https://www.rfc-editor.org/rfc/rfc9110.html#section-5.5-8
# https://www.rfc-editor.org/rfc/rfc9110.html#section-10.2.2-13.1
# https://www.rfc-editor.org/rfc/rfc9110.html#name-collected-abnf
if name in {
"Content-Location",
"Date",
"From",
"If-Modified-Since",
"If-Range",
"If-Unmodified-Since",
"Last-Modified",
"Location",
"Referer",
"Retry-After",
}:
raise BadHttpMessage(f"Duplicate '{name}' header found.")
headers[name] += ", " + value
else:
headers[name] = value

return (CIMultiDictProxy(headers), tuple(raw_headers))
return (HeadersDictProxy(headers), tuple(raw_headers))


class HttpParser(abc.ABC, Generic[_MsgT]):
Expand Down
35 changes: 35 additions & 0 deletions tests/test_http_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,41 @@ def test_cve_2023_37276(parser: Any) -> None:
parser.feed_data(text)


@pytest.mark.parametrize(
("hdr_vals", "expected"),
(
(
('"http://example.com/a.html,foo", apples',),
("http://example.com/a.html,foo", "apples"),
),
(("bananas, apples",), ("bananas", "apples")),
(("bananas", "apples"), ("bananas", "apples")),
(
('"http://example.com/a.html,foo", "apples"',),
("http://example.com/a.html,foo", "apples"),
),
(
('"Sat, 04 May 1996", "Wed, 14 Sep 2005"',),
("Sat, 04 May 1996", "Wed, 14 Sep 2005"),
),
(("foo,bar,baz",), ("foo", "bar", "baz")),
(('"applebanna, this',), ('"applebanna', "this")),
(('fooo", "bar"',), ('fooo"', "bar")),
((" spam , eggs ",), ("spam", "eggs")),
((" spam ", " eggs "), ("spam", "eggs")),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: Add tests for escaped quotes (e.g. "foo\"bar"), maybe also escaped backslash, if that's valid (e.g. "foo\\" or "foo\\\"").

((", , ",), ()),
(("",), ()),
),
)
def test_list_headers(parser: Any, hdr_vals: List[str], expected: List[str]) -> None:
headers = "\r\n".join(f"Foo: {v}" for v in hdr_vals)
text = f"POST / HTTP/1.1\r\n{headers}\r\n\r\n".encode()
messages, upgrade, tail = parser.feed_data(text)
msg = messages[0][0]

assert msg.headers.getall("Foo") == expected


@pytest.mark.parametrize(
"hdr",
(
Expand Down