diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index b7d3f856..bdd289da 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -29,7 +29,9 @@ jobs: - run: python -m pip install 'tox<4' - - run: tox -q -p all -e flake8,towncrier,twine,check-manifest + - run: tox -q -p all -e flake8,towncrier,twine,check-manifest,mypy + env: + TOX_PARALLEL_NO_SPINNER: 1 docs: runs-on: ubuntu-20.04 @@ -40,7 +42,7 @@ jobs: - uses: actions/setup-python@v4 with: - python-version: "3.8" + python-version: "3.11" - uses: actions/cache@v3 with: diff --git a/MANIFEST.in b/MANIFEST.in index c68b6d96..6bfd3473 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -3,6 +3,7 @@ include *.rst include *.md include LICENSE include .coveragerc +include src/treq/py.typed recursive-include docs * prune docs/_build prune docs/html diff --git a/changelog.d/297.removal.rst b/changelog.d/297.removal.rst new file mode 100644 index 00000000..f0cc78a1 --- /dev/null +++ b/changelog.d/297.removal.rst @@ -0,0 +1 @@ +Mixing the *json* argument with *files* or *data* now raises `TypeError`. diff --git a/changelog.d/302.removal.rst b/changelog.d/302.removal.rst new file mode 100644 index 00000000..81b8af11 --- /dev/null +++ b/changelog.d/302.removal.rst @@ -0,0 +1 @@ +Passing non-string (`str` or `bytes`) values as part of a dict to the *headers* argument now results in a `TypeError`, as does passing any collection other than a `dict` or `Headers` instance. diff --git a/changelog.d/366.feature.rst b/changelog.d/366.feature.rst new file mode 100644 index 00000000..7fc34540 --- /dev/null +++ b/changelog.d/366.feature.rst @@ -0,0 +1 @@ +treq now ships type annotations. diff --git a/pyproject.toml b/pyproject.toml index f97064fa..0f6d8d48 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,3 +13,70 @@ filename = "CHANGELOG.rst" directory = "changelog.d" title_format = "{version} ({project_date})" issue_format = "`#{issue} `__" + +[tool.mypy] +namespace_packages = true +plugins = "mypy_zope:plugin" + +check_untyped_defs = true +disallow_incomplete_defs = true +disallow_untyped_defs = true +no_implicit_optional = true +show_column_numbers = true +show_error_codes = true +strict_optional = true +warn_no_return = true +warn_redundant_casts = true +warn_return_any = true +warn_unreachable = true +warn_unused_ignores = true + +disallow_any_decorated = false +disallow_any_explicit = false +disallow_any_expr = false +disallow_any_generics = false +disallow_any_unimported = false +disallow_subclassing_any = false +disallow_untyped_calls = false +disallow_untyped_decorators = false +strict_equality = false + +[[tool.mypy.overrides]] +module = [ + "treq.content", +] +disallow_untyped_defs = true + +[[tool.mypy.overrides]] +module = [ + "treq.api", + "treq.auth", + "treq.client", + "treq.multipart", + "treq.response", + "treq.testing", + "treq.test.test_api", + "treq.test.test_auth", + "treq.test.test_client", + "treq.test.test_content", + "treq.test.test_multipart", + "treq.test.test_response", + "treq.test.test_testing", + "treq.test.test_treq_integration", + "treq.test.util", +] +disallow_untyped_defs = false +check_untyped_defs = false + +[[tool.mypy.overrides]] +module = [ + "treq.test.local_httpbin.child", + "treq.test.local_httpbin.parent", + "treq.test.local_httpbin.shared", + "treq.test.local_httpbin.test.test_child", + "treq.test.local_httpbin.test.test_parent", + "treq.test.local_httpbin.test.test_shared", +] +disallow_untyped_defs = false +check_untyped_defs = false +ignore_missing_imports = true diff --git a/setup.py b/setup.py index fcf1891f..52c26c80 100644 --- a/setup.py +++ b/setup.py @@ -27,13 +27,14 @@ package_dir={"": "src"}, setup_requires=["incremental"], use_incremental=True, - python_requires=">=3.6", + python_requires=">=3.7", install_requires=[ "incremental", "requests >= 2.1.0", "hyperlink >= 21.0.0", - "Twisted[tls] >= 22.10.0", + "Twisted[tls] >= 22.10.0", # For #11635 "attrs", + "typing_extensions >= 3.10.0", ], extras_require={ "dev": [ @@ -46,7 +47,7 @@ "sphinx<7.0.0", # Removal of 'style' key breaks RTD. ], }, - package_data={"treq": ["_version"]}, + package_data={"treq": ["py.typed"]}, author="David Reid", author_email="dreid@dreid.org", maintainer="Tom Most", diff --git a/src/treq/__init__.py b/src/treq/__init__.py index 771fb88e..530c6c6b 100644 --- a/src/treq/__init__.py +++ b/src/treq/__init__.py @@ -1,11 +1,20 @@ -from __future__ import absolute_import, division, print_function +from treq.api import delete, get, head, patch, post, put, request +from treq.content import collect, content, json_content, text_content -from ._version import __version__ +from ._version import __version__ as _version -from treq.api import head, get, post, put, patch, delete, request -from treq.content import collect, content, text_content, json_content +__version__: str = _version.base() -__version__ = __version__.base() - -__all__ = ['head', 'get', 'post', 'put', 'patch', 'delete', 'request', - 'collect', 'content', 'text_content', 'json_content'] +__all__ = [ + "head", + "get", + "post", + "put", + "patch", + "delete", + "request", + "collect", + "content", + "text_content", + "json_content", +] diff --git a/src/treq/_agentspy.py b/src/treq/_agentspy.py index c7ad7f66..4ee80b20 100644 --- a/src/treq/_agentspy.py +++ b/src/treq/_agentspy.py @@ -1,11 +1,11 @@ # Copyright (c) The treq Authors. # See LICENSE for details. -from typing import Callable, List, Optional, Tuple # noqa +from typing import Callable, List, Optional, Tuple import attr from twisted.internet.defer import Deferred from twisted.web.http_headers import Headers -from twisted.web.iweb import IAgent, IBodyProducer, IResponse # noqa +from twisted.web.iweb import IAgent, IBodyProducer, IResponse from zope.interface import implementer @@ -21,11 +21,11 @@ class RequestRecord: :ivar deferred: The :class:`Deferred` returned by :meth:`IAgent.request` """ - method = attr.ib() # type: bytes - uri = attr.ib() # type: bytes - headers = attr.ib() # type: Optional[Headers] - bodyProducer = attr.ib() # type: Optional[IBodyProducer] - deferred = attr.ib() # type: Deferred + method: bytes = attr.field() + uri: bytes = attr.field() + headers: Optional[Headers] = attr.field() + bodyProducer: Optional[IBodyProducer] = attr.field() + deferred: "Deferred[IResponse]" = attr.field() @implementer(IAgent) @@ -38,10 +38,15 @@ class _AgentSpy: A function called with each :class:`RequestRecord` """ - _callback = attr.ib() # type: Callable[Tuple[RequestRecord], None] + _callback: Callable[[RequestRecord], None] = attr.ib() - def request(self, method, uri, headers=None, bodyProducer=None): - # type: (bytes, bytes, Optional[Headers], Optional[IBodyProducer]) -> Deferred[IResponse] # noqa + def request( + self, + method: bytes, + uri: bytes, + headers: Optional[Headers] = None, + bodyProducer: Optional[IBodyProducer] = None, + ) -> "Deferred[IResponse]": if not isinstance(method, bytes): raise TypeError( "method must be bytes, not {!r} of type {}".format(method, type(method)) @@ -63,14 +68,13 @@ def request(self, method, uri, headers=None, bodyProducer=None): " Is the implementation marked with @implementer(IBodyProducer)?" ).format(bodyProducer) ) - d = Deferred() + d: "Deferred[IResponse]" = Deferred() record = RequestRecord(method, uri, headers, bodyProducer, d) self._callback(record) return d -def agent_spy(): - # type: () -> Tuple[IAgent, List[RequestRecord]] +def agent_spy() -> Tuple[IAgent, List[RequestRecord]]: """ Record HTTP requests made with an agent @@ -87,6 +91,6 @@ def agent_spy(): - A list of calls made to the agent's :meth:`~twisted.web.iweb.IAgent.request()` method """ - records = [] + records: List[RequestRecord] = [] agent = _AgentSpy(records.append) return agent, records diff --git a/src/treq/_types.py b/src/treq/_types.py new file mode 100644 index 00000000..b758f5e9 --- /dev/null +++ b/src/treq/_types.py @@ -0,0 +1,104 @@ +# Copyright (c) The treq Authors. +# See LICENSE for details. +import io +from http.cookiejar import CookieJar +from typing import Any, Dict, Iterable, List, Mapping, Tuple, Union + +from hyperlink import DecodedURL, EncodedURL +from twisted.internet.interfaces import (IReactorPluggableNameResolver, + IReactorTCP, IReactorTime) +from twisted.web.http_headers import Headers +from twisted.web.iweb import IBodyProducer + + +class _ITreqReactor(IReactorTCP, IReactorTime, IReactorPluggableNameResolver): + """ + The kind of reactor treq needs for type-checking purposes. + + This is an approximation of the actual requirement, which comes from the + `twisted.internet.endpoints.HostnameEndpoint` used by the `Agent` + implementation: + + > Provider of IReactorTCP, IReactorTime and either + > IReactorPluggableNameResolver or IReactorPluggableResolver. + + We don't model the `IReactorPluggableResolver` option because it is + deprecated. + """ + + +_S = Union[bytes, str] + +_URLType = Union[ + str, + bytes, + EncodedURL, + DecodedURL, +] + +_ParamsType = Union[ + Mapping[str, Union[str, Tuple[str, ...], List[str]]], + List[Tuple[str, str]], +] + +_HeadersType = Union[ + Headers, + Dict[_S, _S], + Dict[_S, List[_S]], +] + +_CookiesType = Union[ + CookieJar, + Mapping[str, str], +] + +_WholeBody = Union[ + bytes, + io.BytesIO, + io.BufferedReader, + IBodyProducer, +] +""" +Types that define the entire HTTP request body, including those coercible to +`IBodyProducer`. +""" + +# Concrete types are used here because the handling of the *data* parameter +# does lots of isinstance checks. +_BodyFields = Union[ + Dict[str, str], + List[Tuple[str, str]], +] +""" +Types that will be URL- or multipart-encoded before being sent as part of the +HTTP request body. +""" + +_DataType = Union[_WholeBody, _BodyFields] +""" +Values accepted for the *data* parameter + +Note that this is a simplification. Only `_BodyFields` may be supplied if the +*files* parameter is passed. +""" + +_FileValue = Union[ + str, + bytes, + Tuple[str, str, IBodyProducer], +] +""" +Either a scalar string, or a file to upload as (filename, content type, +IBodyProducer) +""" + +_FilesType = Union[ + Mapping[str, _FileValue], + Iterable[Tuple[str, _FileValue]], +] +""" +Values accepted for the *files* parameter. +""" + +# Soon... 🤞 https://github.com/python/mypy/issues/731 +_JSONType = Any diff --git a/src/treq/auth.py b/src/treq/auth.py index 3a778cea..ffae1ff4 100644 --- a/src/treq/auth.py +++ b/src/treq/auth.py @@ -3,7 +3,7 @@ from __future__ import absolute_import, division, print_function import binascii -from typing import Union # noqa +from typing import Union from twisted.web.http_headers import Headers from twisted.web.iweb import IAgent @@ -46,8 +46,9 @@ def request(self, method, uri, headers=None, bodyProducer=None): method, uri, headers=requestHeaders, bodyProducer=bodyProducer) -def add_basic_auth(agent, username, password): - # type: (IAgent, Union[str, bytes], Union[str, bytes]) -> IAgent +def add_basic_auth( + agent: IAgent, username: Union[str, bytes], password: Union[str, bytes] +) -> IAgent: """ Wrap an agent to add HTTP basic authentication diff --git a/src/treq/client.py b/src/treq/client.py index 1b09fb0b..5a029372 100644 --- a/src/treq/client.py +++ b/src/treq/client.py @@ -1,49 +1,49 @@ import io import mimetypes import uuid -import warnings -from collections.abc import Mapping -from http.cookiejar import CookieJar, Cookie -from urllib.parse import quote_plus, urlencode as _urlencode +from collections import abc +from http.cookiejar import Cookie, CookieJar +from json import dumps as json_dumps +from typing import (Any, Callable, Iterable, Iterator, List, Mapping, + Optional, Tuple, Union) +from urllib.parse import quote_plus +from urllib.parse import urlencode as _urlencode -from twisted.internet.interfaces import IProtocol +from hyperlink import DecodedURL, EncodedURL +from requests.cookies import merge_cookies from twisted.internet.defer import Deferred -from twisted.python.components import proxyForInterface +from twisted.internet.interfaces import IProtocol +from twisted.python.components import proxyForInterface, registerAdapter from twisted.python.filepath import FilePath -from hyperlink import DecodedURL, EncodedURL - +from twisted.web.client import (BrowserLikeRedirectAgent, ContentDecoderAgent, + CookieAgent, FileBodyProducer, GzipDecoder, + IAgent, RedirectAgent) from twisted.web.http_headers import Headers from twisted.web.iweb import IBodyProducer, IResponse -from twisted.web.client import ( - FileBodyProducer, - RedirectAgent, - BrowserLikeRedirectAgent, - ContentDecoderAgent, - GzipDecoder, - CookieAgent -) - -from twisted.python.components import registerAdapter -from json import dumps as json_dumps - -from treq.auth import add_auth from treq import multipart +from treq._types import (_CookiesType, _DataType, _FilesType, _FileValue, + _HeadersType, _ITreqReactor, _JSONType, _ParamsType, + _URLType) +from treq.auth import add_auth from treq.response import _Response -from requests.cookies import merge_cookies -_NOTHING = object() +class _Nothing: + """Type of the sentinel `_NOTHING`""" + + +_NOTHING = _Nothing() -def urlencode(query, doseq): +def urlencode(query: _ParamsType, doseq: bool) -> bytes: s = _urlencode(query, doseq) - if not isinstance(s, bytes): - s = s.encode("ascii") - return s + return s.encode("ascii") -def _scoped_cookiejar_from_dict(url_object, cookie_dict): +def _scoped_cookiejar_from_dict( + url_object: EncodedURL, cookie_dict: Optional[Mapping[str, str]] +) -> CookieJar: """ Create a CookieJar from a dictionary whose cookies are all scoped to the given URL's origin. @@ -55,14 +55,14 @@ def _scoped_cookiejar_from_dict(url_object, cookie_dict): if cookie_dict is None: return cookie_jar for k, v in cookie_dict.items(): - secure = url_object.scheme == 'https' + secure = url_object.scheme == "https" port_specified = not ( (url_object.scheme == "https" and url_object.port == 443) or (url_object.scheme == "http" and url_object.port == 80) ) port = str(url_object.port) if port_specified else None domain = url_object.host - netscape_domain = domain if '.' in domain else domain + '.local' + netscape_domain = domain if "." in domain else domain + ".local" cookie_jar.set_cookie( Cookie( @@ -71,11 +71,9 @@ def _scoped_cookiejar_from_dict(url_object, cookie_dict): port=port, secure=secure, port_specified=port_specified, - # Contents name=k, value=v, - # Constant/always-the-same stuff version=0, path="/", @@ -87,28 +85,28 @@ def _scoped_cookiejar_from_dict(url_object, cookie_dict): path_specified=False, domain_specified=False, domain_initial_dot=False, - rest=[], + rest={}, ) ) return cookie_jar -class _BodyBufferingProtocol(proxyForInterface(IProtocol)): +class _BodyBufferingProtocol(proxyForInterface(IProtocol)): # type: ignore def __init__(self, original, buffer, finished): self.original = original self.buffer = buffer self.finished = finished - def dataReceived(self, data): + def dataReceived(self, data: bytes) -> None: self.buffer.append(data) self.original.dataReceived(data) - def connectionLost(self, reason): + def connectionLost(self, reason: Exception) -> None: self.original.connectionLost(reason) self.finished.errback(reason) -class _BufferedResponse(proxyForInterface(IResponse)): +class _BufferedResponse(proxyForInterface(IResponse)): # type: ignore def __init__(self, original): self.original = original self._buffer = [] @@ -130,11 +128,7 @@ def deliverBody(self, protocol): self._waiting = Deferred() self._waiting.addBoth(self._deliverWaiting) self.original.deliverBody( - _BodyBufferingProtocol( - protocol, - self._buffer, - self._waiting - ) + _BodyBufferingProtocol(protocol, self._buffer, self._waiting) ) elif self._finished: for segment in self._buffer: @@ -145,79 +139,89 @@ def deliverBody(self, protocol): class HTTPClient: - def __init__(self, agent, cookiejar=None, - data_to_body_producer=IBodyProducer): + def __init__( + self, + agent: IAgent, + cookiejar: Optional[CookieJar] = None, + data_to_body_producer: Callable[[Any], IBodyProducer] = IBodyProducer, + ) -> None: self._agent = agent if cookiejar is None: cookiejar = CookieJar() self._cookiejar = cookiejar self._data_to_body_producer = data_to_body_producer - def get(self, url, **kwargs): + def get(self, url: _URLType, **kwargs: Any) -> "Deferred[_Response]": """ See :func:`treq.get()`. """ - kwargs.setdefault('_stacklevel', 3) - return self.request('GET', url, **kwargs) + kwargs.setdefault("_stacklevel", 3) + return self.request("GET", url, **kwargs) - def put(self, url, data=None, **kwargs): + def put( + self, url: _URLType, data: Optional[_DataType] = None, **kwargs: Any + ) -> "Deferred[_Response]": """ See :func:`treq.put()`. """ - kwargs.setdefault('_stacklevel', 3) - return self.request('PUT', url, data=data, **kwargs) + kwargs.setdefault("_stacklevel", 3) + return self.request("PUT", url, data=data, **kwargs) - def patch(self, url, data=None, **kwargs): + def patch( + self, url: _URLType, data: Optional[_DataType] = None, **kwargs: Any + ) -> "Deferred[_Response]": """ See :func:`treq.patch()`. """ - kwargs.setdefault('_stacklevel', 3) - return self.request('PATCH', url, data=data, **kwargs) + kwargs.setdefault("_stacklevel", 3) + return self.request("PATCH", url, data=data, **kwargs) - def post(self, url, data=None, **kwargs): + def post( + self, url: _URLType, data: Optional[_DataType] = None, **kwargs: Any + ) -> "Deferred[_Response]": """ See :func:`treq.post()`. """ - kwargs.setdefault('_stacklevel', 3) - return self.request('POST', url, data=data, **kwargs) + kwargs.setdefault("_stacklevel", 3) + return self.request("POST", url, data=data, **kwargs) - def head(self, url, **kwargs): + def head(self, url: _URLType, **kwargs: Any) -> "Deferred[_Response]": """ See :func:`treq.head()`. """ - kwargs.setdefault('_stacklevel', 3) - return self.request('HEAD', url, **kwargs) + kwargs.setdefault("_stacklevel", 3) + return self.request("HEAD", url, **kwargs) - def delete(self, url, **kwargs): + def delete(self, url: _URLType, **kwargs: Any) -> "Deferred[_Response]": """ See :func:`treq.delete()`. """ - kwargs.setdefault('_stacklevel', 3) - return self.request('DELETE', url, **kwargs) + kwargs.setdefault("_stacklevel", 3) + return self.request("DELETE", url, **kwargs) def request( self, - method, - url, + method: str, + url: _URLType, *, - params=None, - headers=None, - data=None, - files=None, - json=_NOTHING, - auth=None, - cookies=None, - allow_redirects=True, - browser_like_redirects=False, - unbuffered=False, - reactor=None, - timeout=None, - _stacklevel=2, - ): + params: Optional[_ParamsType] = None, + headers: Optional[_HeadersType] = None, + data: Optional[_DataType] = None, + files: Optional[_FilesType] = None, + json: Union[_JSONType, _Nothing] = _NOTHING, + auth: Optional[Tuple[Union[str, bytes], Union[str, bytes]]] = None, + cookies: Optional[_CookiesType] = None, + allow_redirects: bool = True, + browser_like_redirects: bool = False, + unbuffered: bool = False, + reactor: Optional[_ITreqReactor] = None, + timeout: Optional[float] = None, + _stacklevel: int = 2, + ) -> "Deferred[_Response]": """ See :func:`treq.request()`. """ - method = method.encode('ascii').upper() + method_: bytes = method.encode("ascii").upper() if isinstance(url, DecodedURL): parsed_url = url.encoded_url @@ -228,7 +232,7 @@ def request( # bytes in the path and querystring. parsed_url = EncodedURL.from_text(url) else: - parsed_url = EncodedURL.from_text(url.decode('ascii')) + parsed_url = EncodedURL.from_text(url.decode("ascii")) # Join parameters provided in the URL # and the ones passed as argument. @@ -237,20 +241,21 @@ def request( query=parsed_url.query + tuple(_coerced_query_params(params)) ) - url = parsed_url.to_uri().to_text().encode('ascii') + url = parsed_url.to_uri().to_text().encode("ascii") headers = self._request_headers(headers, _stacklevel + 1) - bodyProducer, contentType = self._request_body(data, files, json, - stacklevel=_stacklevel + 1) + bodyProducer, contentType = self._request_body( + data, files, json, stacklevel=_stacklevel + 1 + ) if contentType is not None: - headers.setRawHeaders(b'Content-Type', [contentType]) + headers.setRawHeaders(b"Content-Type", [contentType]) if not isinstance(cookies, CookieJar): cookies = _scoped_cookiejar_from_dict(parsed_url, cookies) cookies = merge_cookies(self._cookiejar, cookies) - wrapped_agent = CookieAgent(self._agent, cookies) + wrapped_agent: IAgent = CookieAgent(self._agent, cookies) if allow_redirects: if browser_like_redirects: @@ -258,18 +263,19 @@ def request( else: wrapped_agent = RedirectAgent(wrapped_agent) - wrapped_agent = ContentDecoderAgent(wrapped_agent, - [(b'gzip', GzipDecoder)]) + wrapped_agent = ContentDecoderAgent(wrapped_agent, [(b"gzip", GzipDecoder)]) if auth: wrapped_agent = add_auth(wrapped_agent, auth) d = wrapped_agent.request( - method, url, headers=headers, - bodyProducer=bodyProducer) + method_, url, headers=headers, bodyProducer=bodyProducer + ) if reactor is None: - from twisted.internet import reactor + from twisted.internet import reactor # type: ignore + assert reactor is not None + if timeout: delayedCall = reactor.callLater(timeout, d.cancel) @@ -285,12 +291,11 @@ def gotResult(result): return d.addCallback(_Response, cookies) - def _request_headers(self, headers, stacklevel): + def _request_headers( + self, headers: Optional[_HeadersType], stacklevel: int + ) -> Headers: """ Convert the *headers* argument to a :class:`Headers` instance - - :returns: - :class:`twisted.web.http_headers.Headers` """ if isinstance(headers, dict): h = Headers({}) @@ -300,14 +305,10 @@ def _request_headers(self, headers, stacklevel): elif isinstance(v, list): h.setRawHeaders(k, v) else: - warnings.warn( - ( - "The value of headers key {!r} has non-string type {}" - " and will be dropped." - " This will raise TypeError in the next treq release." - ).format(k, type(v)), - DeprecationWarning, - stacklevel=stacklevel, + raise TypeError( + "The value of headers key {!r} has non-string type {}.".format( + k, type(v) + ) ) return h if isinstance(headers, Headers): @@ -315,18 +316,20 @@ def _request_headers(self, headers, stacklevel): if headers is None: return Headers({}) - warnings.warn( + raise TypeError( ( "headers must be a dict, twisted.web.http_headers.Headers, or None," - " but found {}, which will be ignored." - " This will raise TypeError in the next treq release." - ).format(type(headers)), - DeprecationWarning, - stacklevel=stacklevel, + " but found {}." + ).format(type(headers)) ) - return Headers({}) - def _request_body(self, data, files, json, stacklevel): + def _request_body( + self, + data: Optional[_DataType], + files: Optional[_FilesType], + json: Union[_JSONType, _Nothing], + stacklevel: int, + ) -> Tuple[Optional[IBodyProducer], Optional[bytes]]: """ Here we choose a right producer based on the parameters passed in. @@ -354,38 +357,47 @@ def _request_body(self, data, files, json, stacklevel): JSON-encodable data, or the sentinel `_NOTHING`. The sentinel is necessary because ``None`` is a valid JSON value. """ - if json is not _NOTHING and (files or data): - warnings.warn( - ( - "Argument 'json' will be ignored because '{}' was also passed." - " This will raise TypeError in the next treq release." - ).format("data" if data else "files"), - DeprecationWarning, - stacklevel=stacklevel, + if json is not _NOTHING: + if files or data: + raise TypeError( + "Argument 'json' cannot be combined with '{}'.".format( + "data" if data else "files" + ) + ) + return ( + self._data_to_body_producer( + json_dumps(json, separators=(",", ":")).encode("utf-8"), + ), + b"application/json; charset=UTF-8", ) if files: # If the files keyword is present we will issue a # multipart/form-data request as it suits better for cases # with files and/or large objects. - files = list(_convert_files(files)) - boundary = str(uuid.uuid4()).encode('ascii') + fields: List[Tuple[str, _FileValue]] = [] if data: - data = _convert_params(data) - else: - data = [] + for field in _convert_params(data): + fields.append(field) + for field in _convert_files(files): + fields.append(field) + boundary = str(uuid.uuid4()).encode("ascii") return ( - multipart.MultiPartProducer(data + files, boundary=boundary), - b'multipart/form-data; boundary=' + boundary, + multipart.MultiPartProducer(fields, boundary=boundary), + b"multipart/form-data; boundary=" + boundary, ) # Otherwise stick to x-www-form-urlencoded format # as it's generally faster for smaller requests. if isinstance(data, (dict, list, tuple)): return ( + # FIXME: The use of doseq here is not permitted in the types, and + # sequence values aren't supported in the files codepath. It is + # maintained here for backwards compatibility. See + # https://github.com/twisted/treq/issues/360. self._data_to_body_producer(urlencode(data, doseq=True)), - b'application/x-www-form-urlencoded', + b"application/x-www-form-urlencoded", ) elif data: return ( @@ -393,22 +405,13 @@ def _request_body(self, data, files, json, stacklevel): None, ) - if json is not _NOTHING: - return ( - self._data_to_body_producer( - json_dumps(json, separators=(u',', u':')).encode('utf-8'), - ), - b'application/json; charset=UTF-8', - ) - return None, None -def _convert_params(params): - if hasattr(params, "iteritems"): - return list(sorted(params.iteritems())) - elif hasattr(params, "items"): - return list(sorted(params.items())) +def _convert_params(params: _DataType) -> Iterable[Tuple[str, str]]: + items_method = getattr(params, "items", None) + if items_method: + return list(sorted(items_method())) elif isinstance(params, (tuple, list)): return list(params) else: @@ -466,8 +469,7 @@ def _convert_files(files): yield (param, (file_name, content_type, IBodyProducer(fobj))) -def _query_quote(v): - # (Any) -> Text +def _query_quote(v: Any) -> str: """ Percent-encode a querystring name or value. @@ -485,7 +487,7 @@ def _query_quote(v): return q -def _coerced_query_params(params): +def _coerced_query_params(params: _ParamsType) -> Iterator[Tuple[str, str]]: """ Carefully coerce *params* in the same way as `urllib.parse.urlencode()` @@ -501,10 +503,9 @@ def _coerced_query_params(params): :returns: A generator that yields two-tuples containing percent-encoded text strings. - :rtype: - Iterator[Tuple[Text, Text]] """ - if isinstance(params, Mapping): + items: Iterable[Tuple[str, Union[str, Tuple[str, ...], List[str]]]] + if isinstance(params, abc.Mapping): items = params.items() else: items = params @@ -518,20 +519,20 @@ def _coerced_query_params(params): yield key_quoted, _query_quote(value) -def _from_bytes(orig_bytes): +def _from_bytes(orig_bytes: bytes) -> IBodyProducer: return FileBodyProducer(io.BytesIO(orig_bytes)) -def _from_file(orig_file): +def _from_file(orig_file: Union[io.BytesIO, io.BufferedReader]) -> IBodyProducer: return FileBodyProducer(orig_file) -def _guess_content_type(filename): +def _guess_content_type(filename: str) -> Optional[str]: if filename: guessed = mimetypes.guess_type(filename)[0] else: guessed = None - return guessed or 'application/octet-stream' + return guessed or "application/octet-stream" registerAdapter(_from_bytes, bytes, IBodyProducer) diff --git a/src/treq/content.py b/src/treq/content.py index 8982ce9f..27dba507 100644 --- a/src/treq/content.py +++ b/src/treq/content.py @@ -1,15 +1,18 @@ import cgi import json +from typing import Any, Callable, List, Optional, cast from twisted.internet.defer import Deferred, succeed +from twisted.internet.protocol import Protocol, connectionDone from twisted.python.failure import Failure -from twisted.internet.protocol import Protocol from twisted.web.client import ResponseDone from twisted.web.http import PotentialDataLoss +from twisted.web.http_headers import Headers +from twisted.web.iweb import IResponse -def _encoding_from_headers(headers): - content_types = headers.getRawHeaders(u'content-type') +def _encoding_from_headers(headers: Headers) -> Optional[str]: + content_types = headers.getRawHeaders("content-type") if content_types is None: return None @@ -17,27 +20,36 @@ def _encoding_from_headers(headers): # content-type headers. content_type, params = cgi.parse_header(content_types[-1]) - if 'charset' in params: - return params.get('charset').strip("'\"") + charset = params.get("charset") + if charset: + return charset.strip("'\"") - if content_type == 'application/json': - return 'UTF-8' + if content_type == "application/json": + return "UTF-8" + + return None class _BodyCollector(Protocol): - def __init__(self, finished, collector): + finished: "Optional[Deferred[None]]" + + def __init__( + self, finished: "Deferred[None]", collector: Callable[[bytes], None] + ) -> None: self.finished = finished self.collector = collector - def dataReceived(self, data): + def dataReceived(self, data: bytes) -> None: try: self.collector(data) except BaseException: - self.transport.loseConnection() - self.finished.errback(Failure()) + if self.transport: + self.transport.loseConnection() + if self.finished: + self.finished.errback(Failure()) self.finished = None - def connectionLost(self, reason): + def connectionLost(self, reason: Failure = connectionDone) -> None: if self.finished is None: return if reason.check(ResponseDone): @@ -49,7 +61,9 @@ def connectionLost(self, reason): self.finished.errback(reason) -def collect(response, collector): +def collect( + response: IResponse, collector: Callable[[bytes], None] +) -> "Deferred[None]": """ Incrementally collect the body of the response. @@ -69,12 +83,12 @@ def collect(response, collector): if response.length == 0: return succeed(None) - d = Deferred() + d: "Deferred[None]" = Deferred() response.deliverBody(_BodyCollector(d, collector)) return d -def content(response): +def content(response: IResponse) -> "Deferred[bytes]": """ Read the contents of an HTTP response. @@ -85,13 +99,15 @@ def content(response): :rtype: Deferred that fires with the content as a str. """ - _content = [] + _content: List[bytes] = [] d = collect(response, _content.append) - d.addCallback(lambda _: b''.join(_content)) - return d + return cast( + "Deferred[bytes]", + d.addCallback(lambda _: b"".join(_content)), + ) -def json_content(response, **kwargs): +def json_content(response: IResponse, **kwargs: Any) -> "Deferred[Any]": """ Read the contents of an HTTP response and attempt to decode it as JSON. @@ -105,13 +121,11 @@ def json_content(response, **kwargs): :rtype: Deferred that fires with the decoded JSON. """ # RFC7159 (8.1): Default JSON character encoding is UTF-8 - d = text_content(response, encoding='utf-8') + d = text_content(response, encoding="utf-8") + return d.addCallback(lambda text: json.loads(text, **kwargs)) - d.addCallback(lambda text: json.loads(text, **kwargs)) - return d - -def text_content(response, encoding='ISO-8859-1'): +def text_content(response: IResponse, encoding: str = "ISO-8859-1") -> "Deferred[str]": """ Read the contents of an HTTP response and decode it with an appropriate charset, which may be guessed from the ``Content-Type`` header. @@ -122,7 +136,8 @@ def text_content(response, encoding='ISO-8859-1'): :rtype: Deferred that fires with a unicode string. """ - def _decode_content(c): + + def _decode_content(c: bytes) -> str: e = _encoding_from_headers(response.headers) @@ -132,5 +147,4 @@ def _decode_content(c): return c.decode(encoding) d = content(response) - d.addCallback(_decode_content) - return d + return cast("Deferred[str]", d.addCallback(_decode_content)) diff --git a/src/treq/multipart.py b/src/treq/multipart.py index 5309a95c..56becbfa 100644 --- a/src/treq/multipart.py +++ b/src/treq/multipart.py @@ -1,18 +1,31 @@ # Copyright (c) Twisted Matrix Laboratories. # See LICENSE for details. -from uuid import uuid4 -from io import BytesIO from contextlib import closing +from io import BytesIO +from typing import Any, Iterable, List, Mapping, Optional, Tuple, Union, cast +from uuid import uuid4 -from twisted.internet import defer, task +from twisted.internet import task +from twisted.internet.defer import Deferred +from twisted.internet.interfaces import IConsumer +from twisted.python.failure import Failure from twisted.web.iweb import UNKNOWN_LENGTH, IBodyProducer - +from typing_extensions import TypeAlias, Literal from zope.interface import implementer +from treq._types import _S, _FilesType, _FileValue + CRLF = b"\r\n" +_Consumer: TypeAlias = "Union[IConsumer, _LengthConsumer]" +_UnknownLength = Literal["'twisted.web.iweb.UNKNOWN_LENGTH'"] +_Length: TypeAlias = Union[int, _UnknownLength] +_FieldValue = Union[bytes, Tuple[str, str, IBodyProducer]] +_Field: TypeAlias = Tuple[str, _FieldValue] + + @implementer(IBodyProducer) class MultiPartProducer: """ @@ -46,22 +59,31 @@ class MultiPartProducer: schedule all reads. :ivar boundary: The generated boundary used in form-data encoding - :type boundary: `bytes` """ - def __init__(self, fields, boundary=None, cooperator=task): - self._fields = list(_sorted_by_type(_converted(fields))) - self._currentProducer = None + length: _Length + boundary: bytes + _currentProducer: Optional[IBodyProducer] = None + _task: Optional[task.CooperativeTask] = None + + def __init__( + self, + fields: _FilesType, + boundary: Optional[Union[str, bytes]] = None, + cooperator: task.Cooperator = cast(task.Cooperator, task), + ) -> None: + self._fields = _sorted_by_type(_converted(fields)) self._cooperate = cooperator.cooperate - self.boundary = boundary or uuid4().hex - - if isinstance(self.boundary, str): - self.boundary = self.boundary.encode('ascii') + if not boundary: + boundary = uuid4().hex.encode("ascii") + if isinstance(boundary, str): + boundary = boundary.encode("ascii") + self.boundary = boundary self.length = self._calculateLength() - def startProducing(self, consumer): + def startProducing(self, consumer: IConsumer) -> "Deferred[None]": """ Start a cooperative task which will read bytes from the input file and write them to `consumer`. Return a `Deferred` which fires after all @@ -69,29 +91,34 @@ def startProducing(self, consumer): :param consumer: Any `IConsumer` provider """ - self._task = self._cooperate(self._writeLoop(consumer)) - d = self._task.whenDone() + self._task = self._cooperate(self._writeLoop(consumer)) # type: ignore + # whenDone returns the iterator that was passed to cooperate, so who + # cares what type it has? It's an edge signal; we ignore its value. + d: "Deferred[Any]" = self._task.whenDone() - def maybeStopped(reason): + def maybeStopped(reason: Failure) -> "Deferred[None]": reason.trap(task.TaskStopped) - return defer.Deferred() - d.addCallbacks(lambda ignored: None, maybeStopped) + return Deferred() + + d = cast("Deferred[None]", d.addCallbacks(lambda ignored: None, maybeStopped)) return d - def stopProducing(self): + def stopProducing(self) -> None: """ Permanently stop writing bytes from the file to the consumer by stopping the underlying `CooperativeTask`. """ + assert self._task is not None if self._currentProducer: self._currentProducer.stopProducing() self._task.stop() - def pauseProducing(self): + def pauseProducing(self) -> None: """ Temporarily suspend copying bytes from the input file to the consumer by pausing the `CooperativeTask` which drives that activity. """ + assert self._task is not None if self._currentProducer: # Having a current producer means that we are in # the paused state because we've returned @@ -103,18 +130,19 @@ def pauseProducing(self): else: self._task.pause() - def resumeProducing(self): + def resumeProducing(self) -> None: """ Undo the effects of a previous `pauseProducing` and resume copying bytes to the consumer by resuming the `CooperativeTask` which drives the write activity. """ + assert self._task is not None if self._currentProducer: self._currentProducer.resumeProducing() else: self._task.resume() - def _calculateLength(self): + def _calculateLength(self) -> _Length: """ Determine how many bytes the overall form post would consume. The easiest way is to calculate is to generate of `fObj` @@ -126,7 +154,7 @@ def _calculateLength(self): pass return consumer.length - def _getBoundary(self, final=False): + def _getBoundary(self, final: bool = False) -> bytes: """ Returns a boundary line, either final (the one that ends the form data request or a regular, the one that separates the boundaries) @@ -136,7 +164,7 @@ def _getBoundary(self, final=False): f = b"--" if final else b"" return b"--" + self.boundary + f - def _writeLoop(self, consumer): + def _writeLoop(self, consumer: _Consumer) -> Iterable[Optional[Deferred]]: """ Return an iterator which generates the multipart/form-data request including the encoded objects @@ -158,30 +186,36 @@ def _writeLoop(self, consumer): # but with CRLF characters before it and after the line. # This is very important. # proper boundary is "CRLF--boundary-valueCRLF" - consumer.write( - (CRLF if index != 0 else b"") + self._getBoundary() + CRLF) + consumer.write((CRLF if index != 0 else b"") + self._getBoundary() + CRLF) yield self._writeField(name, value, consumer) consumer.write(CRLF + self._getBoundary(final=True) + CRLF) - def _writeField(self, name, value, consumer): - if isinstance(value, str): + def _writeField( + self, name: str, value: _FieldValue, consumer: _Consumer + ) -> Optional[Deferred]: + if isinstance(value, bytes): self._writeString(name, value, consumer) - elif isinstance(value, tuple): + return None + else: filename, content_type, producer = value - return self._writeFile( - name, filename, content_type, producer, consumer) + return self._writeFile(name, filename, content_type, producer, consumer) - def _writeString(self, name, value, consumer): + def _writeString(self, name: str, value: bytes, consumer: _Consumer) -> None: cdisp = _Header(b"Content-Disposition", b"form-data") cdisp.add_param(b"name", name) consumer.write(bytes(cdisp) + CRLF + CRLF) - - encoded = value.encode("utf-8") - consumer.write(encoded) + consumer.write(value) self._currentProducer = None - def _writeFile(self, name, filename, content_type, producer, consumer): + def _writeFile( + self, + name: str, + filename: str, + content_type: str, + producer: IBodyProducer, + consumer: _Consumer, + ) -> "Optional[Deferred[None]]": cdisp = _Header(b"Content-Disposition", b"form-data") cdisp.add_param(b"name", name) if filename: @@ -191,11 +225,13 @@ def _writeFile(self, name, filename, content_type, producer, consumer): consumer.write(bytes(_Header(b"Content-Type", content_type)) + CRLF) if producer.length != UNKNOWN_LENGTH: consumer.write( - bytes(_Header(b"Content-Length", producer.length)) + CRLF) + bytes(_Header(b"Content-Length", str(producer.length))) + CRLF + ) consumer.write(CRLF) if isinstance(consumer, _LengthConsumer): consumer.write(producer.length) + return None else: self._currentProducer = producer @@ -204,24 +240,21 @@ def unset(val): return val d = producer.startProducing(consumer) - d.addCallback(unset) - return d + return cast("Deferred[None]", d.addCallback(unset)) -def _escape(value): +def _escape(value: Union[str, bytes]) -> str: """ This function prevents header values from corrupting the request, a newline in the file name parameter makes form-data request unreadable for majority of parsers. """ - if not isinstance(value, (bytes, str)): - value = str(value) if isinstance(value, bytes): - value = value.decode('utf-8') - return value.replace(u"\r", u"").replace(u"\n", u"").replace(u'"', u'\\"') + value = value.decode("utf-8") + return value.replace("\r", "").replace("\n", "").replace('"', '\\"') -def _enforce_unicode(value): +def _enforce_unicode(value: Any) -> str: """ This function enforces the strings passed to be unicode, so we won't need to guess what's the encoding of the binary strings passed in. @@ -238,38 +271,49 @@ def _enforce_unicode(value): return value.decode("utf-8") except UnicodeDecodeError: raise ValueError( - "Supplied raw bytes that are not ascii/utf-8." - " When supplying raw string make sure it's ascii or utf-8" - ", or work with unicode if you are not sure") + "Supplied raw bytes that are not ASCII/UTF-8." + " When supplying raw string make sure it's ASCII or UTF-8" + ", or work with unicode if you are not sure" + ) else: - raise ValueError( - "Unsupported field type: %s" % (value.__class__.__name__,)) + raise ValueError("Unsupported field type: %s" % (value.__class__.__name__,)) -def _converted(fields): - if hasattr(fields, "iteritems"): - fields = fields.iteritems() - elif hasattr(fields, "items"): - fields = fields.items() +def _converted(fields: _FilesType) -> Iterable[_Field]: + """ + Convert any of the multitude of formats we accept for the *fields* + parameter into the form we work with internally. + """ + fields_: Iterable[Tuple[str, _FileValue]] + if hasattr(fields, "items"): + assert isinstance(fields, Mapping) + fields_ = fields.items() + else: + fields_ = fields - for name, value in fields: + for name, value in fields_: + # NOTE: While `name` is typed as `str` we still support UTF-8 `bytes` here + # for backward compatibility, thus this call to decode. name = _enforce_unicode(name) if isinstance(value, (tuple, list)): if len(value) != 3: - raise ValueError( - "Expected tuple: (filename, content type, producer)") + raise ValueError("Expected tuple: (filename, content type, producer)") filename, content_type, producer = value filename = _enforce_unicode(filename) if filename else None yield name, (filename, content_type, producer) - elif isinstance(value, (bytes, str)): - yield name, _enforce_unicode(value) + elif isinstance(value, str): + yield name, value.encode("utf-8") + + elif isinstance(value, bytes): + yield name, value else: raise ValueError( - "Unsupported value, expected string, unicode " - "or tuple (filename, content type, IBodyProducer)") + "Unsupported value, expected str, bytes, " + "or tuple (filename, content type, IBodyProducer)" + ) class _LengthConsumer: @@ -284,21 +328,25 @@ class _LengthConsumer: """ - def __init__(self): + length: _Length + + def __init__(self) -> None: self.length = 0 - def write(self, value): + def write(self, value: Union[bytes, _Length]) -> None: # this means that we have encountered # unknown length producer # so we need to stop attempts calculating - if self.length is UNKNOWN_LENGTH: + if self.length == UNKNOWN_LENGTH: return + assert isinstance(self.length, int) - if value is UNKNOWN_LENGTH: - self.length = value + if value == UNKNOWN_LENGTH: + self.length = cast(_UnknownLength, UNKNOWN_LENGTH) elif isinstance(value, int): self.length += value else: + assert isinstance(value, bytes) self.length += len(value) @@ -311,15 +359,21 @@ class because it encodes unicode fields using =? bla bla ?= that, everyone wants utf-8 raw bytes. """ - def __init__(self, name, value, params=None): + + def __init__( + self, + name: bytes, + value: _S, + params: Optional[List[Tuple[_S, _S]]] = None, + ): self.name = name self.value = value self.params = params or [] - def add_param(self, name, value): + def add_param(self, name: _S, value: _S) -> None: self.params.append((name, value)) - def __bytes__(self): + def __bytes__(self) -> bytes: with closing(BytesIO()) as h: h.write(self.name + b": " + _escape(self.value).encode("us-ascii")) if self.params: @@ -327,24 +381,23 @@ def __bytes__(self): h.write(b"; ") h.write(_escape(name).encode("us-ascii")) h.write(b"=") - h.write(b'"' + _escape(val).encode('utf-8') + b'"') + h.write(b'"' + _escape(val).encode("utf-8") + b'"') h.seek(0) return h.read() - def __str__(self): - return self.__bytes__() - -def _sorted_by_type(fields): +def _sorted_by_type(fields: Iterable[_Field]) -> List[_Field]: """Sorts params so that strings are placed before files. That makes a request more readable, as generally files are bigger. It also provides deterministic order of fields what is easier for testing. """ + def key(p): key, val = p if isinstance(val, (bytes, str)): return (0, key) else: return (1, key) + return sorted(fields, key=key) diff --git a/src/treq/py.typed b/src/treq/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/src/treq/response.py b/src/treq/response.py index d13c3edb..8b87b326 100644 --- a/src/treq/response.py +++ b/src/treq/response.py @@ -1,13 +1,12 @@ -from twisted.python.components import proxyForInterface -from twisted.web.iweb import IResponse, UNKNOWN_LENGTH -from twisted.python import reflect - from requests.cookies import cookiejar_from_dict +from twisted.python import reflect +from twisted.python.components import proxyForInterface +from twisted.web.iweb import UNKNOWN_LENGTH, IResponse from treq.content import collect, content, json_content, text_content -class _Response(proxyForInterface(IResponse)): +class _Response(proxyForInterface(IResponse)): # type: ignore """ A wrapper for :class:`twisted.web.iweb.IResponse` which manages cookies and adds a few convenience methods. @@ -23,14 +22,15 @@ def __repr__(self): status code, Content-Type header, and body size, if available. """ if self.original.length == UNKNOWN_LENGTH: - size = 'unknown size' + size = "unknown size" else: - size = '{:,d} bytes'.format(self.original.length) + size = "{:,d} bytes".format(self.original.length) # Display non-ascii bits of the content-type header as backslash # escapes. - content_type_bytes = b', '.join( - self.original.headers.getRawHeaders(b'content-type', ())) - content_type = repr(content_type_bytes).lstrip('b')[1:-1] + content_type_bytes = b", ".join( + self.original.headers.getRawHeaders(b"content-type", ()) + ) + content_type = repr(content_type_bytes).lstrip("b")[1:-1] return "<{} {} '{:.40s}' {}>".format( reflect.qual(self.__class__), self.original.code, @@ -71,7 +71,7 @@ def json(self, **kwargs): """ return json_content(self.original, **kwargs) - def text(self, encoding='ISO-8859-1'): + def text(self, encoding="ISO-8859-1"): """ Read the entire body all at once as text, per :func:`treq.text_content()`. @@ -93,8 +93,7 @@ def history(self): history = [] while response.previousResponse is not None: - history.append(_Response(response.previousResponse, - self._cookiejar)) + history.append(_Response(response.previousResponse, self._cookiejar)) response = response.previousResponse history.reverse() diff --git a/src/treq/test/local_httpbin/child.py b/src/treq/test/local_httpbin/child.py index 4a6914e9..f6cb2216 100644 --- a/src/treq/test/local_httpbin/child.py +++ b/src/treq/test/local_httpbin/child.py @@ -8,7 +8,7 @@ import datetime import sys -import httpbin +import httpbin # type: ignore from twisted.internet.defer import Deferred, inlineCallbacks from twisted.internet.endpoints import TCP4ServerEndpoint, SSL4ServerEndpoint @@ -16,7 +16,7 @@ from twisted.internet.ssl import (Certificate, CertificateOptions) -from OpenSSL.crypto import PKey, X509 +from OpenSSL.crypto import PKey, X509 # type: ignore from twisted.python.threadpool import ThreadPool from twisted.web.server import Site diff --git a/src/treq/test/local_httpbin/test/test_child.py b/src/treq/test/local_httpbin/test/test_child.py index 2a08d4e3..01a10fe8 100644 --- a/src/treq/test/local_httpbin/test/test_child.py +++ b/src/treq/test/local_httpbin/test/test_child.py @@ -37,7 +37,7 @@ class CertificatesForAuthorityAndServerTests(SynchronousTestCase): """ def setUp(self): - self.hostname = u".example.org" + self.hostname = ".example.org" ( self.ca_cert, self.server_private_key, @@ -54,7 +54,7 @@ def test_pkey_x509_paired(self): server_private_key = self.server_private_key.to_cryptography_key() server_x509_cert = self.server_x509_cert.to_cryptography() - plaintext = b'plaintext' + plaintext = b"plaintext" ciphertext = server_x509_cert.public_key().encrypt( plaintext, padding.PKCS1v15(), @@ -81,7 +81,7 @@ def test_ca_signed_x509(self): server_x509_cert.signature, server_x509_cert.tbs_certificate_bytes, padding.PKCS1v15(), - server_x509_cert.signature_hash_algorithm + server_x509_cert.signature_hash_algorithm, ) def test_x509_matches_hostname(self): @@ -99,6 +99,7 @@ class FakeThreadPoolState(object): """ State for :py:class:`FakeThreadPool`. """ + init_call_count = attr.ib(default=0) start_call_count = attr.ib(default=0) @@ -108,6 +109,7 @@ class FakeThreadPool(object): """ A fake :py:class:`twisted.python.threadpool.ThreadPool` """ + _state = attr.ib() def init(self): @@ -149,8 +151,8 @@ def test_threadpool_management(self): self.assertEqual(self.fake_threadpool_state.init_call_count, 1) self.assertEqual(self.fake_threadpool_state.start_call_count, 1) - self.assertEqual(len(self.reactor.triggers['before']['shutdown']), 1) - [(stop, _, _)] = self.reactor.triggers['before']['shutdown'] + self.assertEqual(len(self.reactor.triggers["before"]["shutdown"]), 1) + [(stop, _, _)] = self.reactor.triggers["before"]["shutdown"] self.assertEqual(stop, self.fake_threadpool.stop) @@ -170,7 +172,7 @@ def test_tls_listener_matches_description(self): and the host, port, and CA certificate are returned in its description. """ - expected_host = 'host' + expected_host = "host" expected_port = 123 description_deferred = child._serve_tls( @@ -182,9 +184,7 @@ def test_tls_listener_matches_description(self): self.assertEqual(len(self.reactor.sslServers), 1) - [ - (actual_port, actual_site, _, _, actual_host) - ] = self.reactor.sslServers + [(actual_port, actual_site, _, _, actual_host)] = self.reactor.sslServers self.assertEqual(actual_host, expected_host) self.assertEqual(actual_port, expected_port) @@ -211,7 +211,7 @@ def test_tcp_listener_matches_description(self): A TCP listeneris established on the request host and port, and the host and port are returned in its description. """ - expected_host = 'host' + expected_host = "host" expected_port = 123 description_deferred = child._serve_tcp( @@ -223,9 +223,7 @@ def test_tcp_listener_matches_description(self): self.assertEqual(len(self.reactor.tcpServers), 1) - [ - (actual_port, actual_site, _, actual_host) - ] = self.reactor.tcpServers + [(actual_port, actual_site, _, actual_host)] = self.reactor.tcpServers self.assertEqual(actual_host, expected_host) self.assertEqual(actual_port, expected_port) @@ -243,6 +241,7 @@ class FlushableBytesIOState(object): """ State for :py:class:`FlushableBytesIO` """ + bio = attr.ib(default=attr.Factory(io.BytesIO)) flush_count = attr.ib(default=0) @@ -252,6 +251,7 @@ class FlushableBytesIO(object): """ A :py:class:`io.BytesIO` wrapper that records flushes. """ + _state = attr.ib() def write(self, data): @@ -267,6 +267,7 @@ class BufferedStandardOut(object): A standard out that whose ``buffer`` is a :py:class:`FlushableBytesIO` instance. """ + buffer = attr.ib() @@ -284,9 +285,7 @@ def test_description_written(self): An :py:class:`shared._HTTPBinDescription` is written to standard out and the line flushed. """ - description = shared._HTTPBinDescription(host="host", - port=123, - cacert="cacert") + description = shared._HTTPBinDescription(host="host", port=123, cacert="cacert") child._output_process_description(description, self.stdout) @@ -294,7 +293,7 @@ def test_description_written(self): self.assertEqual( written, - b'{"cacert": "cacert", "host": "host", "port": 123}' + b'\n', + b'{"cacert": "cacert", "host": "host", "port": 123}' + b"\n", ) self.assertEqual(self.stdout_state.flush_count, 1) @@ -354,17 +353,14 @@ def output_process_description(self, description, *args, **kwargs): self.output_process_description_calls.append(description) return self.output_process_description_returns - def assertDescriptionAndDeferred(self, - description_deferred, - forever_deferred): + def assertDescriptionAndDeferred(self, description_deferred, forever_deferred): """ Assert that firing ``description_deferred`` outputs the description but that ``forever_deferred`` never fires. """ description_deferred.callback("description") - self.assertEqual(self.output_process_description_calls, - ["description"]) + self.assertEqual(self.output_process_description_calls, ["description"]) self.assertNoResult(forever_deferred) @@ -378,9 +374,7 @@ def test_default_arguments(self): self.assertEqual( self.serve_tcp_calls, - [ - (self.reactor, 'localhost', 0, self.make_httpbin_site_returns) - ] + [(self.reactor, "localhost", 0, self.make_httpbin_site_returns)], ) self.assertDescriptionAndDeferred( @@ -393,13 +387,11 @@ def test_https(self): The ``--https`` command line argument serves ``httpbin`` over HTTPS, returning a :py:class:`Deferred` that never fires. """ - deferred = self.forever_httpbin(self.reactor, ['--https']) + deferred = self.forever_httpbin(self.reactor, ["--https"]) self.assertEqual( self.serve_tls_calls, - [ - (self.reactor, 'localhost', 0, self.make_httpbin_site_returns) - ] + [(self.reactor, "localhost", 0, self.make_httpbin_site_returns)], ) self.assertDescriptionAndDeferred( @@ -413,19 +405,18 @@ def test_host(self): provided host, returning a :py:class:`Deferred` that never fires. """ - deferred = self.forever_httpbin(self.reactor, - ['--host', 'example.org']) + deferred = self.forever_httpbin(self.reactor, ["--host", "example.org"]) self.assertEqual( self.serve_tcp_calls, [ ( self.reactor, - 'example.org', + "example.org", 0, self.make_httpbin_site_returns, ) - ] + ], ) self.assertDescriptionAndDeferred( @@ -439,13 +430,11 @@ def test_port(self): the provided port, returning a :py:class:`Deferred` that never fires. """ - deferred = self.forever_httpbin(self.reactor, ['--port', '91']) + deferred = self.forever_httpbin(self.reactor, ["--port", "91"]) self.assertEqual( self.serve_tcp_calls, - [ - (self.reactor, 'localhost', 91, self.make_httpbin_site_returns) - ] + [(self.reactor, "localhost", 91, self.make_httpbin_site_returns)], ) self.assertDescriptionAndDeferred( diff --git a/src/treq/test/test_agentspy.py b/src/treq/test/test_agentspy.py index 6103f6ae..4450dd89 100644 --- a/src/treq/test/test_agentspy.py +++ b/src/treq/test/test_agentspy.py @@ -15,7 +15,7 @@ class APISpyTests(SynchronousTestCase): The agent_spy API provides an agent that records each request made to it. """ - def test_provides_iagent(self): + def test_provides_iagent(self) -> None: """ The agent returned by agent_spy() provides the IAgent interface. """ @@ -23,7 +23,7 @@ def test_provides_iagent(self): self.assertTrue(IAgent.providedBy(agent)) - def test_records(self): + def test_records(self) -> None: """ Each request made with the agent is recorded. """ @@ -43,7 +43,7 @@ def test_records(self): ], ) - def test_record_attributes(self): + def test_record_attributes(self) -> None: """ Each parameter passed to `request` is available as an attribute of the RequestRecord. Additionally, the deferred returned by the call is @@ -62,15 +62,15 @@ def test_record_attributes(self): self.assertIs(rr.bodyProducer, body) self.assertIs(rr.deferred, deferred) - def test_type_validation(self): + def test_type_validation(self) -> None: """ The request method enforces correctness by raising TypeError when passed parameters of the wrong type. """ agent, _ = agent_spy() - self.assertRaises(TypeError, agent.request, u"method not bytes", b"uri") - self.assertRaises(TypeError, agent.request, b"method", u"uri not bytes") + self.assertRaises(TypeError, agent.request, "method not bytes", b"uri") + self.assertRaises(TypeError, agent.request, b"method", "uri not bytes") self.assertRaises( TypeError, agent.request, b"method", b"uri", {"not": "headers"} ) diff --git a/src/treq/test/test_api.py b/src/treq/test/test_api.py index 2d9a1a17..e11531cf 100644 --- a/src/treq/test/test_api.py +++ b/src/treq/test/test_api.py @@ -1,13 +1,14 @@ from __future__ import absolute_import, division -from twisted.web.iweb import IAgent -from twisted.web.client import HTTPConnectionPool -from twisted.trial.unittest import TestCase from twisted.internet import defer +from twisted.trial.unittest import TestCase +from twisted.web.client import HTTPConnectionPool +from twisted.web.iweb import IAgent from zope.interface import implementer import treq -from treq.api import default_reactor, default_pool, set_global_pool, get_global_pool +from treq.api import (default_pool, default_reactor, get_global_pool, + set_global_pool) try: from twisted.internet.testing import MemoryReactorClock @@ -32,7 +33,7 @@ def getConnection(self, key, endpoint): class TreqAPITests(TestCase): - def test_default_pool(self): + def test_default_pool(self) -> None: """ The module-level API uses the global connection pool by default. """ @@ -44,7 +45,7 @@ def test_default_pool(self): self.assertEqual(pool.requests, 1) self.failureResultOf(d, TabError) - def test_cached_pool(self): + def test_cached_pool(self) -> None: """ The first use of the module-level API populates the global connection pool, which is used for all subsequent requests. @@ -61,7 +62,7 @@ def test_cached_pool(self): self.assertEqual(pool.requests, 6) - def test_custom_pool(self): + def test_custom_pool(self) -> None: """ `treq.post()` accepts a *pool* argument to use for the request. The global pool is unaffected. @@ -74,7 +75,7 @@ def test_custom_pool(self): self.failureResultOf(d, TabError) self.assertIsNot(pool, get_global_pool()) - def test_custom_agent(self): + def test_custom_agent(self) -> None: """ A custom Agent is used if specified. """ @@ -93,13 +94,10 @@ def request(self, method, uri, headers=None, bodyProducer=None): self.assertNoResult(d) self.assertEqual(1, custom_agent.requests) - def test_request_invalid_param(self): + def test_request_invalid_param(self) -> None: """ - `treq.request()` warns that it ignores unknown keyword arguments, but - this is deprecated. - - This test verifies that stacklevel is set appropriately when issuing - the warning. + `treq.request()` raises `TypeError` when it receives unknown keyword + arguments. """ with self.assertRaises(TypeError) as c: treq.request( @@ -111,30 +109,22 @@ def test_request_invalid_param(self): self.assertIn("invalid", str(c.exception)) - def test_post_json_with_data(self): + def test_post_json_with_data(self) -> None: """ - `treq.post()` warns that mixing *data* and *json* is deprecated. - - This test verifies that stacklevel is set appropriately when issuing - the warning. + `treq.post()` raises TypeError when the *data* and *json* arguments + are mixed. """ - self.failureResultOf( + with self.assertRaises(TypeError) as c: treq.post( "https://test.example/", data={"hello": "world"}, json={"goodnight": "moon"}, pool=SyntacticAbominationHTTPConnectionPool(), ) - ) - [w] = self.flushWarnings([self.test_post_json_with_data]) - self.assertEqual(DeprecationWarning, w["category"]) self.assertEqual( - ( - "Argument 'json' will be ignored because 'data' was also passed." - " This will raise TypeError in the next treq release." - ), - w["message"], + "Argument 'json' cannot be combined with 'data'.", + str(c.exception), ) @@ -143,7 +133,7 @@ class DefaultReactorTests(TestCase): Test `treq.api.default_reactor()` """ - def test_passes_reactor(self): + def test_passes_reactor(self) -> None: """ `default_reactor()` returns any reactor passed. """ @@ -151,7 +141,7 @@ def test_passes_reactor(self): self.assertIs(default_reactor(reactor), reactor) - def test_uses_default_reactor(self): + def test_uses_default_reactor(self) -> None: """ `default_reactor()` returns the global reactor when passed ``None``. """ @@ -165,11 +155,11 @@ class DefaultPoolTests(TestCase): Test `treq.api.default_pool`. """ - def setUp(self): + def setUp(self) -> None: set_global_pool(None) self.reactor = MemoryReactorClock() - def test_persistent_false(self): + def test_persistent_false(self) -> None: """ When *persistent=False* is passed a non-persistent pool is created. """ @@ -178,7 +168,7 @@ def test_persistent_false(self): self.assertTrue(isinstance(pool, HTTPConnectionPool)) self.assertFalse(pool.persistent) - def test_persistent_false_not_stored(self): + def test_persistent_false_not_stored(self) -> None: """ When *persistent=False* is passed the resulting pool is not stored as the global pool. @@ -187,7 +177,7 @@ def test_persistent_false_not_stored(self): self.assertIsNot(pool, get_global_pool()) - def test_persistent_false_new(self): + def test_persistent_false_new(self) -> None: """ When *persistent=False* is passed a new pool is returned each time. """ @@ -196,7 +186,7 @@ def test_persistent_false_new(self): self.assertIsNot(pool1, pool2) - def test_pool_none_persistent_none(self): + def test_pool_none_persistent_none(self) -> None: """ When *persistent=None* is passed a _persistent_ pool is created for backwards compatibility. @@ -205,7 +195,7 @@ def test_pool_none_persistent_none(self): self.assertTrue(pool.persistent) - def test_pool_none_persistent_true(self): + def test_pool_none_persistent_true(self) -> None: """ When *persistent=True* is passed a persistent pool is created and stored as the global pool. @@ -215,7 +205,7 @@ def test_pool_none_persistent_true(self): self.assertTrue(isinstance(pool, HTTPConnectionPool)) self.assertTrue(pool.persistent) - def test_cached_global_pool(self): + def test_cached_global_pool(self) -> None: """ When *persistent=True* or *persistent=None* is passed the pool created is cached as the global pool. @@ -225,7 +215,7 @@ def test_cached_global_pool(self): self.assertEqual(pool1, pool2) - def test_specified_pool(self): + def test_specified_pool(self) -> None: """ When the user passes a pool it is returned directly. The *persistent* argument is ignored. It is not cached as the global pool. diff --git a/src/treq/test/test_client.py b/src/treq/test/test_client.py index 52897ce6..0a650042 100644 --- a/src/treq/test/test_client.py +++ b/src/treq/test/test_client.py @@ -541,46 +541,28 @@ def test_request_unsupported_params_combination(self): def test_request_json_with_data(self): """ Passing `HTTPClient.request()` both *data* and *json* parameters is - invalid because *json* is ignored. This behavior is deprecated. + invalid because they conflict, producing `TypeError`. """ - self.client.request( - "POST", - "http://example.com/", - data=BytesIO(b"..."), - json=None, # NB: None is a valid value. It encodes to b'null'. - ) - - [w] = self.flushWarnings([self.test_request_json_with_data]) - self.assertEqual(DeprecationWarning, w["category"]) - self.assertEqual( - ( - "Argument 'json' will be ignored because 'data' was also passed." - " This will raise TypeError in the next treq release." - ), - w['message'], - ) + with self.assertRaises(TypeError): + self.client.request( + "POST", + "http://example.com/", + data=BytesIO(b"..."), + json=None, # NB: None is a valid value. It encodes to b'null'. + ) def test_request_json_with_files(self): """ Passing `HTTPClient.request()` both *files* and *json* parameters is - invalid because *json* is ignored. This behavior is deprecated. + invalid because they confict, producing `TypeError`. """ - self.client.request( - "POST", - "http://example.com/", - files={"f1": ("foo.txt", "text/plain", BytesIO(b"...\n"))}, - json=["this is ignored"], - ) - - [w] = self.flushWarnings([self.test_request_json_with_files]) - self.assertEqual(DeprecationWarning, w["category"]) - self.assertEqual( - ( - "Argument 'json' will be ignored because 'files' was also passed." - " This will raise TypeError in the next treq release." - ), - w['message'], - ) + with self.assertRaises(TypeError): + self.client.request( + "POST", + "http://example.com/", + files={"f1": ("foo.txt", "text/plain", BytesIO(b"...\n"))}, + json=["this is ignored"], + ) def test_request_dict_headers(self): self.client.request('GET', 'http://example.com/', headers={ @@ -621,37 +603,20 @@ def test_request_headers_invalid_type(self): `HTTPClient.request()` warns that headers of an unexpected type are invalid and that this behavior is deprecated. """ - self.client.request('GET', 'http://example.com', headers=[]) - - [w] = self.flushWarnings([self.test_request_headers_invalid_type]) - self.assertEqual(DeprecationWarning, w['category']) - self.assertIn( - "headers must be a dict, twisted.web.http_headers.Headers, or None,", - w['message'], - ) + with self.assertRaises(TypeError): + self.client.request('GET', 'http://example.com', headers=[]) def test_request_dict_headers_invalid_values(self): """ `HTTPClient.request()` warns that non-string header values are dropped and that this behavior is deprecated. """ - self.client.request('GET', 'http://example.com', headers=OrderedDict([ - ('none', None), - ('one', 1), - ('ok', 'string'), - ])) - - [w1, w2] = self.flushWarnings([self.test_request_dict_headers_invalid_values]) - self.assertEqual(DeprecationWarning, w1['category']) - self.assertEqual(DeprecationWarning, w2['category']) - self.assertIn( - "The value of headers key 'none' has non-string type", - w1['message'], - ) - self.assertIn( - "The value of headers key 'one' has non-string type", - w2['message'], - ) + with self.assertRaises(TypeError): + self.client.request('GET', 'http://example.com', headers=OrderedDict([ + ('none', None), + ('one', 1), + ('ok', 'string'), + ])) def test_request_invalid_param(self): """ diff --git a/src/treq/test/test_multipart.py b/src/treq/test/test_multipart.py index 7736fbd9..5cbaade5 100644 --- a/src/treq/test/test_multipart.py +++ b/src/treq/test/test_multipart.py @@ -3,6 +3,7 @@ import cgi import sys +from typing import cast, AnyStr from io import BytesIO @@ -10,6 +11,7 @@ from zope.interface.verify import verifyObject from twisted.internet import task +from twisted.internet.testing import StringTransport from twisted.web.client import FileBodyProducer from twisted.web.iweb import UNKNOWN_LENGTH, IBodyProducer @@ -57,7 +59,7 @@ def getOutput(self, producer, with_producer=False): else: return output.getvalue() - def newLines(self, value): + def newLines(self, value: AnyStr) -> AnyStr: if isinstance(value, str): return value.replace(u"\n", u"\r\n") @@ -72,7 +74,7 @@ def test_interface(self): verifyObject( IBodyProducer, MultiPartProducer({}))) - def test_unknownLength(self): + def test_unknownLength(self) -> None: """ If the L{MultiPartProducer} is constructed with a file-like object passed as a parameter without either a C{seek} or C{tell} method, @@ -93,14 +95,14 @@ def tell(self): """ producer = MultiPartProducer( - {"f": ("name", None, FileBodyProducer(CantTell()))}) + {"f": ("name", "application/octet-stream", FileBodyProducer(CantTell()))}) self.assertEqual(UNKNOWN_LENGTH, producer.length) producer = MultiPartProducer( - {"f": ("name", None, FileBodyProducer(CantSeek()))}) + {"f": ("name", "application/octet-stream", FileBodyProducer(CantSeek()))}) self.assertEqual(UNKNOWN_LENGTH, producer.length) - def test_knownLengthOnFile(self): + def test_knownLengthOnFile(self) -> None: """ If the L{MultiPartProducer} is constructed with a file-like object with both C{seek} and C{tell} methods, its C{length} attribute is set to the @@ -110,7 +112,7 @@ def test_knownLengthOnFile(self): inputFile = BytesIO(inputBytes) inputFile.seek(5) producer = MultiPartProducer({ - "field": ('file name', None, FileBodyProducer( + "field": ('file name', "application/octet-stream", FileBodyProducer( inputFile, cooperator=self.cooperator))}) # Make sure we are generous enough not to alter seek position: @@ -119,33 +121,39 @@ def test_knownLengthOnFile(self): # Total length is hard to calculate manually # as it contains a lot of headers parameters, newlines and boundaries # let's assert for now that it's no less than the input parameter - self.assertTrue(producer.length > len(inputBytes)) + self.assertNotEqual(producer.length, UNKNOWN_LENGTH) + self.assertTrue(cast(int, producer.length) > len(inputBytes)) # Calculating length should not touch producers self.assertTrue(producer._currentProducer is None) - def test_defaultCooperator(self): + def test_defaultCooperator(self) -> None: """ If no L{Cooperator} instance is passed to L{MultiPartProducer}, the global cooperator is used. """ producer = MultiPartProducer({ - "field": ('file name', None, FileBodyProducer( + "field": ("file name", "application/octet-stream", FileBodyProducer( BytesIO(b"yo"), cooperator=self.cooperator)) }) self.assertEqual(task.cooperate, producer._cooperate) - def test_startProducing(self): + def test_startProducing(self) -> None: """ L{MultiPartProducer.startProducing} starts writing bytes from the input file to the given L{IConsumer} and returns a L{Deferred} which fires when they have all been written. """ - consumer = output = BytesIO() + consumer = output = StringTransport() + + # We historically accepted bytes for field names and continue to allow + # it for compatibility, but the types don't permit it because it makes + # them even more complicated and awful. So here we verify that that works. + field = cast(str, b"field") producer = MultiPartProducer({ - b"field": ('file name', "text/hello-world", FileBodyProducer( + field: ("file name", "text/hello-world", FileBodyProducer( BytesIO(b"Hello, World"), cooperator=self.cooperator)) }, cooperator=self.cooperator, boundary=b"heyDavid") @@ -165,16 +173,16 @@ def test_startProducing(self): Hello, World --heyDavid-- -"""), output.getvalue()) +"""), output.value()) self.assertEqual(None, self.successResultOf(complete)) - def test_inputClosedAtEOF(self): + def test_inputClosedAtEOF(self) -> None: """ When L{MultiPartProducer} reaches end-of-file on the input file given to it, the input file is closed. """ inputFile = BytesIO(b"hello, world!") - consumer = BytesIO() + consumer = StringTransport() producer = MultiPartProducer({ "field": ( @@ -192,7 +200,7 @@ def test_inputClosedAtEOF(self): self.assertTrue(inputFile.closed) - def test_failedReadWhileProducing(self): + def test_failedReadWhileProducing(self) -> None: """ If a read from the input file fails while producing bytes to the consumer, the L{Deferred} returned by @@ -212,7 +220,7 @@ def read(self, count): cooperator=self.cooperator)) }, cooperator=self.cooperator, boundary=b"heyDavid") - complete = producer.startProducing(BytesIO()) + complete = producer.startProducing(StringTransport()) while self._scheduled: self._scheduled.pop(0)() @@ -244,13 +252,13 @@ def test_stopProducing(self): self._scheduled.pop(0)() self.assertNoResult(complete) - def test_pauseProducing(self): + def test_pauseProducing(self) -> None: """ L{MultiPartProducer.pauseProducing} temporarily suspends writing bytes from the input file to the given L{IConsumer}. """ inputFile = BytesIO(b"hello, world!") - consumer = output = BytesIO() + consumer = output = StringTransport() producer = MultiPartProducer({ "field": ( @@ -263,7 +271,7 @@ def test_pauseProducing(self): complete = producer.startProducing(consumer) self._scheduled.pop(0)() - currentValue = output.getvalue() + currentValue = output.value() self.assertTrue(currentValue) producer.pauseProducing() @@ -274,17 +282,17 @@ def test_pauseProducing(self): self._scheduled.pop(0)() # Since the producer is paused, no new data should be here. - self.assertEqual(output.getvalue(), currentValue) + self.assertEqual(output.value(), currentValue) self.assertNoResult(complete) - def test_resumeProducing(self): + def test_resumeProducing(self) -> None: """ L{MultoPartProducer.resumeProducing} re-commences writing bytes from the input file to the given L{IConsumer} after it was previously paused with L{MultiPartProducer.pauseProducing}. """ inputFile = BytesIO(b"hello, world!") - consumer = output = BytesIO() + consumer = output = StringTransport() producer = MultiPartProducer({ "field": ( @@ -297,15 +305,15 @@ def test_resumeProducing(self): producer.startProducing(consumer) self._scheduled.pop(0)() - currentValue = output.getvalue() + currentValue = output.value() self.assertTrue(currentValue) producer.pauseProducing() producer.resumeProducing() self._scheduled.pop(0)() # make sure we started producing new data after resume - self.assertTrue(len(currentValue) < len(output.getvalue())) + self.assertTrue(len(currentValue) < len(output.value())) - def test_unicodeString(self): + def test_unicodeString(self) -> None: """ Make sure unicode string is passed properly """ @@ -325,19 +333,28 @@ def test_unicodeString(self): self.assertEqual(producer.length, len(expected)) self.assertEqual(expected, output) - def test_failOnByteStrings(self): + def test_bytesPassThrough(self) -> None: """ - If byte string is passed as a param and we don't know - the encoding, fail early to prevent corrupted form posts + If byte string is passed as a param it is passed through + unchanged. """ - self.assertRaises( - ValueError, - MultiPartProducer, { - "afield": u"это моя строчечка".encode("utf-32"), - }, - cooperator=self.cooperator, boundary=b"heyDavid") + output, producer = self.getOutput( + MultiPartProducer({ + "bfield": b'\x00\x01\x02\x03', + }, cooperator=self.cooperator, boundary=b"heyDavid"), + with_producer=True) + + expected = ( + b"--heyDavid\r\n" + b'Content-Disposition: form-data; name="bfield"\r\n' + b'\r\n' + b'\x00\x01\x02\x03\r\n' + b'--heyDavid--\r\n' + ) + self.assertEqual(producer.length, len(expected)) + self.assertEqual(expected, output) - def test_failOnUnknownParams(self): + def test_failOnUnknownParams(self) -> None: """ If byte string is passed as a param and we don't know the encoding, fail early to prevent corrupted form posts @@ -366,7 +383,7 @@ def test_failOnUnknownParams(self): }, cooperator=self.cooperator, boundary=b"heyDavid") - def test_twoFields(self): + def test_twoFields(self) -> None: """ Make sure multiple fields are rendered properly. """ diff --git a/src/treq/testing.py b/src/treq/testing.py index 33d8c94d..627b29b6 100644 --- a/src/treq/testing.py +++ b/src/treq/testing.py @@ -206,6 +206,15 @@ def startProducing(self, consumer): consumer.write(self.body) return succeed(None) + def stopProducing(self): + raise NotImplementedError() + + def pauseProducing(self): + raise NotImplementedError() + + def resumeProducing(self): + raise NotImplementedError() + def _reject_files(f): """ diff --git a/tox.ini b/tox.ini index 7d5eb933..b22c5db0 100644 --- a/tox.ini +++ b/tox.ini @@ -25,6 +25,18 @@ commands = {envbindir}/coverage run -p \ {envbindir}/trial {posargs:treq} +[testenv:mypy] +basepython = python3.8 +deps = + mypy==1.0.1 + mypy-zope==0.9.1 + types-requests +commands = + mypy \ + --cache-dir="{toxworkdir}/mypy_cache" \ + {tty:--pretty:} \ + {posargs:src} + [testenv:flake8] python = python3.11 skip_install = True