Skip to content

Commit

Permalink
test under py3.10 (#1070)
Browse files Browse the repository at this point in the history
Co-authored-by: Marcelo Trylesinski <marcelotryle@gmail.com>
  • Loading branch information
graingert and Kludex committed Jul 30, 2021
1 parent de53c23 commit c87c783
Show file tree
Hide file tree
Showing 13 changed files with 110 additions and 54 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test-suite.yml
Expand Up @@ -13,7 +13,7 @@ jobs:
runs-on: "${{ matrix.os }}"
strategy:
matrix:
python-version: ["3.6", "3.7", "3.8", "3.9"]
python-version: ["3.6", "3.7", "3.8", "3.9", "3.10.0-beta.4"]
os: [windows-latest, ubuntu-latest]
steps:
- uses: "actions/checkout@v2"
Expand Down
2 changes: 1 addition & 1 deletion docs/deployment.md
Expand Up @@ -92,7 +92,7 @@ Options:
--ssl-certfile TEXT SSL certificate file
--ssl-keyfile-password TEXT SSL keyfile password
--ssl-version INTEGER SSL version to use (see stdlib ssl module's)
[default: 2]
[default: 17]
--ssl-cert-reqs INTEGER Whether client certificate is required (see
stdlib ssl module's) [default: 0]
--ssl-ca-certs TEXT CA certificates file
Expand Down
2 changes: 1 addition & 1 deletion docs/index.md
Expand Up @@ -165,7 +165,7 @@ Options:
--ssl-certfile TEXT SSL certificate file
--ssl-keyfile-password TEXT SSL keyfile password
--ssl-version INTEGER SSL version to use (see stdlib ssl module's)
[default: 2]
[default: 17]
--ssl-cert-reqs INTEGER Whether client certificate is required (see
stdlib ssl module's) [default: 0]
--ssl-ca-certs TEXT CA certificates file
Expand Down
6 changes: 4 additions & 2 deletions requirements.txt
Expand Up @@ -20,12 +20,14 @@ types-pyyaml
trustme
cryptography
coverage
httpx==0.16.*
httpx>=0.18.2
pytest-asyncio==0.14.*
async_generator; python_version < '3.7'


# Documentation
mkdocs
mkdocs>=1.2.2
mkdocs-material

# py3.10 workarounds
https://github.com/aaugustin/websockets/archive/4e1dac362a3639b9cf0e5bcf382601e6e32cfede.tar.gz#egg=websockets; python_version >= '3.10'
26 changes: 19 additions & 7 deletions tests/conftest.py
Expand Up @@ -13,7 +13,7 @@ def tls_certificate_authority() -> trustme.CA:

@pytest.fixture
def tls_certificate(tls_certificate_authority: trustme.CA) -> trustme.LeafCert:
return tls_certificate_authority.issue_server_cert(
return tls_certificate_authority.issue_cert(
"localhost",
"127.0.0.1",
"::1",
Expand All @@ -33,9 +33,9 @@ def tls_ca_certificate_private_key_path(tls_certificate_authority: trustme.CA):


@pytest.fixture
def tls_ca_certificate_private_key_encrypted_path(tls_certificate_authority):
def tls_certificate_private_key_encrypted_path(tls_certificate):
private_key = serialization.load_pem_private_key(
tls_certificate_authority.private_key_pem.bytes(),
tls_certificate.private_key_pem.bytes(),
password=None,
backend=default_backend(),
)
Expand All @@ -49,13 +49,25 @@ def tls_ca_certificate_private_key_encrypted_path(tls_certificate_authority):


@pytest.fixture
def tls_certificate_pem_path(tls_certificate: trustme.LeafCert):
def tls_certificate_private_key_path(tls_certificate: trustme.CA):
with tls_certificate.private_key_pem.tempfile() as private_key:
yield private_key


@pytest.fixture
def tls_certificate_key_and_chain_path(tls_certificate: trustme.LeafCert):
with tls_certificate.private_key_and_cert_chain_pem.tempfile() as cert_pem:
yield cert_pem


@pytest.fixture
def tls_ca_ssl_context(tls_certificate: trustme.LeafCert) -> ssl.SSLContext:
ssl_ctx = ssl.SSLContext()
tls_certificate.configure_cert(ssl_ctx)
def tls_certificate_server_cert_path(tls_certificate: trustme.LeafCert):
with tls_certificate.cert_chain_pems[0].tempfile() as cert_pem:
yield cert_pem


@pytest.fixture
def tls_ca_ssl_context(tls_certificate_authority: trustme.CA) -> ssl.SSLContext:
ssl_ctx = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
tls_certificate_authority.configure_trust(ssl_ctx)
return ssl_ctx
23 changes: 19 additions & 4 deletions tests/protocols/test_http.py
@@ -1,3 +1,4 @@
import asyncio
import contextlib
import logging

Expand Down Expand Up @@ -123,12 +124,15 @@ def set_protocol(self, protocol):
pass


class MockLoop:
class MockLoop(asyncio.AbstractEventLoop):
def __init__(self, event_loop):
self.tasks = []
self.later = []
self.loop = event_loop

def is_running(self):
return True

def create_task(self, coroutine):
self.tasks.insert(0, coroutine)
return MockTask()
Expand All @@ -138,7 +142,14 @@ def call_later(self, delay, callback, *args):

def run_one(self):
coroutine = self.tasks.pop()
self.loop.run_until_complete(coroutine)
self.run_until_complete(coroutine)

def run_until_complete(self, coroutine):
asyncio._set_running_loop(None)
try:
return self.loop.run_until_complete(coroutine)
finally:
asyncio._set_running_loop(self)

def close(self):
self.loop.close()
Expand All @@ -161,13 +172,17 @@ def add_done_callback(self, callback):
@contextlib.contextmanager
def get_connected_protocol(app, protocol_cls, event_loop, **kwargs):
loop = MockLoop(event_loop)
asyncio._set_running_loop(loop)
transport = MockTransport()
config = Config(app=app, **kwargs)
server_state = ServerState()
protocol = protocol_cls(config=config, server_state=server_state, _loop=loop)
protocol.connection_made(transport)
yield protocol
protocol.loop.close()
try:
yield protocol
finally:
protocol.loop.close()
asyncio._set_running_loop(None)


@pytest.mark.parametrize("protocol_cls", HTTP_PROTOCOLS)
Expand Down
10 changes: 6 additions & 4 deletions tests/test_auto_detection.py
@@ -1,5 +1,7 @@
import asyncio

import pytest

from uvicorn.config import Config
from uvicorn.loops.auto import auto_loop_setup
from uvicorn.main import ServerState
Expand Down Expand Up @@ -37,19 +39,19 @@ def test_loop_auto():
assert isinstance(policy, asyncio.events.BaseDefaultEventLoopPolicy)
expected_loop = "asyncio" if uvloop is None else "uvloop"
assert type(policy).__module__.startswith(expected_loop)
loop = asyncio.get_event_loop()
loop.close()


def test_http_auto():
@pytest.mark.asyncio
async def test_http_auto():
config = Config(app=app)
server_state = ServerState()
protocol = AutoHTTPProtocol(config=config, server_state=server_state)
expected_http = "H11Protocol" if httptools is None else "HttpToolsProtocol"
assert type(protocol).__name__ == expected_http


def test_websocket_auto():
@pytest.mark.asyncio
async def test_websocket_auto():
config = Config(app=app)
server_state = ServerState()
protocol = AutoWebSocketsProtocol(config=config, server_state=server_state)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_config.py
Expand Up @@ -175,10 +175,10 @@ def test_ssl_config(
assert config.is_ssl is True


def test_ssl_config_combined(tls_certificate_pem_path: str) -> None:
def test_ssl_config_combined(tls_certificate_key_and_chain_path: str) -> None:
config = Config(
app=asgi_app,
ssl_certfile=tls_certificate_pem_path,
ssl_certfile=tls_certificate_key_and_chain_path,
)
config.load()

Expand Down
34 changes: 26 additions & 8 deletions tests/test_ssl.py
Expand Up @@ -13,14 +13,17 @@ async def app(scope, receive, send):

@pytest.mark.asyncio
async def test_run(
tls_ca_ssl_context, tls_ca_certificate_pem_path, tls_ca_certificate_private_key_path
tls_ca_ssl_context,
tls_certificate_server_cert_path,
tls_certificate_private_key_path,
tls_ca_certificate_pem_path,
):
config = Config(
app=app,
loop="asyncio",
limit_max_requests=1,
ssl_keyfile=tls_ca_certificate_private_key_path,
ssl_certfile=tls_ca_certificate_pem_path,
ssl_keyfile=tls_certificate_private_key_path,
ssl_certfile=tls_certificate_server_cert_path,
ssl_ca_certs=tls_ca_certificate_pem_path,
)
async with run_server(config):
Expand All @@ -31,13 +34,13 @@ async def test_run(

@pytest.mark.asyncio
async def test_run_chain(
tls_ca_ssl_context, tls_certificate_pem_path, tls_ca_certificate_pem_path
tls_ca_ssl_context, tls_certificate_key_and_chain_path, tls_ca_certificate_pem_path
):
config = Config(
app=app,
loop="asyncio",
limit_max_requests=1,
ssl_certfile=tls_certificate_pem_path,
ssl_certfile=tls_certificate_key_and_chain_path,
ssl_ca_certs=tls_ca_certificate_pem_path,
)
async with run_server(config):
Expand All @@ -46,18 +49,33 @@ async def test_run_chain(
assert response.status_code == 204


@pytest.mark.asyncio
async def test_run_chain_only(tls_ca_ssl_context, tls_certificate_key_and_chain_path):
config = Config(
app=app,
loop="asyncio",
limit_max_requests=1,
ssl_certfile=tls_certificate_key_and_chain_path,
)
async with run_server(config):
async with httpx.AsyncClient(verify=tls_ca_ssl_context) as client:
response = await client.get("https://127.0.0.1:8000")
assert response.status_code == 204


@pytest.mark.asyncio
async def test_run_password(
tls_ca_ssl_context,
tls_certificate_server_cert_path,
tls_ca_certificate_pem_path,
tls_ca_certificate_private_key_encrypted_path,
tls_certificate_private_key_encrypted_path,
):
config = Config(
app=app,
loop="asyncio",
limit_max_requests=1,
ssl_keyfile=tls_ca_certificate_private_key_encrypted_path,
ssl_certfile=tls_ca_certificate_pem_path,
ssl_keyfile=tls_certificate_private_key_encrypted_path,
ssl_certfile=tls_certificate_server_cert_path,
ssl_keyfile_password="uvicorn password for the win",
ssl_ca_certs=tls_ca_certificate_pem_path,
)
Expand Down
3 changes: 1 addition & 2 deletions uvicorn/config.py
Expand Up @@ -73,8 +73,7 @@
INTERFACES: List[InterfaceType] = ["auto", "asgi3", "asgi2", "wsgi"]


# Fallback to 'ssl.PROTOCOL_SSLv23' in order to support Python < 3.5.3.
SSL_PROTOCOL_VERSION: int = getattr(ssl, "PROTOCOL_TLS", ssl.PROTOCOL_SSLv23)
SSL_PROTOCOL_VERSION: int = ssl.PROTOCOL_TLS_SERVER


LOGGING_CONFIG: dict = {
Expand Down
16 changes: 2 additions & 14 deletions uvicorn/loops/asyncio.py
@@ -1,19 +1,7 @@
import asyncio
import platform
import selectors
import sys


def asyncio_setup() -> None: # pragma: no cover
loop: asyncio.AbstractEventLoop
if (
sys.version_info.major >= 3
and sys.version_info.minor >= 8
and platform.system() == "Windows"
):
selector = selectors.SelectSelector()
loop = asyncio.SelectorEventLoop(selector)
asyncio.set_event_loop(loop)
else:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
if sys.version_info >= (3, 8) and sys.platform == "win32":
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
31 changes: 25 additions & 6 deletions uvicorn/protocols/websockets/websockets_impl.py
@@ -1,7 +1,8 @@
import asyncio
import http
import inspect
import logging
from typing import Callable
from typing import TYPE_CHECKING, Callable
from urllib.parse import unquote

import websockets
Expand All @@ -24,7 +25,26 @@ def is_serving(self):
return not self.closing


class WebSocketProtocol(websockets.WebSocketServerProtocol):
# special case logger kwarg in websockets >=10
# https://github.com/aaugustin/websockets/issues/1021#issuecomment-886222136
if (
TYPE_CHECKING
or "logger" in inspect.signature(websockets.WebSocketServerProtocol).parameters
):

class _LoggerMixin:
pass


else:

class _LoggerMixin:
def __init__(self, *args, logger, **kwargs):
super().__init__(*args, **kwargs)
self.logger = logging.LoggerAdapter(logger, {"websocket": self})


class WebSocketProtocol(_LoggerMixin, websockets.WebSocketServerProtocol):
def __init__(
self, config, server_state, on_connection_lost: Callable = None, _loop=None
):
Expand All @@ -35,7 +55,6 @@ def __init__(
self.app = config.loaded_app
self.on_connection_lost = on_connection_lost
self.loop = _loop or asyncio.get_event_loop()
self.logger = logging.getLogger("uvicorn.error")
self.root_path = config.root_path

# Shared server state
Expand All @@ -59,14 +78,14 @@ def __init__(
self.transfer_data_task = None

self.ws_server = Server()

super().__init__(
ws_handler=self.ws_handler,
ws_server=self.ws_server,
max_size=self.config.ws_max_size,
ping_interval=self.config.ws_ping_interval,
ping_timeout=self.config.ws_ping_timeout,
extensions=[ServerPerMessageDeflateFactory()],
logger=logging.getLogger("uvicorn.error"),
)

def connection_made(self, transport):
Expand All @@ -76,7 +95,7 @@ def connection_made(self, transport):
self.client = get_remote_addr(transport)
self.scheme = "wss" if is_ssl(transport) else "ws"

if self.logger.level <= TRACE_LOG_LEVEL:
if self.logger.isEnabledFor(TRACE_LOG_LEVEL):
prefix = "%s:%d - " % tuple(self.client) if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection made", prefix)

Expand All @@ -85,7 +104,7 @@ def connection_made(self, transport):
def connection_lost(self, exc):
self.connections.remove(self)

if self.logger.level <= TRACE_LOG_LEVEL:
if self.logger.isEnabledFor(TRACE_LOG_LEVEL):
prefix = "%s:%d - " % tuple(self.client) if self.client else ""
self.logger.log(TRACE_LOG_LEVEL, "%sWebSocket connection lost", prefix)

Expand Down
5 changes: 3 additions & 2 deletions uvicorn/server.py
Expand Up @@ -64,8 +64,9 @@ def __init__(self, config: Config) -> None:

def run(self, sockets: Optional[List[socket.socket]] = None) -> None:
self.config.setup_event_loop()
loop = asyncio.get_event_loop()
loop.run_until_complete(self.serve(sockets=sockets))
if sys.version_info >= (3, 7):
return asyncio.run(self.serve(sockets=sockets))
return asyncio.get_event_loop().run_until_complete(self.serve(sockets=sockets))

async def serve(self, sockets: Optional[List[socket.socket]] = None) -> None:
process_id = os.getpid()
Expand Down

0 comments on commit c87c783

Please sign in to comment.