From bcddbc076eaadc5b870b6579b26324e3ae711bd3 Mon Sep 17 00:00:00 2001 From: Amin Alaee Date: Wed, 8 Dec 2021 15:50:45 +0100 Subject: [PATCH] Add Mypy checks to tests Add support for functools.partial in WebsocketRoute (#1356) * Add support for functools.partial in WebsocketRoute * remove commented code * Refactor tests for partian endpoint and ws Fix snippet on config.md (#1358) * Fix snippet on config.md * Update docs/config.md Co-authored-by: Amin Alaee Co-authored-by: Amin Alaee remove outdated test --- docs/config.md | 5 +- setup.cfg | 3 +- starlette/routing.py | 5 +- tests/middleware/test_cors.py | 2 +- tests/middleware/test_session.py | 8 ++- tests/test_database.py | 2 + tests/test_datastructures.py | 8 +-- tests/test_formparsers.py | 9 ++-- tests/test_routing.py | 89 ++++++++++++++++++++------------ tests/test_templates.py | 2 +- 10 files changed, 82 insertions(+), 51 deletions(-) diff --git a/docs/config.md b/docs/config.md index f7c2c7b7d..52a801568 100644 --- a/docs/config.md +++ b/docs/config.md @@ -128,8 +128,9 @@ application logic separated: **myproject/settings.py**: ```python +import databases from starlette.config import Config -from starlette.datastructures import URL, Secret +from starlette.datastructures import Secret config = Config(".env") @@ -137,7 +138,7 @@ DEBUG = config('DEBUG', cast=bool, default=False) TESTING = config('TESTING', cast=bool, default=False) SECRET_KEY = config('SECRET_KEY', cast=Secret) -DATABASE_URL = config('DATABASE_URL', cast=URL) +DATABASE_URL = config('DATABASE_URL', cast=databases.DatabaseURL) if TESTING: DATABASE_URL = DATABASE_URL.replace(database='test_' + DATABASE_URL.database) ``` diff --git a/setup.cfg b/setup.cfg index 20c2588a1..09ecda353 100644 --- a/setup.cfg +++ b/setup.cfg @@ -8,8 +8,7 @@ ignore_missing_imports = True [mypy-tests.*] disallow_untyped_defs = False -# https://github.com/encode/starlette/issues/1045 -# check_untyped_defs = True +check_untyped_defs = True [tool:isort] profile = black diff --git a/starlette/routing.py b/starlette/routing.py index 3c11c1b0c..982980c3c 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -276,7 +276,10 @@ def __init__( self.endpoint = endpoint self.name = get_name(endpoint) if name is None else name - if inspect.isfunction(endpoint) or inspect.ismethod(endpoint): + endpoint_handler = endpoint + while isinstance(endpoint_handler, functools.partial): + endpoint_handler = endpoint_handler.func + if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler): # Endpoint is function or method. Treat it as `func(websocket)`. self.app = websocket_session(endpoint) else: diff --git a/tests/middleware/test_cors.py b/tests/middleware/test_cors.py index 65252e502..2f0ca3d34 100644 --- a/tests/middleware/test_cors.py +++ b/tests/middleware/test_cors.py @@ -258,7 +258,7 @@ def test_cors_allow_all_methods(test_client_factory): ) @app.route( - "/", methods=("delete", "get", "head", "options", "patch", "post", "put") + "/", methods=["delete", "get", "head", "options", "patch", "post", "put"] ) def homepage(request): return PlainTextResponse("Homepage", status_code=200) diff --git a/tests/middleware/test_session.py b/tests/middleware/test_session.py index 42f4447e5..07296bcbb 100644 --- a/tests/middleware/test_session.py +++ b/tests/middleware/test_session.py @@ -66,7 +66,9 @@ def test_session_expires(test_client_factory): # requests removes expired cookies from response.cookies, we need to # fetch session id from the headers and pass it explicitly expired_cookie_header = response.headers["set-cookie"] - expired_session_value = re.search(r"session=([^;]*);", expired_cookie_header)[1] + expired_session_match = re.search(r"session=([^;]*);", expired_cookie_header) + assert expired_session_match is not None + expired_session_value = expired_session_match[1] response = client.get("/view_session", cookies={"session": expired_session_value}) assert response.json() == {"session": {}} @@ -110,7 +112,9 @@ def test_session_cookie_subpath(test_client_factory): client = test_client_factory(app, base_url="http://testserver/second_app") response = client.post("second_app/update_session", json={"some": "data"}) cookie = response.headers["set-cookie"] - cookie_path = re.search(r"; path=(\S+);", cookie).groups()[0] + cookie_path_match = re.search(r"; path=(\S+);", cookie) + assert cookie_path_match is not None + cookie_path = cookie_path_match.groups()[0] assert cookie_path == "/second_app" diff --git a/tests/test_database.py b/tests/test_database.py index c0a4745d1..11f770bb1 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -77,6 +77,7 @@ async def read_note(request): note_id = request.path_params["note_id"] query = notes.select().where(notes.c.id == note_id) result = await database.fetch_one(query) + assert result is not None content = {"text": result["text"], "completed": result["completed"]} return JSONResponse(content) @@ -86,6 +87,7 @@ async def read_note_text(request): note_id = request.path_params["note_id"] query = sqlalchemy.select([notes.c.text]).where(notes.c.id == note_id) result = await database.fetch_one(query) + assert result is not None return JSONResponse(result[0]) diff --git a/tests/test_datastructures.py b/tests/test_datastructures.py index bb71ba870..258fb45c1 100644 --- a/tests/test_datastructures.py +++ b/tests/test_datastructures.py @@ -1,5 +1,3 @@ -import io - import pytest from starlette.datastructures import ( @@ -228,7 +226,7 @@ async def test_upload_file(): def test_formdata(): - upload = io.BytesIO(b"test") + upload = UploadFile("test") form = FormData([("a", "123"), ("a", "456"), ("b", upload)]) assert "a" in form assert "A" not in form @@ -338,10 +336,6 @@ def test_multidict(): q.update(q) assert repr(q) == "MultiDict([('a', '123'), ('b', '456')])" - q = MultiDict([("a", "123"), ("b", "456")]) - q.update(None) - assert repr(q) == "MultiDict([('a', '123'), ('b', '456')])" - q = MultiDict([("a", "123"), ("a", "456")]) q.update([("a", "123")]) assert q.getlist("a") == ["123"] diff --git a/tests/test_formparsers.py b/tests/test_formparsers.py index 8a1174e1d..4d614e9e0 100644 --- a/tests/test_formparsers.py +++ b/tests/test_formparsers.py @@ -1,4 +1,5 @@ import os +import typing from starlette.formparsers import UploadFile, _user_safe_decode from starlette.requests import Request @@ -21,9 +22,10 @@ async def app(scope, receive, send): for key, value in data.items(): if isinstance(value, UploadFile): content = await value.read() + content = content.decode() if isinstance(content, bytes) else content output[key] = { "filename": value.filename, - "content": content.decode(), + "content": content, "content_type": value.content_type, } else: @@ -36,16 +38,17 @@ async def app(scope, receive, send): async def multi_items_app(scope, receive, send): request = Request(scope, receive) data = await request.form() - output = {} + output: typing.Dict[str, list] = {} for key, value in data.multi_items(): if key not in output: output[key] = [] if isinstance(value, UploadFile): content = await value.read() + content = content.decode() if isinstance(content, bytes) else content output[key].append( { "filename": value.filename, - "content": content.decode(), + "content": content, "content_type": value.content_type, } ) diff --git a/tests/test_routing.py b/tests/test_routing.py index e1374cc5d..231c581fb 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -32,6 +32,28 @@ def user_no_match(request): # pragma: no cover return Response(content, media_type="text/plain") +async def partial_endpoint(arg, request): + return JSONResponse({"arg": arg}) + + +async def partial_ws_endpoint(websocket: WebSocket): + await websocket.accept() + await websocket.send_json({"url": str(websocket.url)}) + await websocket.close() + + +class PartialRoutes: + @classmethod + async def async_endpoint(cls, arg, request): + return JSONResponse({"arg": arg}) + + @classmethod + async def async_ws_endpoint(cls, websocket: WebSocket): + await websocket.accept() + await websocket.send_json({"url": str(websocket.url)}) + await websocket.close() + + app = Router( [ Route("/", endpoint=homepage, methods=["GET"]), @@ -44,6 +66,21 @@ def user_no_match(request): # pragma: no cover Route("/nomatch", endpoint=user_no_match), ], ), + Mount( + "/partial", + routes=[ + Route("/", endpoint=functools.partial(partial_endpoint, "foo")), + Route( + "/cls", + endpoint=functools.partial(PartialRoutes.async_endpoint, "foo"), + ), + WebSocketRoute("/ws", endpoint=functools.partial(partial_ws_endpoint)), + WebSocketRoute( + "/ws/cls", + endpoint=functools.partial(PartialRoutes.async_ws_endpoint), + ), + ], + ), Mount("/static", app=Response("xxxxx", media_type="image/png")), ] ) @@ -91,14 +128,14 @@ def path_with_parentheses(request): @app.websocket_route("/ws") -async def websocket_endpoint(session): +async def websocket_endpoint(session: WebSocket): await session.accept() await session.send_text("Hello, world!") await session.close() @app.websocket_route("/ws/{room}") -async def websocket_params(session): +async def websocket_params(session: WebSocket): await session.accept() await session.send_text(f"Hello, {session.path_params['room']}!") await session.close() @@ -422,13 +459,13 @@ async def subdomain_app(scope, receive, send): await response(scope, receive, send) -subdomain_app = Router( +subdomain_router = Router( routes=[Host("{subdomain}.example.org", app=subdomain_app, name="subdomains")] ) def test_subdomain_routing(test_client_factory): - client = test_client_factory(subdomain_app, base_url="https://foo.example.org/") + client = test_client_factory(subdomain_router, base_url="https://foo.example.org/") response = client.get("/") assert response.status_code == 200 @@ -437,7 +474,7 @@ def test_subdomain_routing(test_client_factory): def test_subdomain_reverse_urls(): assert ( - subdomain_app.url_path_for( + subdomain_router.url_path_for( "subdomains", subdomain="foo", path="/homepage" ).make_absolute_url("https://whatever") == "https://foo.example.org/homepage" @@ -600,6 +637,7 @@ def run_startup(): raise RuntimeError() router = Router(on_startup=[run_startup]) + startup_failed = False async def app(scope, receive, send): async def _send(message): @@ -610,7 +648,6 @@ async def _send(message): await router(scope, receive, _send) - startup_failed = False with pytest.raises(RuntimeError): with test_client_factory(app): pass # pragma: nocover @@ -628,40 +665,28 @@ def run_shutdown(): pass # pragma: nocover -class AsyncEndpointClassMethod: - @classmethod - async def async_endpoint(cls, arg, request): - return JSONResponse({"arg": arg}) - - -async def _partial_async_endpoint(arg, request): - return JSONResponse({"arg": arg}) - - -partial_async_endpoint = functools.partial(_partial_async_endpoint, "foo") -partial_cls_async_endpoint = functools.partial( - AsyncEndpointClassMethod.async_endpoint, "foo" -) - -partial_async_app = Router( - routes=[ - Route("/", partial_async_endpoint), - Route("/cls", partial_cls_async_endpoint), - ] -) - - def test_partial_async_endpoint(test_client_factory): - test_client = test_client_factory(partial_async_app) - response = test_client.get("/") + test_client = test_client_factory(app) + response = test_client.get("/partial") assert response.status_code == 200 assert response.json() == {"arg": "foo"} - cls_method_response = test_client.get("/cls") + cls_method_response = test_client.get("/partial/cls") assert cls_method_response.status_code == 200 assert cls_method_response.json() == {"arg": "foo"} +def test_partial_async_ws_endpoint(test_client_factory): + test_client = test_client_factory(app) + with test_client.websocket_connect("/partial/ws") as websocket: + data = websocket.receive_json() + assert data == {"url": "ws://testserver/partial/ws"} + + with test_client.websocket_connect("/partial/ws/cls") as websocket: + data = websocket.receive_json() + assert data == {"url": "ws://testserver/partial/ws/cls"} + + def test_duplicated_param_names(): with pytest.raises( ValueError, diff --git a/tests/test_templates.py b/tests/test_templates.py index 073482d65..aa8279348 100644 --- a/tests/test_templates.py +++ b/tests/test_templates.py @@ -28,4 +28,4 @@ async def homepage(request): def test_template_response_requires_request(tmpdir): templates = Jinja2Templates(str(tmpdir)) with pytest.raises(ValueError): - templates.TemplateResponse(None, {}) + templates.TemplateResponse("", {})