Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for functools.partial in WebsocketRoute #1356

Merged
merged 3 commits into from Dec 11, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
5 changes: 4 additions & 1 deletion starlette/routing.py
Expand Up @@ -276,7 +276,10 @@ def __init__(
self.endpoint = endpoint
self.name = get_name(endpoint) if name is None else name

if inspect.isfunction(endpoint) or inspect.ismethod(endpoint):
endpoint_handler = endpoint
while isinstance(endpoint_handler, functools.partial):
endpoint_handler = endpoint_handler.func
if inspect.isfunction(endpoint_handler) or inspect.ismethod(endpoint_handler):
# Endpoint is function or method. Treat it as `func(websocket)`.
self.app = websocket_session(endpoint)
else:
Expand Down
81 changes: 53 additions & 28 deletions tests/test_routing.py
Expand Up @@ -32,6 +32,28 @@ def user_no_match(request): # pragma: no cover
return Response(content, media_type="text/plain")


async def partial_endpoint(arg, request):
return JSONResponse({"arg": arg})


async def partial_ws_endpoint(websocket: WebSocket):
await websocket.accept()
await websocket.send_json({"url": str(websocket.url)})
await websocket.close()


class PartialRoutes:
@classmethod
async def async_endpoint(cls, arg, request):
return JSONResponse({"arg": arg})

@classmethod
async def async_ws_endpoint(cls, websocket: WebSocket):
await websocket.accept()
await websocket.send_json({"url": str(websocket.url)})
await websocket.close()


app = Router(
[
Route("/", endpoint=homepage, methods=["GET"]),
Expand All @@ -44,6 +66,21 @@ def user_no_match(request): # pragma: no cover
Route("/nomatch", endpoint=user_no_match),
],
),
Mount(
"/partial",
routes=[
Route("/", endpoint=functools.partial(partial_endpoint, "foo")),
Route(
"/cls",
endpoint=functools.partial(PartialRoutes.async_endpoint, "foo"),
),
WebSocketRoute("/ws", endpoint=functools.partial(partial_ws_endpoint)),
WebSocketRoute(
"/ws/cls",
endpoint=functools.partial(PartialRoutes.async_ws_endpoint),
),
],
),
Mount("/static", app=Response("xxxxx", media_type="image/png")),
]
)
Expand Down Expand Up @@ -91,14 +128,14 @@ def path_with_parentheses(request):


@app.websocket_route("/ws")
async def websocket_endpoint(session):
async def websocket_endpoint(session: WebSocket):
await session.accept()
await session.send_text("Hello, world!")
await session.close()


@app.websocket_route("/ws/{room}")
async def websocket_params(session):
async def websocket_params(session: WebSocket):
await session.accept()
await session.send_text(f"Hello, {session.path_params['room']}!")
await session.close()
Expand Down Expand Up @@ -628,40 +665,28 @@ def run_shutdown():
pass # pragma: nocover


class AsyncEndpointClassMethod:
@classmethod
async def async_endpoint(cls, arg, request):
return JSONResponse({"arg": arg})


async def _partial_async_endpoint(arg, request):
return JSONResponse({"arg": arg})


partial_async_endpoint = functools.partial(_partial_async_endpoint, "foo")
partial_cls_async_endpoint = functools.partial(
AsyncEndpointClassMethod.async_endpoint, "foo"
)

partial_async_app = Router(
routes=[
Route("/", partial_async_endpoint),
Route("/cls", partial_cls_async_endpoint),
]
)


def test_partial_async_endpoint(test_client_factory):
test_client = test_client_factory(partial_async_app)
response = test_client.get("/")
test_client = test_client_factory(app)
response = test_client.get("/partial")
assert response.status_code == 200
assert response.json() == {"arg": "foo"}

cls_method_response = test_client.get("/cls")
cls_method_response = test_client.get("/partial/cls")
assert cls_method_response.status_code == 200
assert cls_method_response.json() == {"arg": "foo"}


def test_partial_async_ws_endpoint(test_client_factory):
test_client = test_client_factory(app)
with test_client.websocket_connect("/partial/ws") as websocket:
data = websocket.receive_json()
assert data == {"url": "ws://testserver/partial/ws"}

with test_client.websocket_connect("/partial/ws/cls") as websocket:
data = websocket.receive_json()
assert data == {"url": "ws://testserver/partial/ws/cls"}


def test_duplicated_param_names():
with pytest.raises(
ValueError,
Expand Down