From 5267b9e9c0a7e1c8790137796ef2739f52507dce Mon Sep 17 00:00:00 2001 From: Tom Most Date: Mon, 5 Sep 2022 18:58:06 -0700 Subject: [PATCH 01/21] Begin the typing --- MANIFEST.in | 2 ++ mypy.ini | 40 ++++++++++++++++++++++++++++++++++++++++ setup.py | 2 +- src/treq/py.typed | 0 tox.ini | 10 ++++++++++ 5 files changed, 53 insertions(+), 1 deletion(-) create mode 100644 mypy.ini create mode 100644 src/treq/py.typed diff --git a/MANIFEST.in b/MANIFEST.in index c68b6d96..d1242850 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -3,11 +3,13 @@ include *.rst include *.md include LICENSE include .coveragerc +include src/treq/py.typed recursive-include docs * prune docs/_build prune docs/html exclude tox.ini +exclude mypy.ini exclude .github exclude .readthedocs.yml diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 00000000..b8447773 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,40 @@ +[mypy] + +namespace_packages = True +plugins=mypy_zope:plugin + +# Increase our expectations + +check_untyped_defs = True +disallow_incomplete_defs = True +disallow_untyped_defs = True +no_implicit_optional = True +show_column_numbers = True +show_error_codes = True +strict_optional = True +warn_no_return = True +warn_redundant_casts = True +warn_return_any = True +warn_unreachable = True +warn_unused_ignores = True + +# These are too strict for us at the moment + +disallow_any_decorated = False +disallow_any_explicit = False +disallow_any_expr = False +disallow_any_generics = False +disallow_any_unimported = False +disallow_subclassing_any = False +disallow_untyped_calls = False +disallow_untyped_decorators = False +strict_equality = False + +# Disable some checks until the effected modules fully adopt mypy + +[mypy-treq._version] +check_untyped_defs = False + +[mypy-treq.test.local_httpbin.*] +disallow_untyped_defs = False +check_untyped_defs = False diff --git a/setup.py b/setup.py index 83027633..8db45e23 100644 --- a/setup.py +++ b/setup.py @@ -46,7 +46,7 @@ "sphinx<7.0.0", # Removal of 'style' key breaks RTD. ], }, - package_data={"treq": ["_version"]}, + package_data={"treq": ["py.typed"]}, author="David Reid", author_email="dreid@dreid.org", maintainer="Tom Most", diff --git a/src/treq/py.typed b/src/treq/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/tox.ini b/tox.ini index 2c308623..36d47c53 100644 --- a/tox.ini +++ b/tox.ini @@ -25,6 +25,16 @@ commands = {envbindir}/coverage run -p \ {envbindir}/trial {posargs:treq} +[testenv:mypy] +deps = + mypy==0.971 + mypy-zope==0.3.9 +commands = + mypy \ + --cache-dir="{toxworkdir}/mypy_cache" \ + {tty:--pretty:} \ + {posargs:src} + [testenv:flake8] skip_install = True deps = flake8 From 19a95ac079d9d51a51fa270181fd445fa8c3cf2f Mon Sep 17 00:00:00 2001 From: Tom Most Date: Tue, 20 Dec 2022 12:47:24 -0800 Subject: [PATCH 02/21] Ignore per-process Coverage files --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 9b00e7b3..10fe94a5 100644 --- a/.gitignore +++ b/.gitignore @@ -8,7 +8,8 @@ dist /docs/html /.eggs MANIFEST -.coverage +/.coverage +/.coverage.* coverage htmlcov _trial_temp* From a0c3a2931dc09b389c8c4d6edf4ce2a1e7425ea0 Mon Sep 17 00:00:00 2001 From: Tom Most Date: Tue, 20 Dec 2022 21:54:54 -0800 Subject: [PATCH 03/21] Bump MyPy --- tox.ini | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tox.ini b/tox.ini index 36d47c53..5195cf0f 100644 --- a/tox.ini +++ b/tox.ini @@ -27,8 +27,9 @@ commands = [testenv:mypy] deps = - mypy==0.971 - mypy-zope==0.3.9 + mypy==0.981 + mypy-zope==0.3.11 + types-requests commands = mypy \ --cache-dir="{toxworkdir}/mypy_cache" \ From 0d892777263b4308dc66bb24e5dc1d046d57f7b9 Mon Sep 17 00:00:00 2001 From: Tom Most Date: Mon, 5 Sep 2022 23:26:43 -0700 Subject: [PATCH 04/21] MyPy clean --- mypy.ini | 40 --- pyproject.toml | 62 ++++ src/treq/__init__.py | 25 +- src/treq/_agentspy.py | 8 +- src/treq/_types.py | 102 ++++++ src/treq/client.py | 298 +++++++++--------- src/treq/multipart.py | 202 +++++++----- src/treq/response.py | 25 +- src/treq/test/local_httpbin/child.py | 4 +- .../test/local_httpbin/test/test_child.py | 2 +- src/treq/test/test_agentspy.py | 12 +- src/treq/test/test_api.py | 41 +-- src/treq/testing.py | 9 + 13 files changed, 509 insertions(+), 321 deletions(-) delete mode 100644 mypy.ini create mode 100644 src/treq/_types.py diff --git a/mypy.ini b/mypy.ini deleted file mode 100644 index b8447773..00000000 --- a/mypy.ini +++ /dev/null @@ -1,40 +0,0 @@ -[mypy] - -namespace_packages = True -plugins=mypy_zope:plugin - -# Increase our expectations - -check_untyped_defs = True -disallow_incomplete_defs = True -disallow_untyped_defs = True -no_implicit_optional = True -show_column_numbers = True -show_error_codes = True -strict_optional = True -warn_no_return = True -warn_redundant_casts = True -warn_return_any = True -warn_unreachable = True -warn_unused_ignores = True - -# These are too strict for us at the moment - -disallow_any_decorated = False -disallow_any_explicit = False -disallow_any_expr = False -disallow_any_generics = False -disallow_any_unimported = False -disallow_subclassing_any = False -disallow_untyped_calls = False -disallow_untyped_decorators = False -strict_equality = False - -# Disable some checks until the effected modules fully adopt mypy - -[mypy-treq._version] -check_untyped_defs = False - -[mypy-treq.test.local_httpbin.*] -disallow_untyped_defs = False -check_untyped_defs = False diff --git a/pyproject.toml b/pyproject.toml index f97064fa..76f182ca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,3 +13,65 @@ filename = "CHANGELOG.rst" directory = "changelog.d" title_format = "{version} ({project_date})" issue_format = "`#{issue} `__" + +[tool.mypy] +namespace_packages = true +plugins = "mypy_zope:plugin" + +check_untyped_defs = true +disallow_incomplete_defs = true +disallow_untyped_defs = true +no_implicit_optional = true +show_column_numbers = true +show_error_codes = true +strict_optional = true +warn_no_return = true +warn_redundant_casts = true +warn_return_any = true +warn_unreachable = true +warn_unused_ignores = true + +disallow_any_decorated = false +disallow_any_explicit = false +disallow_any_expr = false +disallow_any_generics = false +disallow_any_unimported = false +disallow_subclassing_any = false +disallow_untyped_calls = false +disallow_untyped_decorators = false +strict_equality = false + +[[tool.mypy.overrides]] +module = [ + "treq.api", + "treq.auth", + "treq.client", + "treq.content", + "treq.multipart", + "treq.response", + "treq.testing", + "treq.test.test_api", + "treq.test.test_auth", + "treq.test.test_client", + "treq.test.test_content", + "treq.test.test_multipart", + "treq.test.test_response", + "treq.test.test_testing", + "treq.test.test_treq_integration", + "treq.test.util", +] +disallow_untyped_defs = false +check_untyped_defs = false + +[[tool.mypy.overrides]] +module = [ + "treq.test.local_httpbin.child", + "treq.test.local_httpbin.parent", + "treq.test.local_httpbin.shared", + "treq.test.local_httpbin.test.test_child", + "treq.test.local_httpbin.test.test_parent", + "treq.test.local_httpbin.test.test_shared", +] +disallow_untyped_defs = false +check_untyped_defs = false +ignore_missing_imports = true diff --git a/src/treq/__init__.py b/src/treq/__init__.py index 771fb88e..530c6c6b 100644 --- a/src/treq/__init__.py +++ b/src/treq/__init__.py @@ -1,11 +1,20 @@ -from __future__ import absolute_import, division, print_function +from treq.api import delete, get, head, patch, post, put, request +from treq.content import collect, content, json_content, text_content -from ._version import __version__ +from ._version import __version__ as _version -from treq.api import head, get, post, put, patch, delete, request -from treq.content import collect, content, text_content, json_content +__version__: str = _version.base() -__version__ = __version__.base() - -__all__ = ['head', 'get', 'post', 'put', 'patch', 'delete', 'request', - 'collect', 'content', 'text_content', 'json_content'] +__all__ = [ + "head", + "get", + "post", + "put", + "patch", + "delete", + "request", + "collect", + "content", + "text_content", + "json_content", +] diff --git a/src/treq/_agentspy.py b/src/treq/_agentspy.py index c7ad7f66..5e939d1c 100644 --- a/src/treq/_agentspy.py +++ b/src/treq/_agentspy.py @@ -25,7 +25,7 @@ class RequestRecord: uri = attr.ib() # type: bytes headers = attr.ib() # type: Optional[Headers] bodyProducer = attr.ib() # type: Optional[IBodyProducer] - deferred = attr.ib() # type: Deferred + deferred = attr.ib() # type: Deferred[IResponse] @implementer(IAgent) @@ -38,7 +38,7 @@ class _AgentSpy: A function called with each :class:`RequestRecord` """ - _callback = attr.ib() # type: Callable[Tuple[RequestRecord], None] + _callback: Callable[[RequestRecord], None] = attr.ib() def request(self, method, uri, headers=None, bodyProducer=None): # type: (bytes, bytes, Optional[Headers], Optional[IBodyProducer]) -> Deferred[IResponse] # noqa @@ -63,7 +63,7 @@ def request(self, method, uri, headers=None, bodyProducer=None): " Is the implementation marked with @implementer(IBodyProducer)?" ).format(bodyProducer) ) - d = Deferred() + d: Deferred[IResponse] = Deferred() record = RequestRecord(method, uri, headers, bodyProducer, d) self._callback(record) return d @@ -87,6 +87,6 @@ def agent_spy(): - A list of calls made to the agent's :meth:`~twisted.web.iweb.IAgent.request()` method """ - records = [] + records: list[RequestRecord] = [] agent = _AgentSpy(records.append) return agent, records diff --git a/src/treq/_types.py b/src/treq/_types.py new file mode 100644 index 00000000..698147be --- /dev/null +++ b/src/treq/_types.py @@ -0,0 +1,102 @@ +import io +from http.cookiejar import CookieJar +from typing import Any, Iterable, Mapping, Union + +from hyperlink import DecodedURL, EncodedURL +from twisted.internet.interfaces import (IReactorPluggableNameResolver, + IReactorTCP, IReactorTime) +from twisted.web.http_headers import Headers +from twisted.web.iweb import IBodyProducer + + +class _ITreqReactor(IReactorTCP, IReactorTime, IReactorPluggableNameResolver): + """ + The kind of reactor treq needs for type-checking purposes. + + This is an approximation of the actual requirement, which comes from the + `twisted.internet.endpoints.HostnameEndpoint` used by the `Agent` + implementation: + + > Provider of IReactorTCP, IReactorTime and either + > IReactorPluggableNameResolver or IReactorPluggableResolver. + + We don't model the `IReactorPluggableResolver` option because it is + deprecated. + """ + + +_S = Union[bytes, str] + +_URLType = Union[ + str, + bytes, + EncodedURL, + DecodedURL, +] + +_ParamsType = Union[ + Mapping[str, Union[str, tuple[str, ...], list[str]]], + list[tuple[str, str]], +] + +_HeadersType = Union[ + Headers, + dict[_S, _S], + dict[_S, list[_S]], +] + +_CookiesType = Union[ + CookieJar, + Mapping[str, str], +] + +_WholeBody = Union[ + bytes, + io.BytesIO, + io.BufferedReader, + IBodyProducer, +] +""" +Types that define the entire HTTP request body, including those coercible to +`IBodyProducer`. +""" + +# Concrete types are used here because the handling of the *data* parameter +# does lots of isinstance checks. +_BodyFields = Union[ + dict[str, str], + list[tuple[str, str]], +] +""" +Types that will be URL- or multipart-encoded before being sent as part of the +HTTP request body. +""" + +_DataType = Union[_WholeBody, _BodyFields] +""" +Values accepted for the *data* parameter + +Note that this is a simplification. Only `_BodyFields` may be supplied if the +*files* parameter is passed. +""" + +_FileValue = Union[ + str, + bytes, + tuple[str, str, IBodyProducer], +] +""" +Either a scalar string, or a file to upload as (filename, content type, +IBodyProducer) +""" + +_FilesType = Union[ + Mapping[str, _FileValue], + Iterable[tuple[str, _FileValue]], +] +""" +Values accepted for the *files* parameter. +""" + +# Soon... 🤞 https://github.com/python/mypy/issues/731 +_JSONType = Any diff --git a/src/treq/client.py b/src/treq/client.py index 1b09fb0b..336c73aa 100644 --- a/src/treq/client.py +++ b/src/treq/client.py @@ -1,49 +1,48 @@ import io import mimetypes import uuid -import warnings -from collections.abc import Mapping -from http.cookiejar import CookieJar, Cookie -from urllib.parse import quote_plus, urlencode as _urlencode +from collections import abc +from http.cookiejar import Cookie, CookieJar +from json import dumps as json_dumps +from typing import Any, Callable, Iterable, Iterator, Mapping, Optional, Union +from urllib.parse import quote_plus +from urllib.parse import urlencode as _urlencode -from twisted.internet.interfaces import IProtocol +from hyperlink import DecodedURL, EncodedURL +from requests.cookies import merge_cookies from twisted.internet.defer import Deferred -from twisted.python.components import proxyForInterface +from twisted.internet.interfaces import IProtocol +from twisted.python.components import proxyForInterface, registerAdapter from twisted.python.filepath import FilePath -from hyperlink import DecodedURL, EncodedURL - +from twisted.web.client import (BrowserLikeRedirectAgent, ContentDecoderAgent, + CookieAgent, FileBodyProducer, GzipDecoder, + IAgent, RedirectAgent) from twisted.web.http_headers import Headers from twisted.web.iweb import IBodyProducer, IResponse -from twisted.web.client import ( - FileBodyProducer, - RedirectAgent, - BrowserLikeRedirectAgent, - ContentDecoderAgent, - GzipDecoder, - CookieAgent -) - -from twisted.python.components import registerAdapter -from json import dumps as json_dumps - -from treq.auth import add_auth from treq import multipart +from treq._types import (_CookiesType, _DataType, _FilesType, _FileValue, + _HeadersType, _ITreqReactor, _JSONType, _ParamsType, + _URLType) +from treq.auth import add_auth from treq.response import _Response -from requests.cookies import merge_cookies -_NOTHING = object() +class _Nothing: + """Type of the sentinel `_NOTHING`""" + + +_NOTHING = _Nothing() -def urlencode(query, doseq): +def urlencode(query: _ParamsType, doseq: bool) -> bytes: s = _urlencode(query, doseq) - if not isinstance(s, bytes): - s = s.encode("ascii") - return s + return s.encode("ascii") -def _scoped_cookiejar_from_dict(url_object, cookie_dict): +def _scoped_cookiejar_from_dict( + url_object: EncodedURL, cookie_dict: Optional[Mapping[str, str]] +) -> CookieJar: """ Create a CookieJar from a dictionary whose cookies are all scoped to the given URL's origin. @@ -55,14 +54,14 @@ def _scoped_cookiejar_from_dict(url_object, cookie_dict): if cookie_dict is None: return cookie_jar for k, v in cookie_dict.items(): - secure = url_object.scheme == 'https' + secure = url_object.scheme == "https" port_specified = not ( (url_object.scheme == "https" and url_object.port == 443) or (url_object.scheme == "http" and url_object.port == 80) ) port = str(url_object.port) if port_specified else None domain = url_object.host - netscape_domain = domain if '.' in domain else domain + '.local' + netscape_domain = domain if "." in domain else domain + ".local" cookie_jar.set_cookie( Cookie( @@ -71,11 +70,9 @@ def _scoped_cookiejar_from_dict(url_object, cookie_dict): port=port, secure=secure, port_specified=port_specified, - # Contents name=k, value=v, - # Constant/always-the-same stuff version=0, path="/", @@ -87,28 +84,28 @@ def _scoped_cookiejar_from_dict(url_object, cookie_dict): path_specified=False, domain_specified=False, domain_initial_dot=False, - rest=[], + rest={}, ) ) return cookie_jar -class _BodyBufferingProtocol(proxyForInterface(IProtocol)): +class _BodyBufferingProtocol(proxyForInterface(IProtocol)): # type: ignore def __init__(self, original, buffer, finished): self.original = original self.buffer = buffer self.finished = finished - def dataReceived(self, data): + def dataReceived(self, data: bytes) -> None: self.buffer.append(data) self.original.dataReceived(data) - def connectionLost(self, reason): + def connectionLost(self, reason: Exception) -> None: self.original.connectionLost(reason) self.finished.errback(reason) -class _BufferedResponse(proxyForInterface(IResponse)): +class _BufferedResponse(proxyForInterface(IResponse)): # type: ignore def __init__(self, original): self.original = original self._buffer = [] @@ -130,11 +127,7 @@ def deliverBody(self, protocol): self._waiting = Deferred() self._waiting.addBoth(self._deliverWaiting) self.original.deliverBody( - _BodyBufferingProtocol( - protocol, - self._buffer, - self._waiting - ) + _BodyBufferingProtocol(protocol, self._buffer, self._waiting) ) elif self._finished: for segment in self._buffer: @@ -145,79 +138,89 @@ def deliverBody(self, protocol): class HTTPClient: - def __init__(self, agent, cookiejar=None, - data_to_body_producer=IBodyProducer): + def __init__( + self, + agent: IAgent, + cookiejar: Optional[CookieJar] = None, + data_to_body_producer: Callable[[Any], IBodyProducer] = IBodyProducer, + ) -> None: self._agent = agent if cookiejar is None: cookiejar = CookieJar() self._cookiejar = cookiejar self._data_to_body_producer = data_to_body_producer - def get(self, url, **kwargs): + def get(self, url: _URLType, **kwargs: Any) -> Deferred[_Response]: """ See :func:`treq.get()`. """ - kwargs.setdefault('_stacklevel', 3) - return self.request('GET', url, **kwargs) + kwargs.setdefault("_stacklevel", 3) + return self.request("GET", url, **kwargs) - def put(self, url, data=None, **kwargs): + def put( + self, url: _URLType, data: Optional[_DataType] = None, **kwargs: Any + ) -> Deferred[_Response]: """ See :func:`treq.put()`. """ - kwargs.setdefault('_stacklevel', 3) - return self.request('PUT', url, data=data, **kwargs) + kwargs.setdefault("_stacklevel", 3) + return self.request("PUT", url, data=data, **kwargs) - def patch(self, url, data=None, **kwargs): + def patch( + self, url: _URLType, data: Optional[_DataType] = None, **kwargs: Any + ) -> Deferred[_Response]: """ See :func:`treq.patch()`. """ - kwargs.setdefault('_stacklevel', 3) - return self.request('PATCH', url, data=data, **kwargs) + kwargs.setdefault("_stacklevel", 3) + return self.request("PATCH", url, data=data, **kwargs) - def post(self, url, data=None, **kwargs): + def post( + self, url: _URLType, data: Optional[_DataType] = None, **kwargs: Any + ) -> Deferred[_Response]: """ See :func:`treq.post()`. """ - kwargs.setdefault('_stacklevel', 3) - return self.request('POST', url, data=data, **kwargs) + kwargs.setdefault("_stacklevel", 3) + return self.request("POST", url, data=data, **kwargs) - def head(self, url, **kwargs): + def head(self, url: _URLType, **kwargs: Any) -> Deferred[_Response]: """ See :func:`treq.head()`. """ - kwargs.setdefault('_stacklevel', 3) - return self.request('HEAD', url, **kwargs) + kwargs.setdefault("_stacklevel", 3) + return self.request("HEAD", url, **kwargs) - def delete(self, url, **kwargs): + def delete(self, url: _URLType, **kwargs: Any) -> Deferred[_Response]: """ See :func:`treq.delete()`. """ - kwargs.setdefault('_stacklevel', 3) - return self.request('DELETE', url, **kwargs) + kwargs.setdefault("_stacklevel", 3) + return self.request("DELETE", url, **kwargs) def request( self, - method, - url, + method: str, + url: _URLType, *, - params=None, - headers=None, - data=None, - files=None, - json=_NOTHING, - auth=None, - cookies=None, - allow_redirects=True, - browser_like_redirects=False, - unbuffered=False, - reactor=None, - timeout=None, - _stacklevel=2, - ): + params: Optional[_ParamsType] = None, + headers: Optional[_HeadersType] = None, + data: Optional[_DataType] = None, + files: Optional[_FilesType] = None, + json: Union[_JSONType, _Nothing] = _NOTHING, + auth: Optional[tuple[Union[str, bytes], Union[str, bytes]]] = None, + cookies: Optional[_CookiesType] = None, + allow_redirects: bool = True, + browser_like_redirects: bool = False, + unbuffered: bool = False, + reactor: Optional[_ITreqReactor] = None, + timeout: Optional[float] = None, + _stacklevel: int = 2, + ) -> Deferred[_Response]: """ See :func:`treq.request()`. """ - method = method.encode('ascii').upper() + method_: bytes = method.encode("ascii").upper() if isinstance(url, DecodedURL): parsed_url = url.encoded_url @@ -228,7 +231,7 @@ def request( # bytes in the path and querystring. parsed_url = EncodedURL.from_text(url) else: - parsed_url = EncodedURL.from_text(url.decode('ascii')) + parsed_url = EncodedURL.from_text(url.decode("ascii")) # Join parameters provided in the URL # and the ones passed as argument. @@ -237,20 +240,21 @@ def request( query=parsed_url.query + tuple(_coerced_query_params(params)) ) - url = parsed_url.to_uri().to_text().encode('ascii') + url = parsed_url.to_uri().to_text().encode("ascii") headers = self._request_headers(headers, _stacklevel + 1) - bodyProducer, contentType = self._request_body(data, files, json, - stacklevel=_stacklevel + 1) + bodyProducer, contentType = self._request_body( + data, files, json, stacklevel=_stacklevel + 1 + ) if contentType is not None: - headers.setRawHeaders(b'Content-Type', [contentType]) + headers.setRawHeaders(b"Content-Type", [contentType]) if not isinstance(cookies, CookieJar): cookies = _scoped_cookiejar_from_dict(parsed_url, cookies) cookies = merge_cookies(self._cookiejar, cookies) - wrapped_agent = CookieAgent(self._agent, cookies) + wrapped_agent: IAgent = CookieAgent(self._agent, cookies) if allow_redirects: if browser_like_redirects: @@ -258,18 +262,19 @@ def request( else: wrapped_agent = RedirectAgent(wrapped_agent) - wrapped_agent = ContentDecoderAgent(wrapped_agent, - [(b'gzip', GzipDecoder)]) + wrapped_agent = ContentDecoderAgent(wrapped_agent, [(b"gzip", GzipDecoder)]) if auth: wrapped_agent = add_auth(wrapped_agent, auth) d = wrapped_agent.request( - method, url, headers=headers, - bodyProducer=bodyProducer) + method_, url, headers=headers, bodyProducer=bodyProducer + ) if reactor is None: - from twisted.internet import reactor + from twisted.internet import reactor # type: ignore + assert reactor is not None + if timeout: delayedCall = reactor.callLater(timeout, d.cancel) @@ -285,12 +290,11 @@ def gotResult(result): return d.addCallback(_Response, cookies) - def _request_headers(self, headers, stacklevel): + def _request_headers( + self, headers: Optional[_HeadersType], stacklevel: int + ) -> Headers: """ Convert the *headers* argument to a :class:`Headers` instance - - :returns: - :class:`twisted.web.http_headers.Headers` """ if isinstance(headers, dict): h = Headers({}) @@ -300,14 +304,10 @@ def _request_headers(self, headers, stacklevel): elif isinstance(v, list): h.setRawHeaders(k, v) else: - warnings.warn( - ( - "The value of headers key {!r} has non-string type {}" - " and will be dropped." - " This will raise TypeError in the next treq release." - ).format(k, type(v)), - DeprecationWarning, - stacklevel=stacklevel, + raise TypeError( + "The value of headers key {!r} has non-string type {}.".format( + k, type(v) + ) ) return h if isinstance(headers, Headers): @@ -315,18 +315,20 @@ def _request_headers(self, headers, stacklevel): if headers is None: return Headers({}) - warnings.warn( + raise TypeError( ( "headers must be a dict, twisted.web.http_headers.Headers, or None," - " but found {}, which will be ignored." - " This will raise TypeError in the next treq release." - ).format(type(headers)), - DeprecationWarning, - stacklevel=stacklevel, + " but found {}." + ).format(type(headers)) ) - return Headers({}) - def _request_body(self, data, files, json, stacklevel): + def _request_body( + self, + data: Optional[_DataType], + files: Optional[_FilesType], + json: Union[_JSONType, _Nothing], + stacklevel: int, + ) -> tuple[Optional[IBodyProducer], Optional[bytes]]: """ Here we choose a right producer based on the parameters passed in. @@ -354,38 +356,47 @@ def _request_body(self, data, files, json, stacklevel): JSON-encodable data, or the sentinel `_NOTHING`. The sentinel is necessary because ``None`` is a valid JSON value. """ - if json is not _NOTHING and (files or data): - warnings.warn( - ( - "Argument 'json' will be ignored because '{}' was also passed." - " This will raise TypeError in the next treq release." - ).format("data" if data else "files"), - DeprecationWarning, - stacklevel=stacklevel, + if json is not _NOTHING: + if files or data: + raise TypeError( + "Argument 'json' cannot be combined with '{}'.".format( + "data" if data else "files" + ) + ) + return ( + self._data_to_body_producer( + json_dumps(json, separators=(",", ":")).encode("utf-8"), + ), + b"application/json; charset=UTF-8", ) if files: # If the files keyword is present we will issue a # multipart/form-data request as it suits better for cases # with files and/or large objects. - files = list(_convert_files(files)) - boundary = str(uuid.uuid4()).encode('ascii') + fields: list[tuple[str, _FileValue]] = [] if data: - data = _convert_params(data) - else: - data = [] + for field in _convert_params(data): + fields.append(field) + for field in _convert_files(files): + fields.append(field) + boundary = str(uuid.uuid4()).encode("ascii") return ( - multipart.MultiPartProducer(data + files, boundary=boundary), - b'multipart/form-data; boundary=' + boundary, + multipart.MultiPartProducer(fields, boundary=boundary), + b"multipart/form-data; boundary=" + boundary, ) # Otherwise stick to x-www-form-urlencoded format # as it's generally faster for smaller requests. if isinstance(data, (dict, list, tuple)): return ( + # FIXME: The use of doseq here is not permitted in the types, and + # sequence values aren't supported in the files codepath. It is + # maintained here for backwards compatibility. See + # https://github.com/twisted/treq/issues/360. self._data_to_body_producer(urlencode(data, doseq=True)), - b'application/x-www-form-urlencoded', + b"application/x-www-form-urlencoded", ) elif data: return ( @@ -393,22 +404,13 @@ def _request_body(self, data, files, json, stacklevel): None, ) - if json is not _NOTHING: - return ( - self._data_to_body_producer( - json_dumps(json, separators=(u',', u':')).encode('utf-8'), - ), - b'application/json; charset=UTF-8', - ) - return None, None -def _convert_params(params): - if hasattr(params, "iteritems"): - return list(sorted(params.iteritems())) - elif hasattr(params, "items"): - return list(sorted(params.items())) +def _convert_params(params: _DataType) -> Iterable[tuple[str, str]]: + items_method = getattr(params, "items", None) + if items_method: + return list(sorted(items_method())) elif isinstance(params, (tuple, list)): return list(params) else: @@ -466,8 +468,7 @@ def _convert_files(files): yield (param, (file_name, content_type, IBodyProducer(fobj))) -def _query_quote(v): - # (Any) -> Text +def _query_quote(v: Any) -> str: """ Percent-encode a querystring name or value. @@ -485,7 +486,7 @@ def _query_quote(v): return q -def _coerced_query_params(params): +def _coerced_query_params(params: _ParamsType) -> Iterator[tuple[str, str]]: """ Carefully coerce *params* in the same way as `urllib.parse.urlencode()` @@ -501,10 +502,9 @@ def _coerced_query_params(params): :returns: A generator that yields two-tuples containing percent-encoded text strings. - :rtype: - Iterator[Tuple[Text, Text]] """ - if isinstance(params, Mapping): + items: Iterable[tuple[str, Union[str, tuple[str, ...], list[str]]]] + if isinstance(params, abc.Mapping): items = params.items() else: items = params @@ -518,20 +518,20 @@ def _coerced_query_params(params): yield key_quoted, _query_quote(value) -def _from_bytes(orig_bytes): +def _from_bytes(orig_bytes: bytes) -> IBodyProducer: return FileBodyProducer(io.BytesIO(orig_bytes)) -def _from_file(orig_file): +def _from_file(orig_file: Union[io.BytesIO, io.BufferedReader]) -> IBodyProducer: return FileBodyProducer(orig_file) -def _guess_content_type(filename): +def _guess_content_type(filename: str) -> Optional[str]: if filename: guessed = mimetypes.guess_type(filename)[0] else: guessed = None - return guessed or 'application/octet-stream' + return guessed or "application/octet-stream" registerAdapter(_from_bytes, bytes, IBodyProducer) diff --git a/src/treq/multipart.py b/src/treq/multipart.py index 5309a95c..40fd74e1 100644 --- a/src/treq/multipart.py +++ b/src/treq/multipart.py @@ -1,18 +1,31 @@ # Copyright (c) Twisted Matrix Laboratories. # See LICENSE for details. -from uuid import uuid4 -from io import BytesIO from contextlib import closing +from io import BytesIO +from typing import Any, Iterable, Literal, Mapping, Optional, Union, cast +from uuid import uuid4 -from twisted.internet import defer, task +from twisted.internet import task +from twisted.internet.defer import Deferred +from twisted.internet.interfaces import IConsumer +from twisted.python.failure import Failure from twisted.web.iweb import UNKNOWN_LENGTH, IBodyProducer - +from typing_extensions import TypeAlias from zope.interface import implementer +from treq._types import _S, _FilesType, _FileValue + CRLF = b"\r\n" +_Consumer: TypeAlias = "Union[IConsumer, _LengthConsumer]" +_UnknownLength = Literal["'twisted.web.iweb.UNKNOWN_LENGTH'"] +_Length: TypeAlias = Union[int, _UnknownLength] +_FieldValue = Union[bytes, tuple[str, str, IBodyProducer]] +_Field: TypeAlias = tuple[str, _FieldValue] + + @implementer(IBodyProducer) class MultiPartProducer: """ @@ -46,22 +59,31 @@ class MultiPartProducer: schedule all reads. :ivar boundary: The generated boundary used in form-data encoding - :type boundary: `bytes` """ - def __init__(self, fields, boundary=None, cooperator=task): - self._fields = list(_sorted_by_type(_converted(fields))) - self._currentProducer = None + length: _Length + boundary: bytes + _currentProducer: Optional[IBodyProducer] = None + _task: Optional[task.CooperativeTask] = None + + def __init__( + self, + fields: _FilesType, + boundary: Optional[Union[str, bytes]] = None, + cooperator: task.Cooperator = cast(task.Cooperator, task), + ) -> None: + self._fields = _sorted_by_type(_converted(fields)) self._cooperate = cooperator.cooperate - self.boundary = boundary or uuid4().hex - - if isinstance(self.boundary, str): - self.boundary = self.boundary.encode('ascii') + if not boundary: + boundary = uuid4().hex.encode("ascii") + if isinstance(boundary, str): + boundary = boundary.encode("ascii") + self.boundary = boundary self.length = self._calculateLength() - def startProducing(self, consumer): + def startProducing(self, consumer: IConsumer) -> Deferred[None]: """ Start a cooperative task which will read bytes from the input file and write them to `consumer`. Return a `Deferred` which fires after all @@ -69,29 +91,34 @@ def startProducing(self, consumer): :param consumer: Any `IConsumer` provider """ - self._task = self._cooperate(self._writeLoop(consumer)) - d = self._task.whenDone() + self._task = self._cooperate(self._writeLoop(consumer)) # type: ignore + # whenDone returns the iterator that was passed to cooperate, so who + # cares what type it has? It's an edge signal; we ignore its value. + d: Deferred[Any] = self._task.whenDone() - def maybeStopped(reason): + def maybeStopped(reason: Failure) -> Deferred: reason.trap(task.TaskStopped) - return defer.Deferred() - d.addCallbacks(lambda ignored: None, maybeStopped) + return Deferred() + + d = cast(Deferred[None], d.addCallbacks(lambda ignored: None, maybeStopped)) return d - def stopProducing(self): + def stopProducing(self) -> None: """ Permanently stop writing bytes from the file to the consumer by stopping the underlying `CooperativeTask`. """ + assert self._task is not None if self._currentProducer: self._currentProducer.stopProducing() self._task.stop() - def pauseProducing(self): + def pauseProducing(self) -> None: """ Temporarily suspend copying bytes from the input file to the consumer by pausing the `CooperativeTask` which drives that activity. """ + assert self._task is not None if self._currentProducer: # Having a current producer means that we are in # the paused state because we've returned @@ -103,18 +130,19 @@ def pauseProducing(self): else: self._task.pause() - def resumeProducing(self): + def resumeProducing(self) -> None: """ Undo the effects of a previous `pauseProducing` and resume copying bytes to the consumer by resuming the `CooperativeTask` which drives the write activity. """ + assert self._task is not None if self._currentProducer: self._currentProducer.resumeProducing() else: self._task.resume() - def _calculateLength(self): + def _calculateLength(self) -> _Length: """ Determine how many bytes the overall form post would consume. The easiest way is to calculate is to generate of `fObj` @@ -126,7 +154,7 @@ def _calculateLength(self): pass return consumer.length - def _getBoundary(self, final=False): + def _getBoundary(self, final: bool = False) -> bytes: """ Returns a boundary line, either final (the one that ends the form data request or a regular, the one that separates the boundaries) @@ -136,7 +164,7 @@ def _getBoundary(self, final=False): f = b"--" if final else b"" return b"--" + self.boundary + f - def _writeLoop(self, consumer): + def _writeLoop(self, consumer: _Consumer) -> Iterable[Optional[Deferred]]: """ Return an iterator which generates the multipart/form-data request including the encoded objects @@ -158,30 +186,36 @@ def _writeLoop(self, consumer): # but with CRLF characters before it and after the line. # This is very important. # proper boundary is "CRLF--boundary-valueCRLF" - consumer.write( - (CRLF if index != 0 else b"") + self._getBoundary() + CRLF) + consumer.write((CRLF if index != 0 else b"") + self._getBoundary() + CRLF) yield self._writeField(name, value, consumer) consumer.write(CRLF + self._getBoundary(final=True) + CRLF) - def _writeField(self, name, value, consumer): - if isinstance(value, str): + def _writeField( + self, name: str, value: _FieldValue, consumer: _Consumer + ) -> Optional[Deferred]: + if isinstance(value, bytes): self._writeString(name, value, consumer) - elif isinstance(value, tuple): + return None + else: filename, content_type, producer = value - return self._writeFile( - name, filename, content_type, producer, consumer) + return self._writeFile(name, filename, content_type, producer, consumer) - def _writeString(self, name, value, consumer): + def _writeString(self, name: str, value: bytes, consumer: _Consumer) -> None: cdisp = _Header(b"Content-Disposition", b"form-data") cdisp.add_param(b"name", name) consumer.write(bytes(cdisp) + CRLF + CRLF) - - encoded = value.encode("utf-8") - consumer.write(encoded) + consumer.write(value) self._currentProducer = None - def _writeFile(self, name, filename, content_type, producer, consumer): + def _writeFile( + self, + name: str, + filename: str, + content_type: str, + producer: IBodyProducer, + consumer: _Consumer, + ) -> Optional[Deferred[None]]: cdisp = _Header(b"Content-Disposition", b"form-data") cdisp.add_param(b"name", name) if filename: @@ -190,12 +224,12 @@ def _writeFile(self, name, filename, content_type, producer, consumer): consumer.write(bytes(cdisp) + CRLF) consumer.write(bytes(_Header(b"Content-Type", content_type)) + CRLF) if producer.length != UNKNOWN_LENGTH: - consumer.write( - bytes(_Header(b"Content-Length", producer.length)) + CRLF) + consumer.write(bytes(_Header(b"Content-Length", producer.length)) + CRLF) consumer.write(CRLF) if isinstance(consumer, _LengthConsumer): consumer.write(producer.length) + return None else: self._currentProducer = producer @@ -204,24 +238,21 @@ def unset(val): return val d = producer.startProducing(consumer) - d.addCallback(unset) - return d + return cast(Deferred[None], d.addCallback(unset)) -def _escape(value): +def _escape(value: Union[str, bytes]) -> str: """ This function prevents header values from corrupting the request, a newline in the file name parameter makes form-data request unreadable for majority of parsers. """ - if not isinstance(value, (bytes, str)): - value = str(value) if isinstance(value, bytes): - value = value.decode('utf-8') - return value.replace(u"\r", u"").replace(u"\n", u"").replace(u'"', u'\\"') + value = value.decode("utf-8") + return value.replace("\r", "").replace("\n", "").replace('"', '\\"') -def _enforce_unicode(value): +def _enforce_unicode(value: Any) -> str: """ This function enforces the strings passed to be unicode, so we won't need to guess what's the encoding of the binary strings passed in. @@ -238,38 +269,46 @@ def _enforce_unicode(value): return value.decode("utf-8") except UnicodeDecodeError: raise ValueError( - "Supplied raw bytes that are not ascii/utf-8." - " When supplying raw string make sure it's ascii or utf-8" - ", or work with unicode if you are not sure") + "Supplied raw bytes that are not ASCII/UTF-8." + " When supplying raw string make sure it's ASCII or UTF-8" + ", or work with unicode if you are not sure" + ) else: - raise ValueError( - "Unsupported field type: %s" % (value.__class__.__name__,)) + raise ValueError("Unsupported field type: %s" % (value.__class__.__name__,)) -def _converted(fields): - if hasattr(fields, "iteritems"): - fields = fields.iteritems() - elif hasattr(fields, "items"): - fields = fields.items() +def _converted(fields: _FilesType) -> Iterable[_Field]: + """ + Convert + """ + fields_: Iterable[tuple[str, _FileValue]] + if hasattr(fields, "items"): + assert isinstance(fields, Mapping) + fields_ = fields.items() + else: + fields_ = cast(Iterable[tuple[str, _FileValue]], fields) - for name, value in fields: + for name, value in fields_: name = _enforce_unicode(name) if isinstance(value, (tuple, list)): if len(value) != 3: - raise ValueError( - "Expected tuple: (filename, content type, producer)") + raise ValueError("Expected tuple: (filename, content type, producer)") filename, content_type, producer = value filename = _enforce_unicode(filename) if filename else None yield name, (filename, content_type, producer) - elif isinstance(value, (bytes, str)): - yield name, _enforce_unicode(value) + elif isinstance(value, str): + yield name, value.encode("utf-8") + + elif isinstance(value, bytes): + yield name, value else: raise ValueError( - "Unsupported value, expected string, unicode " - "or tuple (filename, content type, IBodyProducer)") + "Unsupported value, expected str, bytes, " + "or tuple (filename, content type, IBodyProducer)" + ) class _LengthConsumer: @@ -284,21 +323,23 @@ class _LengthConsumer: """ - def __init__(self): + length: _Length + + def __init__(self) -> None: self.length = 0 - def write(self, value): + def write(self, value: bytes) -> None: # this means that we have encountered # unknown length producer # so we need to stop attempts calculating - if self.length is UNKNOWN_LENGTH: + if self.length == UNKNOWN_LENGTH: return + assert isinstance(self.length, int) - if value is UNKNOWN_LENGTH: - self.length = value - elif isinstance(value, int): - self.length += value + if value == UNKNOWN_LENGTH: + self.length = cast(_UnknownLength, UNKNOWN_LENGTH) else: + assert isinstance(value, bytes) self.length += len(value) @@ -311,15 +352,21 @@ class because it encodes unicode fields using =? bla bla ?= that, everyone wants utf-8 raw bytes. """ - def __init__(self, name, value, params=None): + + def __init__( + self, + name: bytes, + value: _S, + params: Optional[list[tuple[_S, _S]]] = None, + ): self.name = name self.value = value self.params = params or [] - def add_param(self, name, value): + def add_param(self, name: _S, value: _S) -> None: self.params.append((name, value)) - def __bytes__(self): + def __bytes__(self) -> bytes: with closing(BytesIO()) as h: h.write(self.name + b": " + _escape(self.value).encode("us-ascii")) if self.params: @@ -327,24 +374,23 @@ def __bytes__(self): h.write(b"; ") h.write(_escape(name).encode("us-ascii")) h.write(b"=") - h.write(b'"' + _escape(val).encode('utf-8') + b'"') + h.write(b'"' + _escape(val).encode("utf-8") + b'"') h.seek(0) return h.read() - def __str__(self): - return self.__bytes__() - -def _sorted_by_type(fields): +def _sorted_by_type(fields: Iterable[_Field]) -> list[_Field]: """Sorts params so that strings are placed before files. That makes a request more readable, as generally files are bigger. It also provides deterministic order of fields what is easier for testing. """ + def key(p): key, val = p if isinstance(val, (bytes, str)): return (0, key) else: return (1, key) + return sorted(fields, key=key) diff --git a/src/treq/response.py b/src/treq/response.py index d13c3edb..8b87b326 100644 --- a/src/treq/response.py +++ b/src/treq/response.py @@ -1,13 +1,12 @@ -from twisted.python.components import proxyForInterface -from twisted.web.iweb import IResponse, UNKNOWN_LENGTH -from twisted.python import reflect - from requests.cookies import cookiejar_from_dict +from twisted.python import reflect +from twisted.python.components import proxyForInterface +from twisted.web.iweb import UNKNOWN_LENGTH, IResponse from treq.content import collect, content, json_content, text_content -class _Response(proxyForInterface(IResponse)): +class _Response(proxyForInterface(IResponse)): # type: ignore """ A wrapper for :class:`twisted.web.iweb.IResponse` which manages cookies and adds a few convenience methods. @@ -23,14 +22,15 @@ def __repr__(self): status code, Content-Type header, and body size, if available. """ if self.original.length == UNKNOWN_LENGTH: - size = 'unknown size' + size = "unknown size" else: - size = '{:,d} bytes'.format(self.original.length) + size = "{:,d} bytes".format(self.original.length) # Display non-ascii bits of the content-type header as backslash # escapes. - content_type_bytes = b', '.join( - self.original.headers.getRawHeaders(b'content-type', ())) - content_type = repr(content_type_bytes).lstrip('b')[1:-1] + content_type_bytes = b", ".join( + self.original.headers.getRawHeaders(b"content-type", ()) + ) + content_type = repr(content_type_bytes).lstrip("b")[1:-1] return "<{} {} '{:.40s}' {}>".format( reflect.qual(self.__class__), self.original.code, @@ -71,7 +71,7 @@ def json(self, **kwargs): """ return json_content(self.original, **kwargs) - def text(self, encoding='ISO-8859-1'): + def text(self, encoding="ISO-8859-1"): """ Read the entire body all at once as text, per :func:`treq.text_content()`. @@ -93,8 +93,7 @@ def history(self): history = [] while response.previousResponse is not None: - history.append(_Response(response.previousResponse, - self._cookiejar)) + history.append(_Response(response.previousResponse, self._cookiejar)) response = response.previousResponse history.reverse() diff --git a/src/treq/test/local_httpbin/child.py b/src/treq/test/local_httpbin/child.py index 4a6914e9..f6cb2216 100644 --- a/src/treq/test/local_httpbin/child.py +++ b/src/treq/test/local_httpbin/child.py @@ -8,7 +8,7 @@ import datetime import sys -import httpbin +import httpbin # type: ignore from twisted.internet.defer import Deferred, inlineCallbacks from twisted.internet.endpoints import TCP4ServerEndpoint, SSL4ServerEndpoint @@ -16,7 +16,7 @@ from twisted.internet.ssl import (Certificate, CertificateOptions) -from OpenSSL.crypto import PKey, X509 +from OpenSSL.crypto import PKey, X509 # type: ignore from twisted.python.threadpool import ThreadPool from twisted.web.server import Site diff --git a/src/treq/test/local_httpbin/test/test_child.py b/src/treq/test/local_httpbin/test/test_child.py index 2a08d4e3..406d8dc0 100644 --- a/src/treq/test/local_httpbin/test/test_child.py +++ b/src/treq/test/local_httpbin/test/test_child.py @@ -23,7 +23,7 @@ from twisted.web.server import Site from twisted.web.resource import Resource -from service_identity.cryptography import verify_certificate_hostname +from service_identity.cryptography import verify_certificate_hostname # type: ignore from .. import child, shared diff --git a/src/treq/test/test_agentspy.py b/src/treq/test/test_agentspy.py index 6103f6ae..4450dd89 100644 --- a/src/treq/test/test_agentspy.py +++ b/src/treq/test/test_agentspy.py @@ -15,7 +15,7 @@ class APISpyTests(SynchronousTestCase): The agent_spy API provides an agent that records each request made to it. """ - def test_provides_iagent(self): + def test_provides_iagent(self) -> None: """ The agent returned by agent_spy() provides the IAgent interface. """ @@ -23,7 +23,7 @@ def test_provides_iagent(self): self.assertTrue(IAgent.providedBy(agent)) - def test_records(self): + def test_records(self) -> None: """ Each request made with the agent is recorded. """ @@ -43,7 +43,7 @@ def test_records(self): ], ) - def test_record_attributes(self): + def test_record_attributes(self) -> None: """ Each parameter passed to `request` is available as an attribute of the RequestRecord. Additionally, the deferred returned by the call is @@ -62,15 +62,15 @@ def test_record_attributes(self): self.assertIs(rr.bodyProducer, body) self.assertIs(rr.deferred, deferred) - def test_type_validation(self): + def test_type_validation(self) -> None: """ The request method enforces correctness by raising TypeError when passed parameters of the wrong type. """ agent, _ = agent_spy() - self.assertRaises(TypeError, agent.request, u"method not bytes", b"uri") - self.assertRaises(TypeError, agent.request, b"method", u"uri not bytes") + self.assertRaises(TypeError, agent.request, "method not bytes", b"uri") + self.assertRaises(TypeError, agent.request, b"method", "uri not bytes") self.assertRaises( TypeError, agent.request, b"method", b"uri", {"not": "headers"} ) diff --git a/src/treq/test/test_api.py b/src/treq/test/test_api.py index 2d9a1a17..bcb6d803 100644 --- a/src/treq/test/test_api.py +++ b/src/treq/test/test_api.py @@ -1,13 +1,14 @@ from __future__ import absolute_import, division -from twisted.web.iweb import IAgent -from twisted.web.client import HTTPConnectionPool -from twisted.trial.unittest import TestCase from twisted.internet import defer +from twisted.trial.unittest import TestCase +from twisted.web.client import HTTPConnectionPool +from twisted.web.iweb import IAgent from zope.interface import implementer import treq -from treq.api import default_reactor, default_pool, set_global_pool, get_global_pool +from treq.api import (default_pool, default_reactor, get_global_pool, + set_global_pool) try: from twisted.internet.testing import MemoryReactorClock @@ -32,7 +33,7 @@ def getConnection(self, key, endpoint): class TreqAPITests(TestCase): - def test_default_pool(self): + def test_default_pool(self) -> None: """ The module-level API uses the global connection pool by default. """ @@ -44,7 +45,7 @@ def test_default_pool(self): self.assertEqual(pool.requests, 1) self.failureResultOf(d, TabError) - def test_cached_pool(self): + def test_cached_pool(self) -> None: """ The first use of the module-level API populates the global connection pool, which is used for all subsequent requests. @@ -61,7 +62,7 @@ def test_cached_pool(self): self.assertEqual(pool.requests, 6) - def test_custom_pool(self): + def test_custom_pool(self) -> None: """ `treq.post()` accepts a *pool* argument to use for the request. The global pool is unaffected. @@ -74,7 +75,7 @@ def test_custom_pool(self): self.failureResultOf(d, TabError) self.assertIsNot(pool, get_global_pool()) - def test_custom_agent(self): + def test_custom_agent(self) -> None: """ A custom Agent is used if specified. """ @@ -93,7 +94,7 @@ def request(self, method, uri, headers=None, bodyProducer=None): self.assertNoResult(d) self.assertEqual(1, custom_agent.requests) - def test_request_invalid_param(self): + def test_request_invalid_param(self) -> None: """ `treq.request()` warns that it ignores unknown keyword arguments, but this is deprecated. @@ -111,7 +112,7 @@ def test_request_invalid_param(self): self.assertIn("invalid", str(c.exception)) - def test_post_json_with_data(self): + def test_post_json_with_data(self) -> None: """ `treq.post()` warns that mixing *data* and *json* is deprecated. @@ -143,7 +144,7 @@ class DefaultReactorTests(TestCase): Test `treq.api.default_reactor()` """ - def test_passes_reactor(self): + def test_passes_reactor(self) -> None: """ `default_reactor()` returns any reactor passed. """ @@ -151,7 +152,7 @@ def test_passes_reactor(self): self.assertIs(default_reactor(reactor), reactor) - def test_uses_default_reactor(self): + def test_uses_default_reactor(self) -> None: """ `default_reactor()` returns the global reactor when passed ``None``. """ @@ -165,11 +166,11 @@ class DefaultPoolTests(TestCase): Test `treq.api.default_pool`. """ - def setUp(self): + def setUp(self) -> None: set_global_pool(None) self.reactor = MemoryReactorClock() - def test_persistent_false(self): + def test_persistent_false(self) -> None: """ When *persistent=False* is passed a non-persistent pool is created. """ @@ -178,7 +179,7 @@ def test_persistent_false(self): self.assertTrue(isinstance(pool, HTTPConnectionPool)) self.assertFalse(pool.persistent) - def test_persistent_false_not_stored(self): + def test_persistent_false_not_stored(self) -> None: """ When *persistent=False* is passed the resulting pool is not stored as the global pool. @@ -187,7 +188,7 @@ def test_persistent_false_not_stored(self): self.assertIsNot(pool, get_global_pool()) - def test_persistent_false_new(self): + def test_persistent_false_new(self) -> None: """ When *persistent=False* is passed a new pool is returned each time. """ @@ -196,7 +197,7 @@ def test_persistent_false_new(self): self.assertIsNot(pool1, pool2) - def test_pool_none_persistent_none(self): + def test_pool_none_persistent_none(self) -> None: """ When *persistent=None* is passed a _persistent_ pool is created for backwards compatibility. @@ -205,7 +206,7 @@ def test_pool_none_persistent_none(self): self.assertTrue(pool.persistent) - def test_pool_none_persistent_true(self): + def test_pool_none_persistent_true(self) -> None: """ When *persistent=True* is passed a persistent pool is created and stored as the global pool. @@ -215,7 +216,7 @@ def test_pool_none_persistent_true(self): self.assertTrue(isinstance(pool, HTTPConnectionPool)) self.assertTrue(pool.persistent) - def test_cached_global_pool(self): + def test_cached_global_pool(self) -> None: """ When *persistent=True* or *persistent=None* is passed the pool created is cached as the global pool. @@ -225,7 +226,7 @@ def test_cached_global_pool(self): self.assertEqual(pool1, pool2) - def test_specified_pool(self): + def test_specified_pool(self) -> None: """ When the user passes a pool it is returned directly. The *persistent* argument is ignored. It is not cached as the global pool. diff --git a/src/treq/testing.py b/src/treq/testing.py index 33d8c94d..627b29b6 100644 --- a/src/treq/testing.py +++ b/src/treq/testing.py @@ -206,6 +206,15 @@ def startProducing(self, consumer): consumer.write(self.body) return succeed(None) + def stopProducing(self): + raise NotImplementedError() + + def pauseProducing(self): + raise NotImplementedError() + + def resumeProducing(self): + raise NotImplementedError() + def _reject_files(f): """ From 2277b903de568ff43c812a260639a99c48027c4a Mon Sep 17 00:00:00 2001 From: Tom Most Date: Mon, 17 Apr 2023 20:46:28 -0700 Subject: [PATCH 05/21] Really clean --- src/treq/_agentspy.py | 22 +++++++++++++--------- src/treq/auth.py | 5 +++-- src/treq/test/test_api.py | 25 +++++++------------------ 3 files changed, 23 insertions(+), 29 deletions(-) diff --git a/src/treq/_agentspy.py b/src/treq/_agentspy.py index 5e939d1c..9ae2aa5f 100644 --- a/src/treq/_agentspy.py +++ b/src/treq/_agentspy.py @@ -21,11 +21,11 @@ class RequestRecord: :ivar deferred: The :class:`Deferred` returned by :meth:`IAgent.request` """ - method = attr.ib() # type: bytes - uri = attr.ib() # type: bytes - headers = attr.ib() # type: Optional[Headers] - bodyProducer = attr.ib() # type: Optional[IBodyProducer] - deferred = attr.ib() # type: Deferred[IResponse] + method: bytes = attr.field() + uri: bytes = attr.field() + headers: Optional[Headers] = attr.field() + bodyProducer: Optional[IBodyProducer] = attr.field() + deferred: Deferred[IResponse] = attr.field() @implementer(IAgent) @@ -40,8 +40,13 @@ class _AgentSpy: _callback: Callable[[RequestRecord], None] = attr.ib() - def request(self, method, uri, headers=None, bodyProducer=None): - # type: (bytes, bytes, Optional[Headers], Optional[IBodyProducer]) -> Deferred[IResponse] # noqa + def request( + self, + method: bytes, + uri: bytes, + headers: Optional[Headers] = None, + bodyProducer: Optional[IBodyProducer] = None, + ) -> Deferred[IResponse]: if not isinstance(method, bytes): raise TypeError( "method must be bytes, not {!r} of type {}".format(method, type(method)) @@ -69,8 +74,7 @@ def request(self, method, uri, headers=None, bodyProducer=None): return d -def agent_spy(): - # type: () -> Tuple[IAgent, List[RequestRecord]] +def agent_spy() -> Tuple[IAgent, List[RequestRecord]]: """ Record HTTP requests made with an agent diff --git a/src/treq/auth.py b/src/treq/auth.py index 3a778cea..dc8bb15e 100644 --- a/src/treq/auth.py +++ b/src/treq/auth.py @@ -46,8 +46,9 @@ def request(self, method, uri, headers=None, bodyProducer=None): method, uri, headers=requestHeaders, bodyProducer=bodyProducer) -def add_basic_auth(agent, username, password): - # type: (IAgent, Union[str, bytes], Union[str, bytes]) -> IAgent +def add_basic_auth( + agent: IAgent, username: Union[str, bytes], password: Union[str, bytes] +) -> IAgent: """ Wrap an agent to add HTTP basic authentication diff --git a/src/treq/test/test_api.py b/src/treq/test/test_api.py index bcb6d803..e11531cf 100644 --- a/src/treq/test/test_api.py +++ b/src/treq/test/test_api.py @@ -96,11 +96,8 @@ def request(self, method, uri, headers=None, bodyProducer=None): def test_request_invalid_param(self) -> None: """ - `treq.request()` warns that it ignores unknown keyword arguments, but - this is deprecated. - - This test verifies that stacklevel is set appropriately when issuing - the warning. + `treq.request()` raises `TypeError` when it receives unknown keyword + arguments. """ with self.assertRaises(TypeError) as c: treq.request( @@ -114,28 +111,20 @@ def test_request_invalid_param(self) -> None: def test_post_json_with_data(self) -> None: """ - `treq.post()` warns that mixing *data* and *json* is deprecated. - - This test verifies that stacklevel is set appropriately when issuing - the warning. + `treq.post()` raises TypeError when the *data* and *json* arguments + are mixed. """ - self.failureResultOf( + with self.assertRaises(TypeError) as c: treq.post( "https://test.example/", data={"hello": "world"}, json={"goodnight": "moon"}, pool=SyntacticAbominationHTTPConnectionPool(), ) - ) - [w] = self.flushWarnings([self.test_post_json_with_data]) - self.assertEqual(DeprecationWarning, w["category"]) self.assertEqual( - ( - "Argument 'json' will be ignored because 'data' was also passed." - " This will raise TypeError in the next treq release." - ), - w["message"], + "Argument 'json' cannot be combined with 'data'.", + str(c.exception), ) From 1288960612a0bc8434749984c0bb9bb6b63561a9 Mon Sep 17 00:00:00 2001 From: Tom Most Date: Mon, 17 Apr 2023 20:50:00 -0700 Subject: [PATCH 06/21] Update to MyPy 1.0.1 --- src/treq/multipart.py | 2 +- tox.ini | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/treq/multipart.py b/src/treq/multipart.py index 40fd74e1..08f40cf7 100644 --- a/src/treq/multipart.py +++ b/src/treq/multipart.py @@ -286,7 +286,7 @@ def _converted(fields: _FilesType) -> Iterable[_Field]: assert isinstance(fields, Mapping) fields_ = fields.items() else: - fields_ = cast(Iterable[tuple[str, _FileValue]], fields) + fields_ = fields for name, value in fields_: name = _enforce_unicode(name) diff --git a/tox.ini b/tox.ini index 5195cf0f..7da8d8f6 100644 --- a/tox.ini +++ b/tox.ini @@ -27,8 +27,8 @@ commands = [testenv:mypy] deps = - mypy==0.981 - mypy-zope==0.3.11 + mypy==1.0.1 + mypy-zope==0.9.1 types-requests commands = mypy \ From c48ea18fb7207e2d3b14d29b87f0a53d78c6934f Mon Sep 17 00:00:00 2001 From: Tom Most Date: Mon, 17 Apr 2023 21:37:09 -0700 Subject: [PATCH 07/21] Type some multipart tests --- src/treq/multipart.py | 13 +++++-- src/treq/test/test_multipart.py | 66 ++++++++++++++++++--------------- 2 files changed, 47 insertions(+), 32 deletions(-) diff --git a/src/treq/multipart.py b/src/treq/multipart.py index 08f40cf7..c4b881c8 100644 --- a/src/treq/multipart.py +++ b/src/treq/multipart.py @@ -224,7 +224,9 @@ def _writeFile( consumer.write(bytes(cdisp) + CRLF) consumer.write(bytes(_Header(b"Content-Type", content_type)) + CRLF) if producer.length != UNKNOWN_LENGTH: - consumer.write(bytes(_Header(b"Content-Length", producer.length)) + CRLF) + consumer.write( + bytes(_Header(b"Content-Length", str(producer.length))) + CRLF + ) consumer.write(CRLF) if isinstance(consumer, _LengthConsumer): @@ -279,7 +281,8 @@ def _enforce_unicode(value: Any) -> str: def _converted(fields: _FilesType) -> Iterable[_Field]: """ - Convert + Convert any of the multitude of formats we accept for the *fields* + parameter into the form we work with internally. """ fields_: Iterable[tuple[str, _FileValue]] if hasattr(fields, "items"): @@ -289,6 +292,8 @@ def _converted(fields: _FilesType) -> Iterable[_Field]: fields_ = fields for name, value in fields_: + # NOTE: While `name` is typed as `str` we still support UTF-8 `bytes` here + # for backward compatibility, thus this call to decode. name = _enforce_unicode(name) if isinstance(value, (tuple, list)): @@ -328,7 +333,7 @@ class _LengthConsumer: def __init__(self) -> None: self.length = 0 - def write(self, value: bytes) -> None: + def write(self, value: Union[bytes, _Length]) -> None: # this means that we have encountered # unknown length producer # so we need to stop attempts calculating @@ -338,6 +343,8 @@ def write(self, value: bytes) -> None: if value == UNKNOWN_LENGTH: self.length = cast(_UnknownLength, UNKNOWN_LENGTH) + elif isinstance(value, int): + self.length += value else: assert isinstance(value, bytes) self.length += len(value) diff --git a/src/treq/test/test_multipart.py b/src/treq/test/test_multipart.py index 7736fbd9..b703d0dc 100644 --- a/src/treq/test/test_multipart.py +++ b/src/treq/test/test_multipart.py @@ -3,6 +3,7 @@ import cgi import sys +from typing import cast, AnyStr from io import BytesIO @@ -10,6 +11,7 @@ from zope.interface.verify import verifyObject from twisted.internet import task +from twisted.internet.testing import StringTransport from twisted.web.client import FileBodyProducer from twisted.web.iweb import UNKNOWN_LENGTH, IBodyProducer @@ -57,7 +59,7 @@ def getOutput(self, producer, with_producer=False): else: return output.getvalue() - def newLines(self, value): + def newLines(self, value: AnyStr) -> AnyStr: if isinstance(value, str): return value.replace(u"\n", u"\r\n") @@ -72,7 +74,7 @@ def test_interface(self): verifyObject( IBodyProducer, MultiPartProducer({}))) - def test_unknownLength(self): + def test_unknownLength(self) -> None: """ If the L{MultiPartProducer} is constructed with a file-like object passed as a parameter without either a C{seek} or C{tell} method, @@ -93,14 +95,14 @@ def tell(self): """ producer = MultiPartProducer( - {"f": ("name", None, FileBodyProducer(CantTell()))}) + {"f": ("name", "application/octet-stream", FileBodyProducer(CantTell()))}) self.assertEqual(UNKNOWN_LENGTH, producer.length) producer = MultiPartProducer( - {"f": ("name", None, FileBodyProducer(CantSeek()))}) + {"f": ("name", "application/octet-stream", FileBodyProducer(CantSeek()))}) self.assertEqual(UNKNOWN_LENGTH, producer.length) - def test_knownLengthOnFile(self): + def test_knownLengthOnFile(self) -> None: """ If the L{MultiPartProducer} is constructed with a file-like object with both C{seek} and C{tell} methods, its C{length} attribute is set to the @@ -110,7 +112,7 @@ def test_knownLengthOnFile(self): inputFile = BytesIO(inputBytes) inputFile.seek(5) producer = MultiPartProducer({ - "field": ('file name', None, FileBodyProducer( + "field": ('file name', "application/octet-stream", FileBodyProducer( inputFile, cooperator=self.cooperator))}) # Make sure we are generous enough not to alter seek position: @@ -119,33 +121,39 @@ def test_knownLengthOnFile(self): # Total length is hard to calculate manually # as it contains a lot of headers parameters, newlines and boundaries # let's assert for now that it's no less than the input parameter - self.assertTrue(producer.length > len(inputBytes)) + self.assertNotEqual(producer.length, UNKNOWN_LENGTH) + self.assertTrue(cast(int, producer.length) > len(inputBytes)) # Calculating length should not touch producers self.assertTrue(producer._currentProducer is None) - def test_defaultCooperator(self): + def test_defaultCooperator(self) -> None: """ If no L{Cooperator} instance is passed to L{MultiPartProducer}, the global cooperator is used. """ producer = MultiPartProducer({ - "field": ('file name', None, FileBodyProducer( + "field": ("file name", "application/octet-stream", FileBodyProducer( BytesIO(b"yo"), cooperator=self.cooperator)) }) self.assertEqual(task.cooperate, producer._cooperate) - def test_startProducing(self): + def test_startProducing(self) -> None: """ L{MultiPartProducer.startProducing} starts writing bytes from the input file to the given L{IConsumer} and returns a L{Deferred} which fires when they have all been written. """ - consumer = output = BytesIO() + consumer = output = StringTransport() + + # We historically accepted bytes for field names and continue to allow + # it for compatibility, but the types don't permit it because it makes + # them even more complicated and awful. So here we verify that that works. + field = cast(str, b"field") producer = MultiPartProducer({ - b"field": ('file name', "text/hello-world", FileBodyProducer( + field: ("file name", "text/hello-world", FileBodyProducer( BytesIO(b"Hello, World"), cooperator=self.cooperator)) }, cooperator=self.cooperator, boundary=b"heyDavid") @@ -165,16 +173,16 @@ def test_startProducing(self): Hello, World --heyDavid-- -"""), output.getvalue()) +"""), output.value()) self.assertEqual(None, self.successResultOf(complete)) - def test_inputClosedAtEOF(self): + def test_inputClosedAtEOF(self) -> None: """ When L{MultiPartProducer} reaches end-of-file on the input file given to it, the input file is closed. """ inputFile = BytesIO(b"hello, world!") - consumer = BytesIO() + consumer = StringTransport() producer = MultiPartProducer({ "field": ( @@ -192,7 +200,7 @@ def test_inputClosedAtEOF(self): self.assertTrue(inputFile.closed) - def test_failedReadWhileProducing(self): + def test_failedReadWhileProducing(self) -> None: """ If a read from the input file fails while producing bytes to the consumer, the L{Deferred} returned by @@ -212,7 +220,7 @@ def read(self, count): cooperator=self.cooperator)) }, cooperator=self.cooperator, boundary=b"heyDavid") - complete = producer.startProducing(BytesIO()) + complete = producer.startProducing(StringTransport()) while self._scheduled: self._scheduled.pop(0)() @@ -244,13 +252,13 @@ def test_stopProducing(self): self._scheduled.pop(0)() self.assertNoResult(complete) - def test_pauseProducing(self): + def test_pauseProducing(self) -> None: """ L{MultiPartProducer.pauseProducing} temporarily suspends writing bytes from the input file to the given L{IConsumer}. """ inputFile = BytesIO(b"hello, world!") - consumer = output = BytesIO() + consumer = output = StringTransport() producer = MultiPartProducer({ "field": ( @@ -263,7 +271,7 @@ def test_pauseProducing(self): complete = producer.startProducing(consumer) self._scheduled.pop(0)() - currentValue = output.getvalue() + currentValue = output.value() self.assertTrue(currentValue) producer.pauseProducing() @@ -274,17 +282,17 @@ def test_pauseProducing(self): self._scheduled.pop(0)() # Since the producer is paused, no new data should be here. - self.assertEqual(output.getvalue(), currentValue) + self.assertEqual(output.value(), currentValue) self.assertNoResult(complete) - def test_resumeProducing(self): + def test_resumeProducing(self) -> None: """ L{MultoPartProducer.resumeProducing} re-commences writing bytes from the input file to the given L{IConsumer} after it was previously paused with L{MultiPartProducer.pauseProducing}. """ inputFile = BytesIO(b"hello, world!") - consumer = output = BytesIO() + consumer = output = StringTransport() producer = MultiPartProducer({ "field": ( @@ -297,15 +305,15 @@ def test_resumeProducing(self): producer.startProducing(consumer) self._scheduled.pop(0)() - currentValue = output.getvalue() + currentValue = output.value() self.assertTrue(currentValue) producer.pauseProducing() producer.resumeProducing() self._scheduled.pop(0)() # make sure we started producing new data after resume - self.assertTrue(len(currentValue) < len(output.getvalue())) + self.assertTrue(len(currentValue) < len(output.value())) - def test_unicodeString(self): + def test_unicodeString(self) -> None: """ Make sure unicode string is passed properly """ @@ -325,7 +333,7 @@ def test_unicodeString(self): self.assertEqual(producer.length, len(expected)) self.assertEqual(expected, output) - def test_failOnByteStrings(self): + def test_failOnByteStrings(self) -> None: """ If byte string is passed as a param and we don't know the encoding, fail early to prevent corrupted form posts @@ -337,7 +345,7 @@ def test_failOnByteStrings(self): }, cooperator=self.cooperator, boundary=b"heyDavid") - def test_failOnUnknownParams(self): + def test_failOnUnknownParams(self) -> None: """ If byte string is passed as a param and we don't know the encoding, fail early to prevent corrupted form posts @@ -366,7 +374,7 @@ def test_failOnUnknownParams(self): }, cooperator=self.cooperator, boundary=b"heyDavid") - def test_twoFields(self): + def test_twoFields(self) -> None: """ Make sure multiple fields are rendered properly. """ From 80f778123265ddc36f8bff9bec57978947ea5143 Mon Sep 17 00:00:00 2001 From: Tom Most Date: Mon, 17 Apr 2023 22:05:12 -0700 Subject: [PATCH 08/21] Drop deprecated behaviors --- changelog.d/297.removal.rst | 1 + changelog.d/302.removal.rst | 1 + src/treq/test/test_client.py | 83 +++++++++++------------------------- 3 files changed, 26 insertions(+), 59 deletions(-) create mode 100644 changelog.d/297.removal.rst create mode 100644 changelog.d/302.removal.rst diff --git a/changelog.d/297.removal.rst b/changelog.d/297.removal.rst new file mode 100644 index 00000000..f0cc78a1 --- /dev/null +++ b/changelog.d/297.removal.rst @@ -0,0 +1 @@ +Mixing the *json* argument with *files* or *data* now raises `TypeError`. diff --git a/changelog.d/302.removal.rst b/changelog.d/302.removal.rst new file mode 100644 index 00000000..81b8af11 --- /dev/null +++ b/changelog.d/302.removal.rst @@ -0,0 +1 @@ +Passing non-string (`str` or `bytes`) values as part of a dict to the *headers* argument now results in a `TypeError`, as does passing any collection other than a `dict` or `Headers` instance. diff --git a/src/treq/test/test_client.py b/src/treq/test/test_client.py index 52897ce6..0a650042 100644 --- a/src/treq/test/test_client.py +++ b/src/treq/test/test_client.py @@ -541,46 +541,28 @@ def test_request_unsupported_params_combination(self): def test_request_json_with_data(self): """ Passing `HTTPClient.request()` both *data* and *json* parameters is - invalid because *json* is ignored. This behavior is deprecated. + invalid because they conflict, producing `TypeError`. """ - self.client.request( - "POST", - "http://example.com/", - data=BytesIO(b"..."), - json=None, # NB: None is a valid value. It encodes to b'null'. - ) - - [w] = self.flushWarnings([self.test_request_json_with_data]) - self.assertEqual(DeprecationWarning, w["category"]) - self.assertEqual( - ( - "Argument 'json' will be ignored because 'data' was also passed." - " This will raise TypeError in the next treq release." - ), - w['message'], - ) + with self.assertRaises(TypeError): + self.client.request( + "POST", + "http://example.com/", + data=BytesIO(b"..."), + json=None, # NB: None is a valid value. It encodes to b'null'. + ) def test_request_json_with_files(self): """ Passing `HTTPClient.request()` both *files* and *json* parameters is - invalid because *json* is ignored. This behavior is deprecated. + invalid because they confict, producing `TypeError`. """ - self.client.request( - "POST", - "http://example.com/", - files={"f1": ("foo.txt", "text/plain", BytesIO(b"...\n"))}, - json=["this is ignored"], - ) - - [w] = self.flushWarnings([self.test_request_json_with_files]) - self.assertEqual(DeprecationWarning, w["category"]) - self.assertEqual( - ( - "Argument 'json' will be ignored because 'files' was also passed." - " This will raise TypeError in the next treq release." - ), - w['message'], - ) + with self.assertRaises(TypeError): + self.client.request( + "POST", + "http://example.com/", + files={"f1": ("foo.txt", "text/plain", BytesIO(b"...\n"))}, + json=["this is ignored"], + ) def test_request_dict_headers(self): self.client.request('GET', 'http://example.com/', headers={ @@ -621,37 +603,20 @@ def test_request_headers_invalid_type(self): `HTTPClient.request()` warns that headers of an unexpected type are invalid and that this behavior is deprecated. """ - self.client.request('GET', 'http://example.com', headers=[]) - - [w] = self.flushWarnings([self.test_request_headers_invalid_type]) - self.assertEqual(DeprecationWarning, w['category']) - self.assertIn( - "headers must be a dict, twisted.web.http_headers.Headers, or None,", - w['message'], - ) + with self.assertRaises(TypeError): + self.client.request('GET', 'http://example.com', headers=[]) def test_request_dict_headers_invalid_values(self): """ `HTTPClient.request()` warns that non-string header values are dropped and that this behavior is deprecated. """ - self.client.request('GET', 'http://example.com', headers=OrderedDict([ - ('none', None), - ('one', 1), - ('ok', 'string'), - ])) - - [w1, w2] = self.flushWarnings([self.test_request_dict_headers_invalid_values]) - self.assertEqual(DeprecationWarning, w1['category']) - self.assertEqual(DeprecationWarning, w2['category']) - self.assertIn( - "The value of headers key 'none' has non-string type", - w1['message'], - ) - self.assertIn( - "The value of headers key 'one' has non-string type", - w2['message'], - ) + with self.assertRaises(TypeError): + self.client.request('GET', 'http://example.com', headers=OrderedDict([ + ('none', None), + ('one', 1), + ('ok', 'string'), + ])) def test_request_invalid_param(self): """ From d777c00df6352ce9f7466b42c07f137dc109d727 Mon Sep 17 00:00:00 2001 From: Tom Most Date: Mon, 17 Apr 2023 22:21:52 -0700 Subject: [PATCH 09/21] Allow bytes to pass through I think that this behavior made more sense in the Python 2 days, but now it's just a hazard. --- src/treq/test/test_multipart.py | 27 ++++++++++++++++++--------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/src/treq/test/test_multipart.py b/src/treq/test/test_multipart.py index b703d0dc..5cbaade5 100644 --- a/src/treq/test/test_multipart.py +++ b/src/treq/test/test_multipart.py @@ -333,17 +333,26 @@ def test_unicodeString(self) -> None: self.assertEqual(producer.length, len(expected)) self.assertEqual(expected, output) - def test_failOnByteStrings(self) -> None: + def test_bytesPassThrough(self) -> None: """ - If byte string is passed as a param and we don't know - the encoding, fail early to prevent corrupted form posts + If byte string is passed as a param it is passed through + unchanged. """ - self.assertRaises( - ValueError, - MultiPartProducer, { - "afield": u"это моя строчечка".encode("utf-32"), - }, - cooperator=self.cooperator, boundary=b"heyDavid") + output, producer = self.getOutput( + MultiPartProducer({ + "bfield": b'\x00\x01\x02\x03', + }, cooperator=self.cooperator, boundary=b"heyDavid"), + with_producer=True) + + expected = ( + b"--heyDavid\r\n" + b'Content-Disposition: form-data; name="bfield"\r\n' + b'\r\n' + b'\x00\x01\x02\x03\r\n' + b'--heyDavid--\r\n' + ) + self.assertEqual(producer.length, len(expected)) + self.assertEqual(expected, output) def test_failOnUnknownParams(self) -> None: """ From da9a4643564d7bfe2e0da5d5a2c50423bf07a24d Mon Sep 17 00:00:00 2001 From: Tom Most Date: Mon, 17 Apr 2023 22:29:12 -0700 Subject: [PATCH 10/21] Run MyPy in CI --- .github/workflows/ci.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 4fbc305f..eb59e291 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -29,7 +29,7 @@ jobs: - run: python -m pip install 'tox<4' - - run: tox -q -p all -e flake8,towncrier,twine,check-manifest + - run: tox -q -p all -e flake8,towncrier,twine,check-manifest,mypy docs: runs-on: ubuntu-20.04 From bbd077572d2773c4e2e8cd0814deccc442d615c3 Mon Sep 17 00:00:00 2001 From: Tom Most Date: Mon, 17 Apr 2023 22:33:43 -0700 Subject: [PATCH 11/21] Add change fragment --- changelog.d/366.feature.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog.d/366.feature.rst diff --git a/changelog.d/366.feature.rst b/changelog.d/366.feature.rst new file mode 100644 index 00000000..7fc34540 --- /dev/null +++ b/changelog.d/366.feature.rst @@ -0,0 +1 @@ +treq now ships type annotations. From 9636b5224388bd20fac271e6e7cccc03f645d672 Mon Sep 17 00:00:00 2001 From: Tom Most Date: Mon, 17 Apr 2023 22:38:58 -0700 Subject: [PATCH 12/21] Run MyPy on Python 3.8 --- .github/workflows/ci.yaml | 4 +++- tox.ini | 1 + 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index eb59e291..025bf087 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -30,6 +30,8 @@ jobs: - run: python -m pip install 'tox<4' - run: tox -q -p all -e flake8,towncrier,twine,check-manifest,mypy + env: + TOX_PARALLEL_NO_SPINNER: 1 docs: runs-on: ubuntu-20.04 @@ -40,7 +42,7 @@ jobs: - uses: actions/setup-python@v2 with: - python-version: "3.8" + python-version: "3.11" - uses: actions/cache@v2 with: diff --git a/tox.ini b/tox.ini index 7da8d8f6..4561f6e2 100644 --- a/tox.ini +++ b/tox.ini @@ -26,6 +26,7 @@ commands = {envbindir}/trial {posargs:treq} [testenv:mypy] +basepython = python3.8 deps = mypy==1.0.1 mypy-zope==0.9.1 From 9e51b02cb8ec0adc5674f5376d0e7994e8fef115 Mon Sep 17 00:00:00 2001 From: Tom Most Date: Mon, 17 Apr 2023 22:44:56 -0700 Subject: [PATCH 13/21] Backport annotations to 3.8 --- src/treq/_agentspy.py | 2 +- src/treq/_types.py | 20 +++++++++++--------- src/treq/client.py | 14 +++++++------- src/treq/multipart.py | 12 ++++++------ 4 files changed, 25 insertions(+), 23 deletions(-) diff --git a/src/treq/_agentspy.py b/src/treq/_agentspy.py index 9ae2aa5f..65cbafcd 100644 --- a/src/treq/_agentspy.py +++ b/src/treq/_agentspy.py @@ -91,6 +91,6 @@ def agent_spy() -> Tuple[IAgent, List[RequestRecord]]: - A list of calls made to the agent's :meth:`~twisted.web.iweb.IAgent.request()` method """ - records: list[RequestRecord] = [] + records: List[RequestRecord] = [] agent = _AgentSpy(records.append) return agent, records diff --git a/src/treq/_types.py b/src/treq/_types.py index 698147be..b758f5e9 100644 --- a/src/treq/_types.py +++ b/src/treq/_types.py @@ -1,6 +1,8 @@ +# Copyright (c) The treq Authors. +# See LICENSE for details. import io from http.cookiejar import CookieJar -from typing import Any, Iterable, Mapping, Union +from typing import Any, Dict, Iterable, List, Mapping, Tuple, Union from hyperlink import DecodedURL, EncodedURL from twisted.internet.interfaces import (IReactorPluggableNameResolver, @@ -35,14 +37,14 @@ class _ITreqReactor(IReactorTCP, IReactorTime, IReactorPluggableNameResolver): ] _ParamsType = Union[ - Mapping[str, Union[str, tuple[str, ...], list[str]]], - list[tuple[str, str]], + Mapping[str, Union[str, Tuple[str, ...], List[str]]], + List[Tuple[str, str]], ] _HeadersType = Union[ Headers, - dict[_S, _S], - dict[_S, list[_S]], + Dict[_S, _S], + Dict[_S, List[_S]], ] _CookiesType = Union[ @@ -64,8 +66,8 @@ class _ITreqReactor(IReactorTCP, IReactorTime, IReactorPluggableNameResolver): # Concrete types are used here because the handling of the *data* parameter # does lots of isinstance checks. _BodyFields = Union[ - dict[str, str], - list[tuple[str, str]], + Dict[str, str], + List[Tuple[str, str]], ] """ Types that will be URL- or multipart-encoded before being sent as part of the @@ -83,7 +85,7 @@ class _ITreqReactor(IReactorTCP, IReactorTime, IReactorPluggableNameResolver): _FileValue = Union[ str, bytes, - tuple[str, str, IBodyProducer], + Tuple[str, str, IBodyProducer], ] """ Either a scalar string, or a file to upload as (filename, content type, @@ -92,7 +94,7 @@ class _ITreqReactor(IReactorTCP, IReactorTime, IReactorPluggableNameResolver): _FilesType = Union[ Mapping[str, _FileValue], - Iterable[tuple[str, _FileValue]], + Iterable[Tuple[str, _FileValue]], ] """ Values accepted for the *files* parameter. diff --git a/src/treq/client.py b/src/treq/client.py index 336c73aa..2b6ea17d 100644 --- a/src/treq/client.py +++ b/src/treq/client.py @@ -4,7 +4,7 @@ from collections import abc from http.cookiejar import Cookie, CookieJar from json import dumps as json_dumps -from typing import Any, Callable, Iterable, Iterator, Mapping, Optional, Union +from typing import Any, Callable, Iterable, Iterator, List, Mapping, Optional, Tuple, Union from urllib.parse import quote_plus from urllib.parse import urlencode as _urlencode @@ -208,7 +208,7 @@ def request( data: Optional[_DataType] = None, files: Optional[_FilesType] = None, json: Union[_JSONType, _Nothing] = _NOTHING, - auth: Optional[tuple[Union[str, bytes], Union[str, bytes]]] = None, + auth: Optional[Tuple[Union[str, bytes], Union[str, bytes]]] = None, cookies: Optional[_CookiesType] = None, allow_redirects: bool = True, browser_like_redirects: bool = False, @@ -328,7 +328,7 @@ def _request_body( files: Optional[_FilesType], json: Union[_JSONType, _Nothing], stacklevel: int, - ) -> tuple[Optional[IBodyProducer], Optional[bytes]]: + ) -> Tuple[Optional[IBodyProducer], Optional[bytes]]: """ Here we choose a right producer based on the parameters passed in. @@ -374,7 +374,7 @@ def _request_body( # If the files keyword is present we will issue a # multipart/form-data request as it suits better for cases # with files and/or large objects. - fields: list[tuple[str, _FileValue]] = [] + fields: List[Tuple[str, _FileValue]] = [] if data: for field in _convert_params(data): fields.append(field) @@ -407,7 +407,7 @@ def _request_body( return None, None -def _convert_params(params: _DataType) -> Iterable[tuple[str, str]]: +def _convert_params(params: _DataType) -> Iterable[Tuple[str, str]]: items_method = getattr(params, "items", None) if items_method: return list(sorted(items_method())) @@ -486,7 +486,7 @@ def _query_quote(v: Any) -> str: return q -def _coerced_query_params(params: _ParamsType) -> Iterator[tuple[str, str]]: +def _coerced_query_params(params: _ParamsType) -> Iterator[Tuple[str, str]]: """ Carefully coerce *params* in the same way as `urllib.parse.urlencode()` @@ -503,7 +503,7 @@ def _coerced_query_params(params: _ParamsType) -> Iterator[tuple[str, str]]: A generator that yields two-tuples containing percent-encoded text strings. """ - items: Iterable[tuple[str, Union[str, tuple[str, ...], list[str]]]] + items: Iterable[Tuple[str, Union[str, Tuple[str, ...], List[str]]]] if isinstance(params, abc.Mapping): items = params.items() else: diff --git a/src/treq/multipart.py b/src/treq/multipart.py index c4b881c8..e103103f 100644 --- a/src/treq/multipart.py +++ b/src/treq/multipart.py @@ -3,7 +3,7 @@ from contextlib import closing from io import BytesIO -from typing import Any, Iterable, Literal, Mapping, Optional, Union, cast +from typing import Any, Iterable, List, Literal, Mapping, Optional, Tuple, Union, cast from uuid import uuid4 from twisted.internet import task @@ -22,8 +22,8 @@ _Consumer: TypeAlias = "Union[IConsumer, _LengthConsumer]" _UnknownLength = Literal["'twisted.web.iweb.UNKNOWN_LENGTH'"] _Length: TypeAlias = Union[int, _UnknownLength] -_FieldValue = Union[bytes, tuple[str, str, IBodyProducer]] -_Field: TypeAlias = tuple[str, _FieldValue] +_FieldValue = Union[bytes, Tuple[str, str, IBodyProducer]] +_Field: TypeAlias = Tuple[str, _FieldValue] @implementer(IBodyProducer) @@ -284,7 +284,7 @@ def _converted(fields: _FilesType) -> Iterable[_Field]: Convert any of the multitude of formats we accept for the *fields* parameter into the form we work with internally. """ - fields_: Iterable[tuple[str, _FileValue]] + fields_: Iterable[Tuple[str, _FileValue]] if hasattr(fields, "items"): assert isinstance(fields, Mapping) fields_ = fields.items() @@ -364,7 +364,7 @@ def __init__( self, name: bytes, value: _S, - params: Optional[list[tuple[_S, _S]]] = None, + params: Optional[List[Tuple[_S, _S]]] = None, ): self.name = name self.value = value @@ -386,7 +386,7 @@ def __bytes__(self) -> bytes: return h.read() -def _sorted_by_type(fields: Iterable[_Field]) -> list[_Field]: +def _sorted_by_type(fields: Iterable[_Field]) -> List[_Field]: """Sorts params so that strings are placed before files. That makes a request more readable, as generally files are bigger. From e386f03e2ea30dd731f9b9f9c730b4172ffd226e Mon Sep 17 00:00:00 2001 From: Tom Most Date: Mon, 17 Apr 2023 22:46:49 -0700 Subject: [PATCH 14/21] Nevermind mypy.ini, long live pyproject.toml --- MANIFEST.in | 1 - 1 file changed, 1 deletion(-) diff --git a/MANIFEST.in b/MANIFEST.in index d1242850..6bfd3473 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -9,7 +9,6 @@ prune docs/_build prune docs/html exclude tox.ini -exclude mypy.ini exclude .github exclude .readthedocs.yml From a0c764a4f0625429e8fc0b87cd14d1e2e1560d14 Mon Sep 17 00:00:00 2001 From: Tom Most Date: Mon, 17 Apr 2023 22:52:10 -0700 Subject: [PATCH 15/21] Fix lint --- src/treq/client.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/treq/client.py b/src/treq/client.py index 2b6ea17d..88929b50 100644 --- a/src/treq/client.py +++ b/src/treq/client.py @@ -4,7 +4,8 @@ from collections import abc from http.cookiejar import Cookie, CookieJar from json import dumps as json_dumps -from typing import Any, Callable, Iterable, Iterator, List, Mapping, Optional, Tuple, Union +from typing import (Any, Callable, Iterable, Iterator, List, Mapping, + Optional, Tuple, Union) from urllib.parse import quote_plus from urllib.parse import urlencode as _urlencode From c09df9f97f4a14210322cbb569214efd0fe79e16 Mon Sep 17 00:00:00 2001 From: Tom Most Date: Fri, 28 Apr 2023 22:52:52 -0700 Subject: [PATCH 16/21] Literally fix it? --- src/treq/multipart.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/treq/multipart.py b/src/treq/multipart.py index e103103f..8bb10219 100644 --- a/src/treq/multipart.py +++ b/src/treq/multipart.py @@ -3,7 +3,7 @@ from contextlib import closing from io import BytesIO -from typing import Any, Iterable, List, Literal, Mapping, Optional, Tuple, Union, cast +from typing import Any, Iterable, List, Mapping, Optional, Tuple, Union, cast from uuid import uuid4 from twisted.internet import task @@ -11,7 +11,7 @@ from twisted.internet.interfaces import IConsumer from twisted.python.failure import Failure from twisted.web.iweb import UNKNOWN_LENGTH, IBodyProducer -from typing_extensions import TypeAlias +from typing_extensions import TypeAlias, Literal from zope.interface import implementer from treq._types import _S, _FilesType, _FileValue From a5799b6deb9c36a7f9c940ddf21b97d340e8649a Mon Sep 17 00:00:00 2001 From: Tom Most Date: Sun, 30 Apr 2023 16:01:52 -0700 Subject: [PATCH 17/21] Quote Deferred annotations This is necessary on Python 3.7 because type is not subscriptable. --- src/treq/_agentspy.py | 6 +++--- src/treq/client.py | 14 +++++++------- src/treq/multipart.py | 12 ++++++------ 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/treq/_agentspy.py b/src/treq/_agentspy.py index 65cbafcd..17c80d26 100644 --- a/src/treq/_agentspy.py +++ b/src/treq/_agentspy.py @@ -25,7 +25,7 @@ class RequestRecord: uri: bytes = attr.field() headers: Optional[Headers] = attr.field() bodyProducer: Optional[IBodyProducer] = attr.field() - deferred: Deferred[IResponse] = attr.field() + deferred: "Deferred[IResponse]" = attr.field() @implementer(IAgent) @@ -46,7 +46,7 @@ def request( uri: bytes, headers: Optional[Headers] = None, bodyProducer: Optional[IBodyProducer] = None, - ) -> Deferred[IResponse]: + ) -> "Deferred[IResponse]": if not isinstance(method, bytes): raise TypeError( "method must be bytes, not {!r} of type {}".format(method, type(method)) @@ -68,7 +68,7 @@ def request( " Is the implementation marked with @implementer(IBodyProducer)?" ).format(bodyProducer) ) - d: Deferred[IResponse] = Deferred() + d: "Deferred[IResponse]" = Deferred() record = RequestRecord(method, uri, headers, bodyProducer, d) self._callback(record) return d diff --git a/src/treq/client.py b/src/treq/client.py index 88929b50..5a029372 100644 --- a/src/treq/client.py +++ b/src/treq/client.py @@ -151,7 +151,7 @@ def __init__( self._cookiejar = cookiejar self._data_to_body_producer = data_to_body_producer - def get(self, url: _URLType, **kwargs: Any) -> Deferred[_Response]: + def get(self, url: _URLType, **kwargs: Any) -> "Deferred[_Response]": """ See :func:`treq.get()`. """ @@ -160,7 +160,7 @@ def get(self, url: _URLType, **kwargs: Any) -> Deferred[_Response]: def put( self, url: _URLType, data: Optional[_DataType] = None, **kwargs: Any - ) -> Deferred[_Response]: + ) -> "Deferred[_Response]": """ See :func:`treq.put()`. """ @@ -169,7 +169,7 @@ def put( def patch( self, url: _URLType, data: Optional[_DataType] = None, **kwargs: Any - ) -> Deferred[_Response]: + ) -> "Deferred[_Response]": """ See :func:`treq.patch()`. """ @@ -178,21 +178,21 @@ def patch( def post( self, url: _URLType, data: Optional[_DataType] = None, **kwargs: Any - ) -> Deferred[_Response]: + ) -> "Deferred[_Response]": """ See :func:`treq.post()`. """ kwargs.setdefault("_stacklevel", 3) return self.request("POST", url, data=data, **kwargs) - def head(self, url: _URLType, **kwargs: Any) -> Deferred[_Response]: + def head(self, url: _URLType, **kwargs: Any) -> "Deferred[_Response]": """ See :func:`treq.head()`. """ kwargs.setdefault("_stacklevel", 3) return self.request("HEAD", url, **kwargs) - def delete(self, url: _URLType, **kwargs: Any) -> Deferred[_Response]: + def delete(self, url: _URLType, **kwargs: Any) -> "Deferred[_Response]": """ See :func:`treq.delete()`. """ @@ -217,7 +217,7 @@ def request( reactor: Optional[_ITreqReactor] = None, timeout: Optional[float] = None, _stacklevel: int = 2, - ) -> Deferred[_Response]: + ) -> "Deferred[_Response]": """ See :func:`treq.request()`. """ diff --git a/src/treq/multipart.py b/src/treq/multipart.py index 8bb10219..56becbfa 100644 --- a/src/treq/multipart.py +++ b/src/treq/multipart.py @@ -83,7 +83,7 @@ def __init__( self.length = self._calculateLength() - def startProducing(self, consumer: IConsumer) -> Deferred[None]: + def startProducing(self, consumer: IConsumer) -> "Deferred[None]": """ Start a cooperative task which will read bytes from the input file and write them to `consumer`. Return a `Deferred` which fires after all @@ -94,13 +94,13 @@ def startProducing(self, consumer: IConsumer) -> Deferred[None]: self._task = self._cooperate(self._writeLoop(consumer)) # type: ignore # whenDone returns the iterator that was passed to cooperate, so who # cares what type it has? It's an edge signal; we ignore its value. - d: Deferred[Any] = self._task.whenDone() + d: "Deferred[Any]" = self._task.whenDone() - def maybeStopped(reason: Failure) -> Deferred: + def maybeStopped(reason: Failure) -> "Deferred[None]": reason.trap(task.TaskStopped) return Deferred() - d = cast(Deferred[None], d.addCallbacks(lambda ignored: None, maybeStopped)) + d = cast("Deferred[None]", d.addCallbacks(lambda ignored: None, maybeStopped)) return d def stopProducing(self) -> None: @@ -215,7 +215,7 @@ def _writeFile( content_type: str, producer: IBodyProducer, consumer: _Consumer, - ) -> Optional[Deferred[None]]: + ) -> "Optional[Deferred[None]]": cdisp = _Header(b"Content-Disposition", b"form-data") cdisp.add_param(b"name", name) if filename: @@ -240,7 +240,7 @@ def unset(val): return val d = producer.startProducing(consumer) - return cast(Deferred[None], d.addCallback(unset)) + return cast("Deferred[None]", d.addCallback(unset)) def _escape(value: Union[str, bytes]) -> str: From f6c633e224508aa1e3b35d779cc09af1f92ee579 Mon Sep 17 00:00:00 2001 From: Tom Most Date: Sun, 30 Apr 2023 16:10:33 -0700 Subject: [PATCH 18/21] Update package metadata --- setup.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 8db45e23..5da6f219 100644 --- a/setup.py +++ b/setup.py @@ -7,7 +7,6 @@ "Operating System :: OS Independent", "Framework :: Twisted", "Programming Language :: Python", - "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", @@ -27,13 +26,14 @@ package_dir={"": "src"}, setup_requires=["incremental"], use_incremental=True, - python_requires=">=3.6", + python_requires=">=3.7", install_requires=[ "incremental", "requests >= 2.1.0", "hyperlink >= 21.0.0", "Twisted[tls] >= 18.7.0", "attrs", + "typing_extensions >= 3.10.0", ], extras_require={ "dev": [ From b83d3d27bfb77f496b46ec1217da428cf64e6c2a Mon Sep 17 00:00:00 2001 From: Tom Most Date: Sun, 30 Apr 2023 16:31:08 -0700 Subject: [PATCH 19/21] Up Twisted dep to 22.10.0 For https://github.com/twisted/twisted/issues/11635. --- setup.py | 2 +- tox.ini | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 5da6f219..37cebd15 100644 --- a/setup.py +++ b/setup.py @@ -31,7 +31,7 @@ "incremental", "requests >= 2.1.0", "hyperlink >= 21.0.0", - "Twisted[tls] >= 18.7.0", + "Twisted[tls] >= 22.10.0", # For #11635 "attrs", "typing_extensions >= 3.10.0", ], diff --git a/tox.ini b/tox.ini index 4561f6e2..d23af3cd 100644 --- a/tox.ini +++ b/tox.ini @@ -11,7 +11,7 @@ extras = dev deps = coverage - twisted_lowest: Twisted==18.7.0 + twisted_lowest: Twisted==22.10.0 twisted_latest: Twisted twisted_trunk: https://github.com/twisted/twisted/archive/trunk.zip setenv = From 1e2ba5b04658d46273194e2b99b090cc84fb47e4 Mon Sep 17 00:00:00 2001 From: Tom Most Date: Sun, 30 Apr 2023 20:07:49 -0700 Subject: [PATCH 20/21] Remove temporary noqas --- src/treq/_agentspy.py | 4 ++-- src/treq/auth.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/treq/_agentspy.py b/src/treq/_agentspy.py index 17c80d26..4ee80b20 100644 --- a/src/treq/_agentspy.py +++ b/src/treq/_agentspy.py @@ -1,11 +1,11 @@ # Copyright (c) The treq Authors. # See LICENSE for details. -from typing import Callable, List, Optional, Tuple # noqa +from typing import Callable, List, Optional, Tuple import attr from twisted.internet.defer import Deferred from twisted.web.http_headers import Headers -from twisted.web.iweb import IAgent, IBodyProducer, IResponse # noqa +from twisted.web.iweb import IAgent, IBodyProducer, IResponse from zope.interface import implementer diff --git a/src/treq/auth.py b/src/treq/auth.py index dc8bb15e..ffae1ff4 100644 --- a/src/treq/auth.py +++ b/src/treq/auth.py @@ -3,7 +3,7 @@ from __future__ import absolute_import, division, print_function import binascii -from typing import Union # noqa +from typing import Union from twisted.web.http_headers import Headers from twisted.web.iweb import IAgent From 86eca52b1a198a3fd0d186a1a0ec57aca6a4845f Mon Sep 17 00:00:00 2001 From: Tom Most Date: Sun, 30 Apr 2023 20:25:57 -0700 Subject: [PATCH 21/21] Fully type-annotate treq.content --- pyproject.toml | 7 ++++- src/treq/content.py | 68 +++++++++++++++++++++++++++------------------ 2 files changed, 47 insertions(+), 28 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 76f182ca..0f6d8d48 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,12 +41,17 @@ disallow_untyped_calls = false disallow_untyped_decorators = false strict_equality = false +[[tool.mypy.overrides]] +module = [ + "treq.content", +] +disallow_untyped_defs = true + [[tool.mypy.overrides]] module = [ "treq.api", "treq.auth", "treq.client", - "treq.content", "treq.multipart", "treq.response", "treq.testing", diff --git a/src/treq/content.py b/src/treq/content.py index 8982ce9f..27dba507 100644 --- a/src/treq/content.py +++ b/src/treq/content.py @@ -1,15 +1,18 @@ import cgi import json +from typing import Any, Callable, List, Optional, cast from twisted.internet.defer import Deferred, succeed +from twisted.internet.protocol import Protocol, connectionDone from twisted.python.failure import Failure -from twisted.internet.protocol import Protocol from twisted.web.client import ResponseDone from twisted.web.http import PotentialDataLoss +from twisted.web.http_headers import Headers +from twisted.web.iweb import IResponse -def _encoding_from_headers(headers): - content_types = headers.getRawHeaders(u'content-type') +def _encoding_from_headers(headers: Headers) -> Optional[str]: + content_types = headers.getRawHeaders("content-type") if content_types is None: return None @@ -17,27 +20,36 @@ def _encoding_from_headers(headers): # content-type headers. content_type, params = cgi.parse_header(content_types[-1]) - if 'charset' in params: - return params.get('charset').strip("'\"") + charset = params.get("charset") + if charset: + return charset.strip("'\"") - if content_type == 'application/json': - return 'UTF-8' + if content_type == "application/json": + return "UTF-8" + + return None class _BodyCollector(Protocol): - def __init__(self, finished, collector): + finished: "Optional[Deferred[None]]" + + def __init__( + self, finished: "Deferred[None]", collector: Callable[[bytes], None] + ) -> None: self.finished = finished self.collector = collector - def dataReceived(self, data): + def dataReceived(self, data: bytes) -> None: try: self.collector(data) except BaseException: - self.transport.loseConnection() - self.finished.errback(Failure()) + if self.transport: + self.transport.loseConnection() + if self.finished: + self.finished.errback(Failure()) self.finished = None - def connectionLost(self, reason): + def connectionLost(self, reason: Failure = connectionDone) -> None: if self.finished is None: return if reason.check(ResponseDone): @@ -49,7 +61,9 @@ def connectionLost(self, reason): self.finished.errback(reason) -def collect(response, collector): +def collect( + response: IResponse, collector: Callable[[bytes], None] +) -> "Deferred[None]": """ Incrementally collect the body of the response. @@ -69,12 +83,12 @@ def collect(response, collector): if response.length == 0: return succeed(None) - d = Deferred() + d: "Deferred[None]" = Deferred() response.deliverBody(_BodyCollector(d, collector)) return d -def content(response): +def content(response: IResponse) -> "Deferred[bytes]": """ Read the contents of an HTTP response. @@ -85,13 +99,15 @@ def content(response): :rtype: Deferred that fires with the content as a str. """ - _content = [] + _content: List[bytes] = [] d = collect(response, _content.append) - d.addCallback(lambda _: b''.join(_content)) - return d + return cast( + "Deferred[bytes]", + d.addCallback(lambda _: b"".join(_content)), + ) -def json_content(response, **kwargs): +def json_content(response: IResponse, **kwargs: Any) -> "Deferred[Any]": """ Read the contents of an HTTP response and attempt to decode it as JSON. @@ -105,13 +121,11 @@ def json_content(response, **kwargs): :rtype: Deferred that fires with the decoded JSON. """ # RFC7159 (8.1): Default JSON character encoding is UTF-8 - d = text_content(response, encoding='utf-8') + d = text_content(response, encoding="utf-8") + return d.addCallback(lambda text: json.loads(text, **kwargs)) - d.addCallback(lambda text: json.loads(text, **kwargs)) - return d - -def text_content(response, encoding='ISO-8859-1'): +def text_content(response: IResponse, encoding: str = "ISO-8859-1") -> "Deferred[str]": """ Read the contents of an HTTP response and decode it with an appropriate charset, which may be guessed from the ``Content-Type`` header. @@ -122,7 +136,8 @@ def text_content(response, encoding='ISO-8859-1'): :rtype: Deferred that fires with a unicode string. """ - def _decode_content(c): + + def _decode_content(c: bytes) -> str: e = _encoding_from_headers(response.headers) @@ -132,5 +147,4 @@ def _decode_content(c): return c.decode(encoding) d = content(response) - d.addCallback(_decode_content) - return d + return cast("Deferred[str]", d.addCallback(_decode_content))