From c87c7834b77583efd5963dfddb08681206701d33 Mon Sep 17 00:00:00 2001 From: Thomas Grainger Date: Fri, 30 Jul 2021 12:56:48 +0100 Subject: [PATCH] test under py3.10 (#1070) Co-authored-by: Marcelo Trylesinski --- .github/workflows/test-suite.yml | 2 +- docs/deployment.md | 2 +- docs/index.md | 2 +- requirements.txt | 6 ++-- tests/conftest.py | 26 ++++++++++---- tests/protocols/test_http.py | 23 ++++++++++--- tests/test_auto_detection.py | 10 +++--- tests/test_config.py | 4 +-- tests/test_ssl.py | 34 ++++++++++++++----- uvicorn/config.py | 3 +- uvicorn/loops/asyncio.py | 16 ++------- .../protocols/websockets/websockets_impl.py | 31 +++++++++++++---- uvicorn/server.py | 5 +-- 13 files changed, 110 insertions(+), 54 deletions(-) diff --git a/.github/workflows/test-suite.yml b/.github/workflows/test-suite.yml index b8a82b639..ebae2f071 100644 --- a/.github/workflows/test-suite.yml +++ b/.github/workflows/test-suite.yml @@ -13,7 +13,7 @@ jobs: runs-on: "${{ matrix.os }}" strategy: matrix: - python-version: ["3.6", "3.7", "3.8", "3.9"] + python-version: ["3.6", "3.7", "3.8", "3.9", "3.10.0-beta.4"] os: [windows-latest, ubuntu-latest] steps: - uses: "actions/checkout@v2" diff --git a/docs/deployment.md b/docs/deployment.md index a251f4a1c..819a93260 100644 --- a/docs/deployment.md +++ b/docs/deployment.md @@ -92,7 +92,7 @@ Options: --ssl-certfile TEXT SSL certificate file --ssl-keyfile-password TEXT SSL keyfile password --ssl-version INTEGER SSL version to use (see stdlib ssl module's) - [default: 2] + [default: 17] --ssl-cert-reqs INTEGER Whether client certificate is required (see stdlib ssl module's) [default: 0] --ssl-ca-certs TEXT CA certificates file diff --git a/docs/index.md b/docs/index.md index 957882c72..1af05b97f 100644 --- a/docs/index.md +++ b/docs/index.md @@ -165,7 +165,7 @@ Options: --ssl-certfile TEXT SSL certificate file --ssl-keyfile-password TEXT SSL keyfile password --ssl-version INTEGER SSL version to use (see stdlib ssl module's) - [default: 2] + [default: 17] --ssl-cert-reqs INTEGER Whether client certificate is required (see stdlib ssl module's) [default: 0] --ssl-ca-certs TEXT CA certificates file diff --git a/requirements.txt b/requirements.txt index 77c9e2d49..9cb3205be 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,12 +20,14 @@ types-pyyaml trustme cryptography coverage -httpx==0.16.* +httpx>=0.18.2 pytest-asyncio==0.14.* async_generator; python_version < '3.7' # Documentation -mkdocs +mkdocs>=1.2.2 mkdocs-material +# py3.10 workarounds +https://github.com/aaugustin/websockets/archive/4e1dac362a3639b9cf0e5bcf382601e6e32cfede.tar.gz#egg=websockets; python_version >= '3.10' diff --git a/tests/conftest.py b/tests/conftest.py index a5e5a8e12..dc99d8b63 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,7 +13,7 @@ def tls_certificate_authority() -> trustme.CA: @pytest.fixture def tls_certificate(tls_certificate_authority: trustme.CA) -> trustme.LeafCert: - return tls_certificate_authority.issue_server_cert( + return tls_certificate_authority.issue_cert( "localhost", "127.0.0.1", "::1", @@ -33,9 +33,9 @@ def tls_ca_certificate_private_key_path(tls_certificate_authority: trustme.CA): @pytest.fixture -def tls_ca_certificate_private_key_encrypted_path(tls_certificate_authority): +def tls_certificate_private_key_encrypted_path(tls_certificate): private_key = serialization.load_pem_private_key( - tls_certificate_authority.private_key_pem.bytes(), + tls_certificate.private_key_pem.bytes(), password=None, backend=default_backend(), ) @@ -49,13 +49,25 @@ def tls_ca_certificate_private_key_encrypted_path(tls_certificate_authority): @pytest.fixture -def tls_certificate_pem_path(tls_certificate: trustme.LeafCert): +def tls_certificate_private_key_path(tls_certificate: trustme.CA): + with tls_certificate.private_key_pem.tempfile() as private_key: + yield private_key + + +@pytest.fixture +def tls_certificate_key_and_chain_path(tls_certificate: trustme.LeafCert): with tls_certificate.private_key_and_cert_chain_pem.tempfile() as cert_pem: yield cert_pem @pytest.fixture -def tls_ca_ssl_context(tls_certificate: trustme.LeafCert) -> ssl.SSLContext: - ssl_ctx = ssl.SSLContext() - tls_certificate.configure_cert(ssl_ctx) +def tls_certificate_server_cert_path(tls_certificate: trustme.LeafCert): + with tls_certificate.cert_chain_pems[0].tempfile() as cert_pem: + yield cert_pem + + +@pytest.fixture +def tls_ca_ssl_context(tls_certificate_authority: trustme.CA) -> ssl.SSLContext: + ssl_ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH) + tls_certificate_authority.configure_trust(ssl_ctx) return ssl_ctx diff --git a/tests/protocols/test_http.py b/tests/protocols/test_http.py index 577f70edb..826f05fdc 100644 --- a/tests/protocols/test_http.py +++ b/tests/protocols/test_http.py @@ -1,3 +1,4 @@ +import asyncio import contextlib import logging @@ -123,12 +124,15 @@ def set_protocol(self, protocol): pass -class MockLoop: +class MockLoop(asyncio.AbstractEventLoop): def __init__(self, event_loop): self.tasks = [] self.later = [] self.loop = event_loop + def is_running(self): + return True + def create_task(self, coroutine): self.tasks.insert(0, coroutine) return MockTask() @@ -138,7 +142,14 @@ def call_later(self, delay, callback, *args): def run_one(self): coroutine = self.tasks.pop() - self.loop.run_until_complete(coroutine) + self.run_until_complete(coroutine) + + def run_until_complete(self, coroutine): + asyncio._set_running_loop(None) + try: + return self.loop.run_until_complete(coroutine) + finally: + asyncio._set_running_loop(self) def close(self): self.loop.close() @@ -161,13 +172,17 @@ def add_done_callback(self, callback): @contextlib.contextmanager def get_connected_protocol(app, protocol_cls, event_loop, **kwargs): loop = MockLoop(event_loop) + asyncio._set_running_loop(loop) transport = MockTransport() config = Config(app=app, **kwargs) server_state = ServerState() protocol = protocol_cls(config=config, server_state=server_state, _loop=loop) protocol.connection_made(transport) - yield protocol - protocol.loop.close() + try: + yield protocol + finally: + protocol.loop.close() + asyncio._set_running_loop(None) @pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS) diff --git a/tests/test_auto_detection.py b/tests/test_auto_detection.py index 93b84e4aa..f23cb0161 100644 --- a/tests/test_auto_detection.py +++ b/tests/test_auto_detection.py @@ -1,5 +1,7 @@ import asyncio +import pytest + from uvicorn.config import Config from uvicorn.loops.auto import auto_loop_setup from uvicorn.main import ServerState @@ -37,11 +39,10 @@ def test_loop_auto(): assert isinstance(policy, asyncio.events.BaseDefaultEventLoopPolicy) expected_loop = "asyncio" if uvloop is None else "uvloop" assert type(policy).__module__.startswith(expected_loop) - loop = asyncio.get_event_loop() - loop.close() -def test_http_auto(): +@pytest.mark.asyncio +async def test_http_auto(): config = Config(app=app) server_state = ServerState() protocol = AutoHTTPProtocol(config=config, server_state=server_state) @@ -49,7 +50,8 @@ def test_http_auto(): assert type(protocol).__name__ == expected_http -def test_websocket_auto(): +@pytest.mark.asyncio +async def test_websocket_auto(): config = Config(app=app) server_state = ServerState() protocol = AutoWebSocketsProtocol(config=config, server_state=server_state) diff --git a/tests/test_config.py b/tests/test_config.py index 1e26c4cc5..9efbd6c47 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -175,10 +175,10 @@ def test_ssl_config( assert config.is_ssl is True -def test_ssl_config_combined(tls_certificate_pem_path: str) -> None: +def test_ssl_config_combined(tls_certificate_key_and_chain_path: str) -> None: config = Config( app=asgi_app, - ssl_certfile=tls_certificate_pem_path, + ssl_certfile=tls_certificate_key_and_chain_path, ) config.load() diff --git a/tests/test_ssl.py b/tests/test_ssl.py index d44e5de49..346ffe76c 100644 --- a/tests/test_ssl.py +++ b/tests/test_ssl.py @@ -13,14 +13,17 @@ async def app(scope, receive, send): @pytest.mark.asyncio async def test_run( - tls_ca_ssl_context, tls_ca_certificate_pem_path, tls_ca_certificate_private_key_path + tls_ca_ssl_context, + tls_certificate_server_cert_path, + tls_certificate_private_key_path, + tls_ca_certificate_pem_path, ): config = Config( app=app, loop="asyncio", limit_max_requests=1, - ssl_keyfile=tls_ca_certificate_private_key_path, - ssl_certfile=tls_ca_certificate_pem_path, + ssl_keyfile=tls_certificate_private_key_path, + ssl_certfile=tls_certificate_server_cert_path, ssl_ca_certs=tls_ca_certificate_pem_path, ) async with run_server(config): @@ -31,13 +34,13 @@ async def test_run( @pytest.mark.asyncio async def test_run_chain( - tls_ca_ssl_context, tls_certificate_pem_path, tls_ca_certificate_pem_path + tls_ca_ssl_context, tls_certificate_key_and_chain_path, tls_ca_certificate_pem_path ): config = Config( app=app, loop="asyncio", limit_max_requests=1, - ssl_certfile=tls_certificate_pem_path, + ssl_certfile=tls_certificate_key_and_chain_path, ssl_ca_certs=tls_ca_certificate_pem_path, ) async with run_server(config): @@ -46,18 +49,33 @@ async def test_run_chain( assert response.status_code == 204 +@pytest.mark.asyncio +async def test_run_chain_only(tls_ca_ssl_context, tls_certificate_key_and_chain_path): + config = Config( + app=app, + loop="asyncio", + limit_max_requests=1, + ssl_certfile=tls_certificate_key_and_chain_path, + ) + async with run_server(config): + async with httpx.AsyncClient(verify=tls_ca_ssl_context) as client: + response = await client.get("https://127.0.0.1:8000") + assert response.status_code == 204 + + @pytest.mark.asyncio async def test_run_password( tls_ca_ssl_context, + tls_certificate_server_cert_path, tls_ca_certificate_pem_path, - tls_ca_certificate_private_key_encrypted_path, + tls_certificate_private_key_encrypted_path, ): config = Config( app=app, loop="asyncio", limit_max_requests=1, - ssl_keyfile=tls_ca_certificate_private_key_encrypted_path, - ssl_certfile=tls_ca_certificate_pem_path, + ssl_keyfile=tls_certificate_private_key_encrypted_path, + ssl_certfile=tls_certificate_server_cert_path, ssl_keyfile_password="uvicorn password for the win", ssl_ca_certs=tls_ca_certificate_pem_path, ) diff --git a/uvicorn/config.py b/uvicorn/config.py index 98265abcc..d14fc70c8 100644 --- a/uvicorn/config.py +++ b/uvicorn/config.py @@ -73,8 +73,7 @@ INTERFACES: List[InterfaceType] = ["auto", "asgi3", "asgi2", "wsgi"] -# Fallback to 'ssl.PROTOCOL_SSLv23' in order to support Python < 3.5.3. -SSL_PROTOCOL_VERSION: int = getattr(ssl, "PROTOCOL_TLS", ssl.PROTOCOL_SSLv23) +SSL_PROTOCOL_VERSION: int = ssl.PROTOCOL_TLS_SERVER LOGGING_CONFIG: dict = { diff --git a/uvicorn/loops/asyncio.py b/uvicorn/loops/asyncio.py index ba764b03a..dad2f4034 100644 --- a/uvicorn/loops/asyncio.py +++ b/uvicorn/loops/asyncio.py @@ -1,19 +1,7 @@ import asyncio -import platform -import selectors import sys def asyncio_setup() -> None: # pragma: no cover - loop: asyncio.AbstractEventLoop - if ( - sys.version_info.major >= 3 - and sys.version_info.minor >= 8 - and platform.system() == "Windows" - ): - selector = selectors.SelectSelector() - loop = asyncio.SelectorEventLoop(selector) - asyncio.set_event_loop(loop) - else: - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) + if sys.version_info >= (3, 8) and sys.platform == "win32": + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) diff --git a/uvicorn/protocols/websockets/websockets_impl.py b/uvicorn/protocols/websockets/websockets_impl.py index ddc4aa536..6f5c32f56 100644 --- a/uvicorn/protocols/websockets/websockets_impl.py +++ b/uvicorn/protocols/websockets/websockets_impl.py @@ -1,7 +1,8 @@ import asyncio import http +import inspect import logging -from typing import Callable +from typing import TYPE_CHECKING, Callable from urllib.parse import unquote import websockets @@ -24,7 +25,26 @@ def is_serving(self): return not self.closing -class WebSocketProtocol(websockets.WebSocketServerProtocol): +# special case logger kwarg in websockets >=10 +# https://github.com/aaugustin/websockets/issues/1021#issuecomment-886222136 +if ( + TYPE_CHECKING + or "logger" in inspect.signature(websockets.WebSocketServerProtocol).parameters +): + + class _LoggerMixin: + pass + + +else: + + class _LoggerMixin: + def __init__(self, *args, logger, **kwargs): + super().__init__(*args, **kwargs) + self.logger = logging.LoggerAdapter(logger, {"websocket": self}) + + +class WebSocketProtocol(_LoggerMixin, websockets.WebSocketServerProtocol): def __init__( self, config, server_state, on_connection_lost: Callable = None, _loop=None ): @@ -35,7 +55,6 @@ def __init__( self.app = config.loaded_app self.on_connection_lost = on_connection_lost self.loop = _loop or asyncio.get_event_loop() - self.logger = logging.getLogger("uvicorn.error") self.root_path = config.root_path # Shared server state @@ -59,7 +78,6 @@ def __init__( self.transfer_data_task = None self.ws_server = Server() - super().__init__( ws_handler=self.ws_handler, ws_server=self.ws_server, @@ -67,6 +85,7 @@ def __init__( ping_interval=self.config.ws_ping_interval, ping_timeout=self.config.ws_ping_timeout, extensions=[ServerPerMessageDeflateFactory()], + logger=logging.getLogger("uvicorn.error"), ) def connection_made(self, transport): @@ -76,7 +95,7 @@ def connection_made(self, transport): self.client = get_remote_addr(transport) self.scheme = "wss" if is_ssl(transport) else "ws" - if self.logger.level <= TRACE_LOG_LEVEL: + if self.logger.isEnabledFor(TRACE_LOG_LEVEL): prefix = "%s:%d - " % tuple(self.client) if self.client else "" self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix) @@ -85,7 +104,7 @@ def connection_made(self, transport): def connection_lost(self, exc): self.connections.remove(self) - if self.logger.level <= TRACE_LOG_LEVEL: + if self.logger.isEnabledFor(TRACE_LOG_LEVEL): prefix = "%s:%d - " % tuple(self.client) if self.client else "" self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix) diff --git a/uvicorn/server.py b/uvicorn/server.py index b6e3e8791..6aa694cae 100644 --- a/uvicorn/server.py +++ b/uvicorn/server.py @@ -64,8 +64,9 @@ def __init__(self, config: Config) -> None: def run(self, sockets: Optional[List[socket.socket]] = None) -> None: self.config.setup_event_loop() - loop = asyncio.get_event_loop() - loop.run_until_complete(self.serve(sockets=sockets)) + if sys.version_info >= (3, 7): + return asyncio.run(self.serve(sockets=sockets)) + return asyncio.get_event_loop().run_until_complete(self.serve(sockets=sockets)) async def serve(self, sockets: Optional[List[socket.socket]] = None) -> None: process_id = os.getpid()