From 7c2ddb748e592befb9e2dc8f776608b01df7cbaf Mon Sep 17 00:00:00 2001 From: abersheerna Date: Sat, 1 Jan 2022 06:50:01 +0000 Subject: [PATCH] Use a2wsgi.WSGIMiddleware replace WSGIMiddleware --- setup.cfg | 2 - setup.py | 1 + tests/middleware/test_wsgi.py | 125 ------------------------ tests/test_config.py | 2 +- uvicorn/config.py | 11 ++- uvicorn/middleware/wsgi.py | 179 ---------------------------------- 6 files changed, 12 insertions(+), 308 deletions(-) delete mode 100644 tests/middleware/test_wsgi.py delete mode 100644 uvicorn/middleware/wsgi.py diff --git a/setup.cfg b/setup.cfg index 3a829c12a..9f39fb7df 100644 --- a/setup.cfg +++ b/setup.cfg @@ -25,8 +25,6 @@ files = uvicorn/protocols/websockets/auto.py, uvicorn/supervisors/__init__.py, uvicorn/middleware/debug.py, - uvicorn/middleware/wsgi.py, - tests/middleware/test_wsgi.py, uvicorn/supervisors/watchgodreload.py, uvicorn/logging.py, uvicorn/middleware/asgi2.py, diff --git a/setup.py b/setup.py index bd0ecc215..1b42718a4 100755 --- a/setup.py +++ b/setup.py @@ -62,6 +62,7 @@ def get_packages(package): "watchgod>=0.6", "python-dotenv>=0.13", "PyYAML>=5.1", + "a2wsgi>=1.4.0,<2.0.0" ] diff --git a/tests/middleware/test_wsgi.py b/tests/middleware/test_wsgi.py deleted file mode 100644 index a9673c3f6..000000000 --- a/tests/middleware/test_wsgi.py +++ /dev/null @@ -1,125 +0,0 @@ -import sys -from typing import List - -import httpx -import pytest -from asgiref.typing import HTTPRequestEvent, HTTPScope - -from uvicorn._types import Environ, StartResponse -from uvicorn.middleware.wsgi import WSGIMiddleware, build_environ - - -def hello_world(environ: Environ, start_response: StartResponse) -> List[bytes]: - status = "200 OK" - output = b"Hello World!\n" - headers = [ - ("Content-Type", "text/plain; charset=utf-8"), - ("Content-Length", str(len(output))), - ] - start_response(status, headers, None) - return [output] - - -def echo_body(environ: Environ, start_response: StartResponse) -> List[bytes]: - status = "200 OK" - output = environ["wsgi.input"].read() - headers = [ - ("Content-Type", "text/plain; charset=utf-8"), - ("Content-Length", str(len(output))), - ] - start_response(status, headers, None) - return [output] - - -def raise_exception(environ: Environ, start_response: StartResponse) -> RuntimeError: - raise RuntimeError("Something went wrong") - - -def return_exc_info(environ: Environ, start_response: StartResponse) -> List[bytes]: - try: - raise RuntimeError("Something went wrong") - except RuntimeError: - status = "500 Internal Server Error" - output = b"Internal Server Error" - headers = [ - ("Content-Type", "text/plain; charset=utf-8"), - ("Content-Length", str(len(output))), - ] - start_response(status, headers, sys.exc_info()) # type: ignore[arg-type] - return [output] - - -@pytest.mark.asyncio -async def test_wsgi_get() -> None: - app = WSGIMiddleware(hello_world) - async with httpx.AsyncClient(app=app, base_url="http://testserver") as client: - response = await client.get("/") - assert response.status_code == 200 - assert response.text == "Hello World!\n" - - -@pytest.mark.asyncio -async def test_wsgi_post() -> None: - app = WSGIMiddleware(echo_body) - async with httpx.AsyncClient(app=app, base_url="http://testserver") as client: - response = await client.post("/", json={"example": 123}) - assert response.status_code == 200 - assert response.text == '{"example": 123}' - - -@pytest.mark.asyncio -async def test_wsgi_exception() -> None: - # Note that we're testing the WSGI app directly here. - # The HTTP protocol implementations would catch this error and return 500. - app = WSGIMiddleware(raise_exception) - async with httpx.AsyncClient(app=app, base_url="http://testserver") as client: - with pytest.raises(RuntimeError): - await client.get("/") - - -@pytest.mark.asyncio -async def test_wsgi_exc_info() -> None: - # Note that we're testing the WSGI app directly here. - # The HTTP protocol implementations would catch this error and return 500. - app = WSGIMiddleware(return_exc_info) - async with httpx.AsyncClient(app=app, base_url="http://testserver") as client: - with pytest.raises(RuntimeError): - response = await client.get("/") - - app = WSGIMiddleware(return_exc_info) - transport = httpx.ASGITransport( - app=app, - raise_app_exceptions=False, - ) - async with httpx.AsyncClient( - transport=transport, base_url="http://testserver" - ) as client: - response = await client.get("/") - assert response.status_code == 500 - assert response.text == "Internal Server Error" - - -def test_build_environ_encoding() -> None: - scope: HTTPScope = { - "asgi": {"version": "3.0", "spec_version": "2.0"}, - "scheme": "http", - "raw_path": b"/\xe6\x96\x87", - "type": "http", - "http_version": "1.1", - "method": "GET", - "path": "/文", - "root_path": "/文", - "client": None, - "server": None, - "query_string": b"a=123&b=456", - "headers": [(b"key", b"value1"), (b"key", b"value2")], - "extensions": {}, - } - message: HTTPRequestEvent = { - "type": "http.request", - "body": b"", - "more_body": False, - } - environ = build_environ(scope, message, b"") - assert environ["PATH_INFO"] == "/文".encode("utf8").decode("latin-1") - assert environ["HTTP_KEY"] == "value1,value2" diff --git a/tests/test_config.py b/tests/test_config.py index 265d09b36..c436d3cdb 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -15,6 +15,7 @@ import pytest import yaml +from a2wsgi import WSGIMiddleware from asgiref.typing import ASGIApplication, ASGIReceiveCallable, ASGISendCallable, Scope from pytest_mock import MockerFixture @@ -23,7 +24,6 @@ from uvicorn.config import LOGGING_CONFIG, Config from uvicorn.middleware.debug import DebugMiddleware from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware -from uvicorn.middleware.wsgi import WSGIMiddleware from uvicorn.protocols.http.h11_impl import H11Protocol diff --git a/uvicorn/config.py b/uvicorn/config.py index 681ca0caf..89aa16aae 100644 --- a/uvicorn/config.py +++ b/uvicorn/config.py @@ -28,12 +28,21 @@ # enable this functionality. pass +try: + from a2wsgi import WSGIMiddleware +except ImportError: + from uvicorn._types import WSGIApp + + class WSGIMiddleware: # type: ignore + def __init__(self, app: WSGIApp, workers: int = 10): + raise RuntimeError("Please install `a2wsgi` for serving WSGI applications") + + from uvicorn.importer import ImportFromStringError, import_from_string from uvicorn.middleware.asgi2 import ASGI2Middleware from uvicorn.middleware.debug import DebugMiddleware from uvicorn.middleware.message_logger import MessageLoggerMiddleware from uvicorn.middleware.proxy_headers import ProxyHeadersMiddleware -from uvicorn.middleware.wsgi import WSGIMiddleware HTTPProtocolType = Literal["auto", "h11", "httptools"] WSProtocolType = Literal["auto", "none", "websockets", "wsproto"] diff --git a/uvicorn/middleware/wsgi.py b/uvicorn/middleware/wsgi.py deleted file mode 100644 index 74bdfada4..000000000 --- a/uvicorn/middleware/wsgi.py +++ /dev/null @@ -1,179 +0,0 @@ -import asyncio -import concurrent.futures -import io -import sys -from collections import deque -from typing import Deque, Iterable, Optional, Tuple - -from asgiref.typing import ( - ASGIReceiveCallable, - ASGIReceiveEvent, - ASGISendCallable, - ASGISendEvent, - HTTPRequestEvent, - HTTPResponseBodyEvent, - HTTPResponseStartEvent, - HTTPScope, -) - -from uvicorn._types import Environ, ExcInfo, StartResponse, WSGIApp - - -def build_environ(scope: HTTPScope, message: ASGIReceiveEvent, body: bytes) -> Environ: - """ - Builds a scope and request message into a WSGI environ object. - """ - environ = { - "REQUEST_METHOD": scope["method"], - "SCRIPT_NAME": "", - "PATH_INFO": scope["path"].encode("utf8").decode("latin1"), - "QUERY_STRING": scope["query_string"].decode("ascii"), - "SERVER_PROTOCOL": "HTTP/%s" % scope["http_version"], - "wsgi.version": (1, 0), - "wsgi.url_scheme": scope.get("scheme", "http"), - "wsgi.input": io.BytesIO(body), - "wsgi.errors": sys.stdout, - "wsgi.multithread": True, - "wsgi.multiprocess": True, - "wsgi.run_once": False, - } - - # Get server name and port - required in WSGI, not in ASGI - server = scope.get("server") - if server is None: - server = ("localhost", 80) - environ["SERVER_NAME"] = server[0] - environ["SERVER_PORT"] = server[1] - - # Get client IP address - client = scope.get("client") - if client is not None: - environ["REMOTE_ADDR"] = client[0] - - # Go through headers and make them into environ entries - for name, value in scope.get("headers", []): - name_str: str = name.decode("latin1") - if name_str == "content-length": - corrected_name = "CONTENT_LENGTH" - elif name_str == "content-type": - corrected_name = "CONTENT_TYPE" - else: - corrected_name = "HTTP_%s" % name_str.upper().replace("-", "_") - # HTTPbis say only ASCII chars are allowed in headers, but we latin1 - # just in case - value_str: str = value.decode("latin1") - if corrected_name in environ: - corrected_name_environ = environ[corrected_name] - assert isinstance(corrected_name_environ, str) - value_str = corrected_name_environ + "," + value_str - environ[corrected_name] = value_str - return environ - - -class WSGIMiddleware: - def __init__(self, app: WSGIApp, workers: int = 10): - self.app = app - self.executor = concurrent.futures.ThreadPoolExecutor(max_workers=workers) - - async def __call__( - self, scope: HTTPScope, receive: ASGIReceiveCallable, send: ASGISendCallable - ) -> None: - assert scope["type"] == "http" - instance = WSGIResponder(self.app, self.executor, scope) - await instance(receive, send) - - -class WSGIResponder: - def __init__( - self, - app: WSGIApp, - executor: concurrent.futures.ThreadPoolExecutor, - scope: HTTPScope, - ): - self.app = app - self.executor = executor - self.scope = scope - self.status = None - self.response_headers = None - self.send_event = asyncio.Event() - self.send_queue: Deque[Optional[ASGISendEvent]] = deque() - self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() - self.response_started = False - self.exc_info: Optional[ExcInfo] = None - - async def __call__( - self, receive: ASGIReceiveCallable, send: ASGISendCallable - ) -> None: - message: HTTPRequestEvent = await receive() # type: ignore[assignment] - body = message.get("body", b"") - more_body = message.get("more_body", False) - while more_body: - body_message: HTTPRequestEvent = await receive() # type: ignore[assignment] - body += body_message.get("body", b"") - more_body = body_message.get("more_body", False) - environ = build_environ(self.scope, message, body) - self.loop = asyncio.get_event_loop() - wsgi = self.loop.run_in_executor( - self.executor, self.wsgi, environ, self.start_response - ) - sender = self.loop.create_task(self.sender(send)) - try: - await asyncio.wait_for(wsgi, None) - finally: - self.send_queue.append(None) - self.send_event.set() - await asyncio.wait_for(sender, None) - if self.exc_info is not None: - raise self.exc_info[0].with_traceback(self.exc_info[1], self.exc_info[2]) - - async def sender(self, send: ASGISendCallable) -> None: - while True: - if self.send_queue: - message = self.send_queue.popleft() - if message is None: - return - await send(message) - else: - await self.send_event.wait() - self.send_event.clear() - - def start_response( - self, - status: str, - response_headers: Iterable[Tuple[str, str]], - exc_info: Optional[ExcInfo] = None, - ) -> None: - self.exc_info = exc_info - if not self.response_started: - self.response_started = True - status_code_str, _ = status.split(" ", 1) - status_code = int(status_code_str) - headers = [ - (name.encode("ascii"), value.encode("ascii")) - for name, value in response_headers - ] - http_response_start_event: HTTPResponseStartEvent = { - "type": "http.response.start", - "status": status_code, - "headers": headers, - } - self.send_queue.append(http_response_start_event) - self.loop.call_soon_threadsafe(self.send_event.set) - - def wsgi(self, environ: Environ, start_response: StartResponse) -> None: - for chunk in self.app(environ, start_response): # type: ignore - response_body: HTTPResponseBodyEvent = { - "type": "http.response.body", - "body": chunk, - "more_body": True, - } - self.send_queue.append(response_body) - self.loop.call_soon_threadsafe(self.send_event.set) - - empty_body: HTTPResponseBodyEvent = { - "type": "http.response.body", - "body": b"", - "more_body": False, - } - self.send_queue.append(empty_body) - self.loop.call_soon_threadsafe(self.send_event.set)