From cd8d58901f10b2deaf01711c7c6f02c0d21adffa Mon Sep 17 00:00:00 2001 From: Amin Alaee Date: Fri, 10 Dec 2021 15:22:11 +0100 Subject: [PATCH 1/3] Add support for functools.partial in WebsocketRoute --- starlette/routing.py | 12 +++++++++++- tests/test_routing.py | 14 ++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/starlette/routing.py b/starlette/routing.py index 3c11c1b0c..e9cd990d5 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -276,13 +276,23 @@ def __init__( self.endpoint = endpoint self.name = get_name(endpoint) if name is None else name - if inspect.isfunction(endpoint) or inspect.ismethod(endpoint): + endpoint_handler = endpoint + while isinstance(endpoint_handler, functools.partial): + endpoint_handler = endpoint_handler.func + if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler): # Endpoint is function or method. Treat it as `func(websocket)`. self.app = websocket_session(endpoint) else: # Endpoint is a class. Treat it as ASGI. self.app = endpoint + # if inspect.isfunction(endpoint) or inspect.ismethod(endpoint): + # # Endpoint is function or method. Treat it as `func(websocket)`. + # self.app = websocket_session(endpoint) + # else: + # # Endpoint is a class. Treat it as ASGI. + # self.app = endpoint + self.path_regex, self.path_format, self.param_convertors = compile_path(path) def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: diff --git a/tests/test_routing.py b/tests/test_routing.py index e1374cc5d..91bc26527 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -32,6 +32,12 @@ def user_no_match(request): # pragma: no cover return Response(content, media_type="text/plain") +async def users_ws(websocket: WebSocket): + await websocket.accept() + await websocket.send_json({"url": str(websocket.url)}) + await websocket.close() + + app = Router( [ Route("/", endpoint=homepage, methods=["GET"]), @@ -42,6 +48,7 @@ def user_no_match(request): # pragma: no cover Route("/me", endpoint=user_me), Route("/{username}", endpoint=user), Route("/nomatch", endpoint=user_no_match), + WebSocketRoute("/ws", endpoint=functools.partial(users_ws)), ], ), Mount("/static", app=Response("xxxxx", media_type="image/png")), @@ -662,6 +669,13 @@ def test_partial_async_endpoint(test_client_factory): assert cls_method_response.json() == {"arg": "foo"} +def test_partial_async_ws_endpoint(test_client_factory): + test_client = test_client_factory(app) + with test_client.websocket_connect("/users/ws") as websocket: + data = websocket.receive_json() + assert data == {"url": "ws://testserver/users/ws"} + + def test_duplicated_param_names(): with pytest.raises( ValueError, From 318ca783b8712f4cb0d463bcf4637fa8075ea2c2 Mon Sep 17 00:00:00 2001 From: Amin Alaee Date: Fri, 10 Dec 2021 15:26:06 +0100 Subject: [PATCH 2/3] remove commented code --- starlette/routing.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/starlette/routing.py b/starlette/routing.py index e9cd990d5..982980c3c 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -286,13 +286,6 @@ def __init__( # Endpoint is a class. Treat it as ASGI. self.app = endpoint - # if inspect.isfunction(endpoint) or inspect.ismethod(endpoint): - # # Endpoint is function or method. Treat it as `func(websocket)`. - # self.app = websocket_session(endpoint) - # else: - # # Endpoint is a class. Treat it as ASGI. - # self.app = endpoint - self.path_regex, self.path_format, self.param_convertors = compile_path(path) def matches(self, scope: Scope) -> typing.Tuple[Match, Scope]: From 464d43c6ed04448a1427af1f49e9a5eec426d2e9 Mon Sep 17 00:00:00 2001 From: Amin Alaee Date: Fri, 10 Dec 2021 16:08:02 +0100 Subject: [PATCH 3/3] Refactor tests for partian endpoint and ws --- tests/test_routing.py | 75 +++++++++++++++++++++++++------------------ 1 file changed, 43 insertions(+), 32 deletions(-) diff --git a/tests/test_routing.py b/tests/test_routing.py index 91bc26527..dcb996531 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -32,12 +32,28 @@ def user_no_match(request): # pragma: no cover return Response(content, media_type="text/plain") -async def users_ws(websocket: WebSocket): +async def partial_endpoint(arg, request): + return JSONResponse({"arg": arg}) + + +async def partial_ws_endpoint(websocket: WebSocket): await websocket.accept() await websocket.send_json({"url": str(websocket.url)}) await websocket.close() +class PartialRoutes: + @classmethod + async def async_endpoint(cls, arg, request): + return JSONResponse({"arg": arg}) + + @classmethod + async def async_ws_endpoint(cls, websocket: WebSocket): + await websocket.accept() + await websocket.send_json({"url": str(websocket.url)}) + await websocket.close() + + app = Router( [ Route("/", endpoint=homepage, methods=["GET"]), @@ -48,7 +64,21 @@ async def users_ws(websocket: WebSocket): Route("/me", endpoint=user_me), Route("/{username}", endpoint=user), Route("/nomatch", endpoint=user_no_match), - WebSocketRoute("/ws", endpoint=functools.partial(users_ws)), + ], + ), + Mount( + "/partial", + routes=[ + Route("/", endpoint=functools.partial(partial_endpoint, "foo")), + Route( + "/cls", + endpoint=functools.partial(PartialRoutes.async_endpoint, "foo"), + ), + WebSocketRoute("/ws", endpoint=functools.partial(partial_ws_endpoint)), + WebSocketRoute( + "/ws/cls", + endpoint=functools.partial(PartialRoutes.async_ws_endpoint), + ), ], ), Mount("/static", app=Response("xxxxx", media_type="image/png")), @@ -98,14 +128,14 @@ def path_with_parentheses(request): @app.websocket_route("/ws") -async def websocket_endpoint(session): +async def websocket_endpoint(session: WebSocket): await session.accept() await session.send_text("Hello, world!") await session.close() @app.websocket_route("/ws/{room}") -async def websocket_params(session): +async def websocket_params(session: WebSocket): await session.accept() await session.send_text(f"Hello, {session.path_params['room']}!") await session.close() @@ -635,45 +665,26 @@ def run_shutdown(): pass # pragma: nocover -class AsyncEndpointClassMethod: - @classmethod - async def async_endpoint(cls, arg, request): - return JSONResponse({"arg": arg}) - - -async def _partial_async_endpoint(arg, request): - return JSONResponse({"arg": arg}) - - -partial_async_endpoint = functools.partial(_partial_async_endpoint, "foo") -partial_cls_async_endpoint = functools.partial( - AsyncEndpointClassMethod.async_endpoint, "foo" -) - -partial_async_app = Router( - routes=[ - Route("/", partial_async_endpoint), - Route("/cls", partial_cls_async_endpoint), - ] -) - - def test_partial_async_endpoint(test_client_factory): - test_client = test_client_factory(partial_async_app) - response = test_client.get("/") + test_client = test_client_factory(app) + response = test_client.get("/partial") assert response.status_code == 200 assert response.json() == {"arg": "foo"} - cls_method_response = test_client.get("/cls") + cls_method_response = test_client.get("/partial/cls") assert cls_method_response.status_code == 200 assert cls_method_response.json() == {"arg": "foo"} def test_partial_async_ws_endpoint(test_client_factory): test_client = test_client_factory(app) - with test_client.websocket_connect("/users/ws") as websocket: + with test_client.websocket_connect("/partial/ws") as websocket: + data = websocket.receive_json() + assert data == {"url": "ws://testserver/partial/ws"} + + with test_client.websocket_connect("/partial/ws/cls") as websocket: data = websocket.receive_json() - assert data == {"url": "ws://testserver/users/ws"} + assert data == {"url": "ws://testserver/partial/ws/cls"} def test_duplicated_param_names():