diff --git a/changelog.d/355.bugfix.rst b/changelog.d/355.bugfix.rst new file mode 100644 index 00000000..8e427ee0 --- /dev/null +++ b/changelog.d/355.bugfix.rst @@ -0,0 +1 @@ +:mod:`treq.content.text_content()` no longer generates deprecation warnings due to use of the ``cgi`` module. diff --git a/setup.py b/setup.py index 52c26c80..36046d44 100644 --- a/setup.py +++ b/setup.py @@ -35,6 +35,7 @@ "Twisted[tls] >= 22.10.0", # For #11635 "attrs", "typing_extensions >= 3.10.0", + "multipart", ], extras_require={ "dev": [ diff --git a/src/treq/content.py b/src/treq/content.py index 27dba507..e3f4aaad 100644 --- a/src/treq/content.py +++ b/src/treq/content.py @@ -1,7 +1,7 @@ -import cgi import json -from typing import Any, Callable, List, Optional, cast +from typing import Any, Callable, FrozenSet, List, Optional, cast +import multipart # type: ignore from twisted.internet.defer import Deferred, succeed from twisted.internet.protocol import Protocol, connectionDone from twisted.python.failure import Failure @@ -11,6 +11,17 @@ from twisted.web.iweb import IResponse +"""Characters that are valid in a charset name per RFC 2978. + +See https://www.rfc-editor.org/errata/eid5433 +""" +_MIME_CHARSET_CHARS: FrozenSet[str] = frozenset( + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" # ALPHA + "0123456789" # DIGIT + "!#$%&+-^_`~" # symbols +) + + def _encoding_from_headers(headers: Headers) -> Optional[str]: content_types = headers.getRawHeaders("content-type") if content_types is None: @@ -18,14 +29,20 @@ def _encoding_from_headers(headers: Headers) -> Optional[str]: # This seems to be the choice browsers make when encountering multiple # content-type headers. - content_type, params = cgi.parse_header(content_types[-1]) + media_type, params = multipart.parse_options_header(content_types[-1]) charset = params.get("charset") if charset: - return charset.strip("'\"") - - if content_type == "application/json": - return "UTF-8" + assert isinstance(charset, str) # for MyPy + charset = charset.strip("'\"").lower() + if not charset: + return None + if not set(charset).issubset(_MIME_CHARSET_CHARS): + return None + return charset + + if media_type == "application/json": + return "utf-8" return None diff --git a/src/treq/test/test_content.py b/src/treq/test/test_content.py index 0d83ddfe..8158e007 100644 --- a/src/treq/test/test_content.py +++ b/src/treq/test/test_content.py @@ -1,4 +1,6 @@ +import unittest from unittest import mock +from typing import Optional from twisted.python.failure import Failure @@ -11,6 +13,7 @@ from twisted.web.server import NOT_DONE_YET from treq import collect, content, json_content, text_content +from treq.content import _encoding_from_headers from treq.client import _BufferedResponse from treq.testing import StubTreq @@ -267,6 +270,59 @@ def error(data): # being closed. stub.flush() self.assertEqual(len(resource.request_finishes), 1) - self.assertIsInstance( - resource.request_finishes[0].value, ConnectionDone + self.assertIsInstance(resource.request_finishes[0].value, ConnectionDone) + + +class EncodingFromHeadersTests(unittest.TestCase): + def _encodingFromContentType(self, content_type: str) -> Optional[str]: + """ + Invoke `_encoding_from_headers()` for a header value. + + :param content_type: A Content-Type header value. + :returns: The result of `_encoding_from_headers()` + """ + h = Headers({"Content-Type": [content_type]}) + return _encoding_from_headers(h) + + def test_rfcExamples(self): + """ + The examples from RFC 9110 ยง 8.3.1 are normalized to + canonical (lowercase) form. + """ + for example in [ + "text/html;charset=utf-8", + 'Text/HTML;Charset="utf-8"', + 'text/html; charset="utf-8"', + "text/html;charset=UTF-8", + ]: + self.assertEqual("utf-8", self._encodingFromContentType(example)) + + def test_multipleParams(self): + """The charset parameter is extracted even if mixed with other params.""" + for example in [ + "a/b;c=d;charSet=ascii", + "a/b;c=d;charset=ascii; e=f", + "a/b;c=d; charsEt=ascii;e=f", + "a/b;c=d; charset=ascii; e=f", + ]: + self.assertEqual("ascii", self._encodingFromContentType(example)) + + def test_quotedString(self): + """Any quotes that surround the value of the charset param are removed.""" + self.assertEqual( + "ascii", self._encodingFromContentType("foo/bar; charset='ASCII'") ) + self.assertEqual( + "shift_jis", self._encodingFromContentType('a/b; charset="Shift_JIS"') + ) + + def test_noCharset(self): + """None is returned when no valid charset parameter is found.""" + for example in [ + "application/octet-stream", + "text/plain;charset=", + "text/plain;charset=''", + "text/plain;charset=\"'\"", + "text/plain;charset=๐Ÿ™ƒ", + ]: + self.assertIsNone(self._encodingFromContentType(example)) diff --git a/src/treq/test/test_multipart.py b/src/treq/test/test_multipart.py index 5cbaade5..999f1afd 100644 --- a/src/treq/test/test_multipart.py +++ b/src/treq/test/test_multipart.py @@ -1,12 +1,11 @@ # Copyright (c) Twisted Matrix Laboratories. # See LICENSE for details. -import cgi -import sys from typing import cast, AnyStr from io import BytesIO +from multipart import MultipartParser # type: ignore from twisted.trial import unittest from zope.interface.verify import verifyObject @@ -588,9 +587,10 @@ def test_newLinesInParams(self): --heyDavid-- """.encode("utf-8")), output) - def test_worksWithCgi(self): + def test_worksWithMultipart(self): """ - Make sure the stuff we generated actually parsed by python cgi + Make sure the stuff we generated can actually be parsed by the + `multipart` module. """ output = self.getOutput( MultiPartProducer([ @@ -612,23 +612,20 @@ def test_worksWithCgi(self): ) ) - form = cgi.parse_multipart(BytesIO(output), { - "boundary": b"heyDavid", - "CONTENT-LENGTH": str(len(output)), - }) + form = MultipartParser( + stream=BytesIO(output), + boundary=b"heyDavid", + content_length=len(output), + ) - # Since Python 3.7, the value for a non-file field is now a list - # of strings, not bytes. - if sys.version_info >= (3, 7): - self.assertEqual(set(['just a string\r\n', 'another string']), - set(form['cfield'])) - else: - self.assertEqual(set([b'just a string\r\n', b'another string']), - set(form['cfield'])) + self.assertEqual( + [b'just a string\r\n', b'another string'], + [f.raw for f in form.get_all('cfield')], + ) - self.assertEqual(set([b'my lovely bytes2']), set(form['efield'])) - self.assertEqual(set([b'my lovely bytes219']), set(form['xfield'])) - self.assertEqual(set([b'my lovely bytes22']), set(form['afield'])) + self.assertEqual(b'my lovely bytes2', form.get('efield').raw) + self.assertEqual(b'my lovely bytes219', form.get('xfield').raw) + self.assertEqual(b'my lovely bytes22', form.get('afield').raw) class LengthConsumerTestCase(unittest.TestCase):