Skip to content

Commit

Permalink
#11635 Allow mixed types in Headers types (#11636)
Browse files Browse the repository at this point in the history
  • Loading branch information
twm committed Oct 3, 2022
2 parents f0be792 + e251fd4 commit 3314a7c
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 15 deletions.
44 changes: 29 additions & 15 deletions src/twisted/web/http_headers.py
Expand Up @@ -18,7 +18,6 @@
Tuple,
TypeVar,
Union,
cast,
overload,
)

Expand Down Expand Up @@ -85,7 +84,7 @@ class Headers:
def __init__(
self,
rawHeaders: Optional[Mapping[AnyStr, Sequence[AnyStr]]] = None,
):
) -> None:
self._rawHeaders: Dict[bytes, List[bytes]] = {}
if rawHeaders is not None:
for name, values in rawHeaders.items():
Expand All @@ -111,7 +110,7 @@ def __cmp__(self, other):
)
return NotImplemented

def _encodeName(self, name: AnyStr) -> bytes:
def _encodeName(self, name: Union[str, bytes]) -> bytes:
"""
Encode the name of a header (eg 'Content-Type') to an ISO-8859-1 encoded
bytestring if required.
Expand Down Expand Up @@ -152,7 +151,21 @@ def removeHeader(self, name: AnyStr) -> None:
"""
self._rawHeaders.pop(self._encodeName(name), None)

def setRawHeaders(self, name: AnyStr, values: Sequence[AnyStr]) -> None:
@overload
def setRawHeaders(self, name: Union[str, bytes], values: Sequence[bytes]) -> None:
...

@overload
def setRawHeaders(self, name: Union[str, bytes], values: Sequence[str]) -> None:
...

@overload
def setRawHeaders(
self, name: Union[str, bytes], values: Sequence[Union[str, bytes]]
) -> None:
...

def setRawHeaders(self, name: Union[str, bytes], values: object) -> None:
"""
Sets the raw representation of the given header.
Expand All @@ -161,9 +174,8 @@ def setRawHeaders(self, name: AnyStr, values: Sequence[AnyStr]) -> None:
@param values: A list of strings each one being a header value of
the given name.
@raise TypeError: Raised if C{values} is not a L{list} of L{bytes}
or L{str} strings, or if C{name} is not a L{bytes} or
L{str} string.
@raise TypeError: Raised if C{values} is not a sequence of L{bytes}
or L{str}, or if C{name} is not L{bytes} or L{str}.
@return: L{None}
"""
Expand All @@ -175,7 +187,7 @@ def setRawHeaders(self, name: AnyStr, values: Sequence[AnyStr]) -> None:

if not isinstance(name, (bytes, str)):
raise TypeError(
"Header name is an instance of %r, " "not bytes or str" % (type(name),)
f"Header name is an instance of {type(name)!r}, not bytes or str"
)

for count, value in enumerate(values):
Expand All @@ -200,7 +212,7 @@ def setRawHeaders(self, name: AnyStr, values: Sequence[AnyStr]) -> None:

self._rawHeaders[_name] = encodedValues

def addRawHeader(self, name: AnyStr, value: AnyStr) -> None:
def addRawHeader(self, name: Union[str, bytes], value: Union[str, bytes]) -> None:
"""
Add a new raw value for the given header.
Expand All @@ -210,7 +222,7 @@ def addRawHeader(self, name: AnyStr, value: AnyStr) -> None:
"""
if not isinstance(name, (bytes, str)):
raise TypeError(
"Header name is an instance of %r, " "not bytes or str" % (type(name),)
f"Header name is an instance of {type(name)!r}, not bytes or str"
)

if not isinstance(value, (bytes, str)):
Expand All @@ -219,11 +231,13 @@ def addRawHeader(self, name: AnyStr, value: AnyStr) -> None:
"bytes or str" % (type(value),)
)

# We secretly know getRawHeaders is really returning a list
values = cast(List[AnyStr], self.getRawHeaders(name, default=[]))
values.append(value)

self.setRawHeaders(name, values)
self._rawHeaders.setdefault(
_sanitizeLinearWhitespace(self._encodeName(name)), []
).append(
_sanitizeLinearWhitespace(
value.encode("utf8") if isinstance(value, str) else value
)
)

@overload
def getRawHeaders(self, name: AnyStr) -> Optional[Sequence[AnyStr]]:
Expand Down
1 change: 1 addition & 0 deletions src/twisted/web/newsfragments/11635.bugfix
@@ -0,0 +1 @@
The typing of the twisted.web.http_headers.Headers methods addRawHeader() and setRawHeaders() now allow mixing str and bytes, matching the runtime behavior.
33 changes: 33 additions & 0 deletions src/twisted/web/test/test_http_headers.py
Expand Up @@ -668,3 +668,36 @@ def test_copy(self):
# Verify that the orignal does not have it
self.assertEqual(h.getRawHeaders("test\u00E1"), ["foo\u2603", "bar"])
self.assertEqual(h.getRawHeaders(b"test\xe1"), [b"foo\xe2\x98\x83", b"bar"])


class MixedHeadersTests(TestCase):
"""
Tests for L{Headers}, mixing L{bytes} and L{str} arguments for methods
where that is permitted.
"""

def test_addRawHeader(self) -> None:
"""
L{Headers.addRawHeader} accepts mixed L{str} and L{bytes}.
"""
h = Headers()
h.addRawHeader(b"bytes", "str")
h.addRawHeader("str", b"bytes")

self.assertEqual(h.getRawHeaders(b"Bytes"), [b"str"])
self.assertEqual(h.getRawHeaders("Str"), ["bytes"])

def test_setRawHeaders(self) -> None:
"""
L{Headers.setRawHeaders} accepts mixed L{str} and L{bytes}.
"""
h = Headers()
h.setRawHeaders(b"bytes", [b"bytes"])
h.setRawHeaders("str", ["str"])
h.setRawHeaders("mixed-str", [b"bytes", "str"])
h.setRawHeaders(b"mixed-bytes", ["str", b"bytes"])

self.assertEqual(h.getRawHeaders(b"Bytes"), [b"bytes"])
self.assertEqual(h.getRawHeaders("Str"), ["str"])
self.assertEqual(h.getRawHeaders("Mixed-Str"), ["bytes", "str"])
self.assertEqual(h.getRawHeaders(b"Mixed-Bytes"), [b"str", b"bytes"])

0 comments on commit 3314a7c

Please sign in to comment.