diff --git a/src/twisted/web/http_headers.py b/src/twisted/web/http_headers.py index 738a1759c6b..f810f4bc2c4 100644 --- a/src/twisted/web/http_headers.py +++ b/src/twisted/web/http_headers.py @@ -18,7 +18,6 @@ Tuple, TypeVar, Union, - cast, overload, ) @@ -85,7 +84,7 @@ class Headers: def __init__( self, rawHeaders: Optional[Mapping[AnyStr, Sequence[AnyStr]]] = None, - ): + ) -> None: self._rawHeaders: Dict[bytes, List[bytes]] = {} if rawHeaders is not None: for name, values in rawHeaders.items(): @@ -111,7 +110,7 @@ def __cmp__(self, other): ) return NotImplemented - def _encodeName(self, name: AnyStr) -> bytes: + def _encodeName(self, name: Union[str, bytes]) -> bytes: """ Encode the name of a header (eg 'Content-Type') to an ISO-8859-1 encoded bytestring if required. @@ -152,7 +151,21 @@ def removeHeader(self, name: AnyStr) -> None: """ self._rawHeaders.pop(self._encodeName(name), None) - def setRawHeaders(self, name: AnyStr, values: Sequence[AnyStr]) -> None: + @overload + def setRawHeaders(self, name: Union[str, bytes], values: Sequence[bytes]) -> None: + ... + + @overload + def setRawHeaders(self, name: Union[str, bytes], values: Sequence[str]) -> None: + ... + + @overload + def setRawHeaders( + self, name: Union[str, bytes], values: Sequence[Union[str, bytes]] + ) -> None: + ... + + def setRawHeaders(self, name: Union[str, bytes], values: object) -> None: """ Sets the raw representation of the given header. @@ -161,9 +174,8 @@ def setRawHeaders(self, name: AnyStr, values: Sequence[AnyStr]) -> None: @param values: A list of strings each one being a header value of the given name. - @raise TypeError: Raised if C{values} is not a L{list} of L{bytes} - or L{str} strings, or if C{name} is not a L{bytes} or - L{str} string. + @raise TypeError: Raised if C{values} is not a sequence of L{bytes} + or L{str}, or if C{name} is not L{bytes} or L{str}. @return: L{None} """ @@ -175,7 +187,7 @@ def setRawHeaders(self, name: AnyStr, values: Sequence[AnyStr]) -> None: if not isinstance(name, (bytes, str)): raise TypeError( - "Header name is an instance of %r, " "not bytes or str" % (type(name),) + f"Header name is an instance of {type(name)!r}, not bytes or str" ) for count, value in enumerate(values): @@ -200,7 +212,7 @@ def setRawHeaders(self, name: AnyStr, values: Sequence[AnyStr]) -> None: self._rawHeaders[_name] = encodedValues - def addRawHeader(self, name: AnyStr, value: AnyStr) -> None: + def addRawHeader(self, name: Union[str, bytes], value: Union[str, bytes]) -> None: """ Add a new raw value for the given header. @@ -210,7 +222,7 @@ def addRawHeader(self, name: AnyStr, value: AnyStr) -> None: """ if not isinstance(name, (bytes, str)): raise TypeError( - "Header name is an instance of %r, " "not bytes or str" % (type(name),) + f"Header name is an instance of {type(name)!r}, not bytes or str" ) if not isinstance(value, (bytes, str)): @@ -219,11 +231,13 @@ def addRawHeader(self, name: AnyStr, value: AnyStr) -> None: "bytes or str" % (type(value),) ) - # We secretly know getRawHeaders is really returning a list - values = cast(List[AnyStr], self.getRawHeaders(name, default=[])) - values.append(value) - - self.setRawHeaders(name, values) + self._rawHeaders.setdefault( + _sanitizeLinearWhitespace(self._encodeName(name)), [] + ).append( + _sanitizeLinearWhitespace( + value.encode("utf8") if isinstance(value, str) else value + ) + ) @overload def getRawHeaders(self, name: AnyStr) -> Optional[Sequence[AnyStr]]: diff --git a/src/twisted/web/newsfragments/11635.bugfix b/src/twisted/web/newsfragments/11635.bugfix new file mode 100644 index 00000000000..a32e51e7618 --- /dev/null +++ b/src/twisted/web/newsfragments/11635.bugfix @@ -0,0 +1 @@ +The typing of the twisted.web.http_headers.Headers methods addRawHeader() and setRawHeaders() now allow mixing str and bytes, matching the runtime behavior. diff --git a/src/twisted/web/test/test_http_headers.py b/src/twisted/web/test/test_http_headers.py index 4f9f71f125f..1e039fd78b6 100644 --- a/src/twisted/web/test/test_http_headers.py +++ b/src/twisted/web/test/test_http_headers.py @@ -668,3 +668,36 @@ def test_copy(self): # Verify that the orignal does not have it self.assertEqual(h.getRawHeaders("test\u00E1"), ["foo\u2603", "bar"]) self.assertEqual(h.getRawHeaders(b"test\xe1"), [b"foo\xe2\x98\x83", b"bar"]) + + +class MixedHeadersTests(TestCase): + """ + Tests for L{Headers}, mixing L{bytes} and L{str} arguments for methods + where that is permitted. + """ + + def test_addRawHeader(self) -> None: + """ + L{Headers.addRawHeader} accepts mixed L{str} and L{bytes}. + """ + h = Headers() + h.addRawHeader(b"bytes", "str") + h.addRawHeader("str", b"bytes") + + self.assertEqual(h.getRawHeaders(b"Bytes"), [b"str"]) + self.assertEqual(h.getRawHeaders("Str"), ["bytes"]) + + def test_setRawHeaders(self) -> None: + """ + L{Headers.setRawHeaders} accepts mixed L{str} and L{bytes}. + """ + h = Headers() + h.setRawHeaders(b"bytes", [b"bytes"]) + h.setRawHeaders("str", ["str"]) + h.setRawHeaders("mixed-str", [b"bytes", "str"]) + h.setRawHeaders(b"mixed-bytes", ["str", b"bytes"]) + + self.assertEqual(h.getRawHeaders(b"Bytes"), [b"bytes"]) + self.assertEqual(h.getRawHeaders("Str"), ["str"]) + self.assertEqual(h.getRawHeaders("Mixed-Str"), ["bytes", "str"]) + self.assertEqual(h.getRawHeaders(b"Mixed-Bytes"), [b"str", b"bytes"])