Skip to content

Commit

Permalink
Merge pull request #377 from twisted/bye-cgi-355
Browse files Browse the repository at this point in the history
Eliminate use of `cgi`
  • Loading branch information
glyph committed Jan 6, 2024
2 parents de8e854 + c2210ca commit 8ccd453
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 28 deletions.
1 change: 1 addition & 0 deletions changelog.d/355.bugfix.rst
@@ -0,0 +1 @@
:mod:`treq.content.text_content()` no longer generates deprecation warnings due to use of the ``cgi`` module.
1 change: 1 addition & 0 deletions setup.py
Expand Up @@ -35,6 +35,7 @@
"Twisted[tls] >= 22.10.0", # For #11635
"attrs",
"typing_extensions >= 3.10.0",
"multipart",
],
extras_require={
"dev": [
Expand Down
31 changes: 24 additions & 7 deletions src/treq/content.py
@@ -1,7 +1,7 @@
import cgi
import json
from typing import Any, Callable, List, Optional, cast
from typing import Any, Callable, FrozenSet, List, Optional, cast

import multipart # type: ignore
from twisted.internet.defer import Deferred, succeed
from twisted.internet.protocol import Protocol, connectionDone
from twisted.python.failure import Failure
Expand All @@ -11,21 +11,38 @@
from twisted.web.iweb import IResponse


"""Characters that are valid in a charset name per RFC 2978.
See https://www.rfc-editor.org/errata/eid5433
"""
_MIME_CHARSET_CHARS: FrozenSet[str] = frozenset(
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" # ALPHA
"0123456789" # DIGIT
"!#$%&+-^_`~" # symbols
)


def _encoding_from_headers(headers: Headers) -> Optional[str]:
content_types = headers.getRawHeaders("content-type")
if content_types is None:
return None

# This seems to be the choice browsers make when encountering multiple
# content-type headers.
content_type, params = cgi.parse_header(content_types[-1])
media_type, params = multipart.parse_options_header(content_types[-1])

charset = params.get("charset")
if charset:
return charset.strip("'\"")

if content_type == "application/json":
return "UTF-8"
assert isinstance(charset, str) # for MyPy
charset = charset.strip("'\"").lower()
if not charset:
return None
if not set(charset).issubset(_MIME_CHARSET_CHARS):
return None
return charset

if media_type == "application/json":
return "utf-8"

return None

Expand Down
60 changes: 58 additions & 2 deletions src/treq/test/test_content.py
@@ -1,4 +1,6 @@
import unittest
from unittest import mock
from typing import Optional

from twisted.python.failure import Failure

Expand All @@ -11,6 +13,7 @@
from twisted.web.server import NOT_DONE_YET

from treq import collect, content, json_content, text_content
from treq.content import _encoding_from_headers
from treq.client import _BufferedResponse
from treq.testing import StubTreq

Expand Down Expand Up @@ -267,6 +270,59 @@ def error(data):
# being closed.
stub.flush()
self.assertEqual(len(resource.request_finishes), 1)
self.assertIsInstance(
resource.request_finishes[0].value, ConnectionDone
self.assertIsInstance(resource.request_finishes[0].value, ConnectionDone)


class EncodingFromHeadersTests(unittest.TestCase):
def _encodingFromContentType(self, content_type: str) -> Optional[str]:
"""
Invoke `_encoding_from_headers()` for a header value.
:param content_type: A Content-Type header value.
:returns: The result of `_encoding_from_headers()`
"""
h = Headers({"Content-Type": [content_type]})
return _encoding_from_headers(h)

def test_rfcExamples(self):
"""
The examples from RFC 9110 § 8.3.1 are normalized to
canonical (lowercase) form.
"""
for example in [
"text/html;charset=utf-8",
'Text/HTML;Charset="utf-8"',
'text/html; charset="utf-8"',
"text/html;charset=UTF-8",
]:
self.assertEqual("utf-8", self._encodingFromContentType(example))

def test_multipleParams(self):
"""The charset parameter is extracted even if mixed with other params."""
for example in [
"a/b;c=d;charSet=ascii",
"a/b;c=d;charset=ascii; e=f",
"a/b;c=d; charsEt=ascii;e=f",
"a/b;c=d; charset=ascii; e=f",
]:
self.assertEqual("ascii", self._encodingFromContentType(example))

def test_quotedString(self):
"""Any quotes that surround the value of the charset param are removed."""
self.assertEqual(
"ascii", self._encodingFromContentType("foo/bar; charset='ASCII'")
)
self.assertEqual(
"shift_jis", self._encodingFromContentType('a/b; charset="Shift_JIS"')
)

def test_noCharset(self):
"""None is returned when no valid charset parameter is found."""
for example in [
"application/octet-stream",
"text/plain;charset=",
"text/plain;charset=''",
"text/plain;charset=\"'\"",
"text/plain;charset=🙃",
]:
self.assertIsNone(self._encodingFromContentType(example))
35 changes: 16 additions & 19 deletions src/treq/test/test_multipart.py
@@ -1,12 +1,11 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.

import cgi
import sys
from typing import cast, AnyStr

from io import BytesIO

from multipart import MultipartParser # type: ignore
from twisted.trial import unittest
from zope.interface.verify import verifyObject

Expand Down Expand Up @@ -588,9 +587,10 @@ def test_newLinesInParams(self):
--heyDavid--
""".encode("utf-8")), output)

def test_worksWithCgi(self):
def test_worksWithMultipart(self):
"""
Make sure the stuff we generated actually parsed by python cgi
Make sure the stuff we generated can actually be parsed by the
`multipart` module.
"""
output = self.getOutput(
MultiPartProducer([
Expand All @@ -612,23 +612,20 @@ def test_worksWithCgi(self):
)
)

form = cgi.parse_multipart(BytesIO(output), {
"boundary": b"heyDavid",
"CONTENT-LENGTH": str(len(output)),
})
form = MultipartParser(
stream=BytesIO(output),
boundary=b"heyDavid",
content_length=len(output),
)

# Since Python 3.7, the value for a non-file field is now a list
# of strings, not bytes.
if sys.version_info >= (3, 7):
self.assertEqual(set(['just a string\r\n', 'another string']),
set(form['cfield']))
else:
self.assertEqual(set([b'just a string\r\n', b'another string']),
set(form['cfield']))
self.assertEqual(
[b'just a string\r\n', b'another string'],
[f.raw for f in form.get_all('cfield')],
)

self.assertEqual(set([b'my lovely bytes2']), set(form['efield']))
self.assertEqual(set([b'my lovely bytes219']), set(form['xfield']))
self.assertEqual(set([b'my lovely bytes22']), set(form['afield']))
self.assertEqual(b'my lovely bytes2', form.get('efield').raw)
self.assertEqual(b'my lovely bytes219', form.get('xfield').raw)
self.assertEqual(b'my lovely bytes22', form.get('afield').raw)


class LengthConsumerTestCase(unittest.TestCase):
Expand Down

0 comments on commit 8ccd453

Please sign in to comment.