Skip to content

Commit

Permalink
Fix regression on route names with colons (#1675)
Browse files Browse the repository at this point in the history
Co-authored-by: Bodo Graumann <mail@bodograumann.de>
  • Loading branch information
florimondmanca and bodograumann committed Jun 4, 2022
1 parent 9ef1b91 commit b588ebe
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 3 deletions.
14 changes: 12 additions & 2 deletions starlette/routing.py
Expand Up @@ -111,13 +111,16 @@ def compile_path(
path: str,
) -> typing.Tuple[typing.Pattern, str, typing.Dict[str, Convertor]]:
"""
Given a path string, like: "/{username:str}", return a three-tuple
Given a path string, like: "/{username:str}",
or a host string, like: "{subdomain}.mydomain.org", return a three-tuple
of (regex, format, {param_name:convertor}).
regex: "/(?P<username>[^/]+)"
format: "/{username}"
convertors: {"username": StringConvertor()}
"""
is_host = not path.startswith("/")

path_regex = "^"
path_format = ""
duplicated_params = set()
Expand Down Expand Up @@ -150,7 +153,13 @@ def compile_path(
ending = "s" if len(duplicated_params) > 1 else ""
raise ValueError(f"Duplicated param name{ending} {names} at path {path}")

path_regex += re.escape(path[idx:].split(":")[0]) + "$"
if is_host:
# Align with `Host.matches()` behavior, which ignores port.
hostname = path[idx:].split(":")[0]
path_regex += re.escape(hostname) + "$"
else:
path_regex += re.escape(path[idx:]) + "$"

path_format += path[idx:]

return re.compile(path_regex), path_format, param_convertors
Expand Down Expand Up @@ -429,6 +438,7 @@ class Host(BaseRoute):
def __init__(
self, host: str, app: ASGIApp, name: typing.Optional[str] = None
) -> None:
assert not host.startswith("/"), "Host must not start with '/'"
self.host = host
self.app = app
self.name = name
Expand Down
22 changes: 21 additions & 1 deletion tests/test_routing.py
Expand Up @@ -28,6 +28,11 @@ def user_me(request):
return Response(content, media_type="text/plain")


def disable_user(request):
content = "User " + request.path_params["username"] + " disabled"
return Response(content, media_type="text/plain")


def user_no_match(request): # pragma: no cover
content = "User fixed no match"
return Response(content, media_type="text/plain")
Expand Down Expand Up @@ -109,6 +114,7 @@ async def websocket_params(session: WebSocket):
Route("/", endpoint=users),
Route("/me", endpoint=user_me),
Route("/{username}", endpoint=user),
Route("/{username}:disable", endpoint=disable_user, methods=["PUT"]),
Route("/nomatch", endpoint=user_no_match),
],
),
Expand Down Expand Up @@ -189,6 +195,11 @@ def test_router(client):
assert response.url == "http://testserver/users/tomchristie"
assert response.text == "User tomchristie"

response = client.put("/users/tomchristie:disable")
assert response.status_code == 200
assert response.url == "http://testserver/users/tomchristie:disable"
assert response.text == "User tomchristie disabled"

response = client.get("/users/nomatch")
assert response.status_code == 200
assert response.text == "User nomatch"
Expand Down Expand Up @@ -429,14 +440,23 @@ def test_host_routing(test_client_factory):
response = client.get("/")
assert response.status_code == 200

client = test_client_factory(mixed_hosts_app, base_url="https://port.example.org/")
client = test_client_factory(
mixed_hosts_app, base_url="https://port.example.org:3600/"
)

response = client.get("/users")
assert response.status_code == 404

response = client.get("/")
assert response.status_code == 200

# Port in requested Host is irrelevant.

client = test_client_factory(mixed_hosts_app, base_url="https://port.example.org/")

response = client.get("/")
assert response.status_code == 200

client = test_client_factory(
mixed_hosts_app, base_url="https://port.example.org:5600/"
)
Expand Down

0 comments on commit b588ebe

Please sign in to comment.