diff --git a/starlette/routing.py b/starlette/routing.py index 7e10b16f9..1aa2cdb6d 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -111,13 +111,16 @@ def compile_path( path: str, ) -> typing.Tuple[typing.Pattern, str, typing.Dict[str, Convertor]]: """ - Given a path string, like: "/{username:str}", return a three-tuple + Given a path string, like: "/{username:str}", + or a host string, like: "{subdomain}.mydomain.org", return a three-tuple of (regex, format, {param_name:convertor}). regex: "/(?P[^/]+)" format: "/{username}" convertors: {"username": StringConvertor()} """ + is_host = not path.startswith("/") + path_regex = "^" path_format = "" duplicated_params = set() @@ -150,7 +153,13 @@ def compile_path( ending = "s" if len(duplicated_params) > 1 else "" raise ValueError(f"Duplicated param name{ending} {names} at path {path}") - path_regex += re.escape(path[idx:].split(":")[0]) + "$" + if is_host: + # Align with `Host.matches()` behavior, which ignores port. + hostname = path[idx:].split(":")[0] + path_regex += re.escape(hostname) + "$" + else: + path_regex += re.escape(path[idx:]) + "$" + path_format += path[idx:] return re.compile(path_regex), path_format, param_convertors @@ -429,6 +438,7 @@ class Host(BaseRoute): def __init__( self, host: str, app: ASGIApp, name: typing.Optional[str] = None ) -> None: + assert not host.startswith("/"), "Host must not start with '/'" self.host = host self.app = app self.name = name diff --git a/tests/test_routing.py b/tests/test_routing.py index e8adaca48..e2b1c3dfc 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -28,6 +28,11 @@ def user_me(request): return Response(content, media_type="text/plain") +def disable_user(request): + content = "User " + request.path_params["username"] + " disabled" + return Response(content, media_type="text/plain") + + def user_no_match(request): # pragma: no cover content = "User fixed no match" return Response(content, media_type="text/plain") @@ -109,6 +114,7 @@ async def websocket_params(session: WebSocket): Route("/", endpoint=users), Route("/me", endpoint=user_me), Route("/{username}", endpoint=user), + Route("/{username}:disable", endpoint=disable_user, methods=["PUT"]), Route("/nomatch", endpoint=user_no_match), ], ), @@ -189,6 +195,11 @@ def test_router(client): assert response.url == "http://testserver/users/tomchristie" assert response.text == "User tomchristie" + response = client.put("/users/tomchristie:disable") + assert response.status_code == 200 + assert response.url == "http://testserver/users/tomchristie:disable" + assert response.text == "User tomchristie disabled" + response = client.get("/users/nomatch") assert response.status_code == 200 assert response.text == "User nomatch" @@ -429,7 +440,9 @@ def test_host_routing(test_client_factory): response = client.get("/") assert response.status_code == 200 - client = test_client_factory(mixed_hosts_app, base_url="https://port.example.org/") + client = test_client_factory( + mixed_hosts_app, base_url="https://port.example.org:3600/" + ) response = client.get("/users") assert response.status_code == 404 @@ -437,6 +450,13 @@ def test_host_routing(test_client_factory): response = client.get("/") assert response.status_code == 200 + # Port in requested Host is irrelevant. + + client = test_client_factory(mixed_hosts_app, base_url="https://port.example.org/") + + response = client.get("/") + assert response.status_code == 200 + client = test_client_factory( mixed_hosts_app, base_url="https://port.example.org:5600/" )