diff --git a/example/wsgi_example/settings.py b/example/wsgi_example/settings.py index 44a3afb16..85ae1511f 100644 --- a/example/wsgi_example/settings.py +++ b/example/wsgi_example/settings.py @@ -23,7 +23,7 @@ SECRET_KEY = "0#*&9l*zrdz9u8%oxof_=yf5m8=el2#cv!%)y+yynt$pmvs46o" # SECURITY WARNING: don't run with debug turned on in production! -DEBUG = False +DEBUG = True ALLOWED_HOSTS = ["*"] diff --git a/indexpy/applications.py b/indexpy/applications.py index 191453e97..8e48f2361 100644 --- a/indexpy/applications.py +++ b/indexpy/applications.py @@ -14,7 +14,6 @@ from starlette.responses import RedirectResponse from starlette.middleware import Middleware from starlette.middleware.errors import ServerErrorMiddleware -from starlette.middleware.wsgi import WSGIMiddleware from starlette.middleware.cors import CORSMiddleware from starlette.middleware.gzip import GZipMiddleware from starlette.middleware.trustedhost import TrustedHostMiddleware @@ -30,6 +29,7 @@ after_response_tasks_var, finished_response_tasks_var, ) +from .wsgi import WSGIMiddleware class Lifespan: diff --git a/indexpy/wsgi.py b/indexpy/wsgi.py new file mode 100644 index 000000000..d65711da7 --- /dev/null +++ b/indexpy/wsgi.py @@ -0,0 +1,191 @@ +import asyncio +import sys +import typing +from tempfile import SpooledTemporaryFile + +from starlette.concurrency import run_in_threadpool +from starlette.types import Message, Receive, Scope, Send + + +class Body: + def __init__(self, recv_event: asyncio.Event) -> None: + self.file = SpooledTemporaryFile(1024 * 1024) + self.recv_event = recv_event + self.has_more = True + + def feed_eof(self) -> None: + self.has_more = False + + async def write(self, data: bytes) -> None: + await run_in_threadpool(self.file.write, data) + + def close(self) -> None: + self.file.close() + + def read(self, size: int) -> bytes: + data = self.file.read(size) + while len(data) < size and self.has_more: + data += self.file.read() + return data + + def readline(self, limit: int = -1) -> bytes: + data = self.file.readline(limit) + while (not data or not data.endswith(b"\n")) and self.has_more: + data += self.file.readline(limit - len(data)) + return data + + def readlines(self, hint: int = -1) -> typing.List[bytes]: + data = self.file.readlines(hint) + _hint = data.count(b"\n") + while _hint < hint and self.has_more: + data += self.file.readlines(hint - _hint) + _hint = data.count(b"\n") + return data + + def __iter__(self) -> typing.Generator: + while self.has_more: + yield self.readline() + + +def build_environ(scope: Scope, body: Body) -> dict: + """ + Builds a scope and request body into a WSGI environ object. + """ + environ = { + "REQUEST_METHOD": scope["method"], + "SCRIPT_NAME": scope.get("root_path", ""), + "PATH_INFO": scope["path"], + "QUERY_STRING": scope["query_string"].decode("ascii"), + "SERVER_PROTOCOL": f"HTTP/{scope['http_version']}", + "wsgi.version": (1, 0), + "wsgi.url_scheme": scope.get("scheme", "http"), + "wsgi.input": body, + "wsgi.errors": sys.stdout, + "wsgi.multithread": True, + "wsgi.multiprocess": True, + "wsgi.run_once": False, + } + + # Get server name and port - required in WSGI, not in ASGI + server = scope.get("server") or ("localhost", 80) + environ["SERVER_NAME"] = server[0] + environ["SERVER_PORT"] = server[1] + + # Get client IP address + if scope.get("client"): + environ["REMOTE_ADDR"] = scope["client"][0] + + # Go through headers and make them into environ entries + for name, value in scope.get("headers", []): + name = name.decode("latin1") + if name == "content-length": + corrected_name = "CONTENT_LENGTH" + elif name == "content-type": + corrected_name = "CONTENT_TYPE" + else: + corrected_name = f"HTTP_{name}".upper().replace("-", "_") + # HTTPbis say only ASCII chars are allowed in headers, but we latin1 just in case + value = value.decode("latin1") + if corrected_name in environ: + value = environ[corrected_name] + "," + value + environ[corrected_name] = value + return environ + + +class WSGIMiddleware: + def __init__(self, app: typing.Callable, workers: int = 10) -> None: + self.app = app + + async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None: + assert scope["type"] == "http" + responder = WSGIResponder(self.app, scope) + await responder(receive, send) + + +class WSGIResponder: + def __init__(self, app: typing.Callable, scope: Scope) -> None: + self.app = app + self.scope = scope + self.recv_event = asyncio.Event() + self.send_event = asyncio.Event() + self.send_queue = [] # type: typing.List[typing.Optional[Message]] + self.loop = asyncio.get_event_loop() + self.response_started = False + self.exc_info = None # type: typing.Any + + async def __call__(self, receive: Receive, send: Send) -> None: + body = Body(self.recv_event) + environ = build_environ(self.scope, body) + sender = None + receiver = None + try: + sender = self.loop.create_task(self.sender(send)) + receiver = self.loop.create_task(self.recevier(receive, body)) + await run_in_threadpool(self.wsgi, environ, self.start_response) + self.send_queue.append(None) + self.send_event.set() + await asyncio.wait_for(sender, None) + if self.exc_info is not None: + raise self.exc_info[0].with_traceback( + self.exc_info[1], self.exc_info[2] + ) + finally: + if sender and not sender.done(): + sender.cancel() # pragma: no cover + if receiver and not receiver.done(): + receiver.cancel() # pragma: no cover + body.close() + + async def recevier(self, receive: Receive, body: Body) -> None: + more_body = True + await self.recv_event.wait() + while more_body: + message = await receive() + await body.write(message.get("body", b"")) + more_body = message.get("more_body", False) + body.feed_eof() + + async def sender(self, send: Send) -> None: + while True: + if self.send_queue: + message = self.send_queue.pop(0) + if message is None: + return + await send(message) + else: + await self.send_event.wait() + self.send_event.clear() + + def start_response( + self, + status: str, + response_headers: typing.List[typing.Tuple[str, str]], + exc_info: typing.Any = None, + ) -> None: + self.exc_info = exc_info + if not self.response_started: + self.response_started = True + status_code_string, _ = status.split(" ", 1) + status_code = int(status_code_string) + headers = [ + (name.strip().encode("ascii").lower(), value.strip().encode("ascii")) + for name, value in response_headers + ] + self.send_queue.append( + { + "type": "http.response.start", + "status": status_code, + "headers": headers, + } + ) + self.loop.call_soon_threadsafe(self.send_event.set) + + def wsgi(self, environ: dict, start_response: typing.Callable) -> None: + for chunk in self.app(environ, start_response): + self.send_queue.append( + {"type": "http.response.body", "body": chunk, "more_body": True} + ) + self.loop.call_soon_threadsafe(self.send_event.set) + + self.send_queue.append({"type": "http.response.body", "body": b""}) + self.loop.call_soon_threadsafe(self.send_event.set) diff --git a/tests/test_wsgi.py b/tests/test_wsgi.py new file mode 100644 index 000000000..202726097 --- /dev/null +++ b/tests/test_wsgi.py @@ -0,0 +1,130 @@ +import sys + +import pytest + +from starlette.middleware.wsgi import WSGIMiddleware, build_environ +from starlette.testclient import TestClient + + +def hello_world(environ, start_response): + status = "200 OK" + output = b"Hello World!\n" + headers = [ + ("Content-Type", "text/plain; charset=utf-8"), + ("Content-Length", str(len(output))), + ] + start_response(status, headers) + return [output] + + +def echo_body(environ, start_response): + status = "200 OK" + output = environ["wsgi.input"].read() + headers = [ + ("Content-Type", "text/plain; charset=utf-8"), + ("Content-Length", str(len(output))), + ] + start_response(status, headers) + return [output] + + +def raise_exception(environ, start_response): + raise RuntimeError("Something went wrong") + + +def return_exc_info(environ, start_response): + try: + raise RuntimeError("Something went wrong") + except RuntimeError: + status = "500 Internal Server Error" + output = b"Internal Server Error" + headers = [ + ("Content-Type", "text/plain; charset=utf-8"), + ("Content-Length", str(len(output))), + ] + start_response(status, headers, exc_info=sys.exc_info()) + return [output] + + +def test_wsgi_get(): + app = WSGIMiddleware(hello_world) + client = TestClient(app) + response = client.get("/") + assert response.status_code == 200 + assert response.text == "Hello World!\n" + + +def test_wsgi_post(): + app = WSGIMiddleware(echo_body) + client = TestClient(app) + response = client.post("/", json={"example": 123}) + assert response.status_code == 200 + assert response.text == '{"example": 123}' + + +def test_wsgi_exception(): + # Note that we're testing the WSGI app directly here. + # The HTTP protocol implementations would catch this error and return 500. + app = WSGIMiddleware(raise_exception) + client = TestClient(app) + with pytest.raises(RuntimeError): + client.get("/") + + +def test_wsgi_exc_info(): + # Note that we're testing the WSGI app directly here. + # The HTTP protocol implementations would catch this error and return 500. + app = WSGIMiddleware(return_exc_info) + client = TestClient(app) + with pytest.raises(RuntimeError): + response = client.get("/") + + app = WSGIMiddleware(return_exc_info) + client = TestClient(app, raise_server_exceptions=False) + response = client.get("/") + assert response.status_code == 500 + assert response.text == "Internal Server Error" + + +def test_build_environ(): + scope = { + "type": "http", + "http_version": "1.1", + "method": "GET", + "scheme": "https", + "path": "/", + "query_string": b"a=123&b=456", + "headers": [ + (b"host", b"www.example.org"), + (b"content-type", b"application/json"), + (b"content-length", b"18"), + (b"accept", b"application/json"), + (b"accept", b"text/plain"), + ], + "client": ("134.56.78.4", 1453), + "server": ("www.example.org", 443), + } + body = b'{"example":"body"}' + environ = build_environ(scope, body) + stream = environ.pop("wsgi.input") + assert stream.read() == b'{"example":"body"}' + assert environ == { + "CONTENT_LENGTH": "18", + "CONTENT_TYPE": "application/json", + "HTTP_ACCEPT": "application/json,text/plain", + "HTTP_HOST": "www.example.org", + "PATH_INFO": "/", + "QUERY_STRING": "a=123&b=456", + "REMOTE_ADDR": "134.56.78.4", + "REQUEST_METHOD": "GET", + "SCRIPT_NAME": "", + "SERVER_NAME": "www.example.org", + "SERVER_PORT": 443, + "SERVER_PROTOCOL": "HTTP/1.1", + "wsgi.errors": sys.stdout, + "wsgi.multiprocess": True, + "wsgi.multithread": True, + "wsgi.run_once": False, + "wsgi.url_scheme": "https", + "wsgi.version": (1, 0), + }