From 5864620763f4d54598934a7f1fa47b04bccf57bd Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Wed, 26 Jan 2022 09:04:05 -0600 Subject: [PATCH 1/4] add Allow header to 405 responses as required by RFC 7231 --- starlette/routing.py | 5 +++-- tests/test_endpoints.py | 1 + tests/test_exceptions.py | 11 +++++++++++ tests/test_routing.py | 1 + 4 files changed, 16 insertions(+), 2 deletions(-) diff --git a/starlette/routing.py b/starlette/routing.py index a7d72cb55..c0870e953 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -250,10 +250,11 @@ def url_path_for(self, name: str, **path_params: typing.Any) -> URLPath: async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: if self.methods and scope["method"] not in self.methods: + headers = {"Allow": ", ".join(self.methods)} if "app" in scope: - raise HTTPException(status_code=405) + raise HTTPException(status_code=405, headers=headers) else: - response = PlainTextResponse("Method Not Allowed", status_code=405) + response = PlainTextResponse("Method Not Allowed", status_code=405, headers=headers) await response(scope, receive, send) else: await self.app(scope, receive, send) diff --git a/tests/test_endpoints.py b/tests/test_endpoints.py index e57d47486..9895a4559 100644 --- a/tests/test_endpoints.py +++ b/tests/test_endpoints.py @@ -40,6 +40,7 @@ def test_http_endpoint_route_method(client): response = client.post("/") assert response.status_code == 405 assert response.text == "Method Not Allowed" + assert response.headers["allow"] == "GET" def test_websocket_endpoint_on_connect(test_client_factory): diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index cef03359f..703772ba1 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -21,6 +21,10 @@ def not_modified(request): raise HTTPException(status_code=304) +def not_allowed(request): + raise HTTPException(status_code=405) + + class HandledExcAfterResponse: async def __call__(self, scope, receive, send): response = PlainTextResponse("OK", status_code=200) @@ -34,6 +38,7 @@ async def __call__(self, scope, receive, send): Route("/not_acceptable", endpoint=not_acceptable), Route("/no_content", endpoint=no_content), Route("/not_modified", endpoint=not_modified), + Route("/not_allowed", endpoint=not_allowed), Route("/handled_exc_after_response", endpoint=HandledExcAfterResponse()), WebSocketRoute("/runtime_error", endpoint=raise_runtime_error), ] @@ -67,6 +72,12 @@ def test_not_modified(client): assert response.text == "" +def test_not_allowed(client): + response = client.get("/not_allowed") + assert response.status_code == 405 + assert set(response.headers["allow"].split(", ")) == {"HEAD", "GET"} + + def test_websockets_should_raise(client): with pytest.raises(RuntimeError): with client.websocket_connect("/runtime_error"): diff --git a/tests/test_routing.py b/tests/test_routing.py index 231c581fb..dc28427ee 100644 --- a/tests/test_routing.py +++ b/tests/test_routing.py @@ -161,6 +161,7 @@ def test_router(client): response = client.post("/") assert response.status_code == 405 assert response.text == "Method Not Allowed" + assert set(response.headers["allow"].split(", ")) == {"HEAD", "GET"} response = client.get("/foo") assert response.status_code == 404 From 3390c917dc920284f7e3a5bad42bbede6907eb8e Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Wed, 26 Jan 2022 09:08:30 -0600 Subject: [PATCH 2/4] lint --- starlette/routing.py | 4 +++- tests/test_exceptions.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/starlette/routing.py b/starlette/routing.py index c0870e953..84ffcb3fb 100644 --- a/starlette/routing.py +++ b/starlette/routing.py @@ -254,7 +254,9 @@ async def handle(self, scope: Scope, receive: Receive, send: Send) -> None: if "app" in scope: raise HTTPException(status_code=405, headers=headers) else: - response = PlainTextResponse("Method Not Allowed", status_code=405, headers=headers) + response = PlainTextResponse( + "Method Not Allowed", status_code=405, headers=headers + ) await response(scope, receive, send) else: await self.app(scope, receive, send) diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index efc82fced..5181bdd3a 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -20,11 +20,11 @@ def no_content(request): def not_modified(request): raise HTTPException(status_code=304) - + def not_allowed(request): raise HTTPException(status_code=405) - + def with_headers(request): raise HTTPException(status_code=200, headers={"x-potato": "always"}) From 49a60eac69e53b264a6cd6deb5d593ac8b9b6880 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Wed, 26 Jan 2022 09:26:52 -0600 Subject: [PATCH 3/4] set headers in HTTPEndpoint --- starlette/endpoints.py | 9 +++++++-- tests/test_exceptions.py | 11 ----------- 2 files changed, 7 insertions(+), 13 deletions(-) diff --git a/starlette/endpoints.py b/starlette/endpoints.py index e27e4fe49..604210758 100644 --- a/starlette/endpoints.py +++ b/starlette/endpoints.py @@ -17,6 +17,10 @@ def __init__(self, scope: Scope, receive: Receive, send: Send) -> None: self.scope = scope self.receive = receive self.send = send + self._allowed_methods = filter( + lambda method: getattr(self, method.lower(), None) is not None, + ["GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], + ) def __await__(self) -> typing.Generator: return self.dispatch().__await__() @@ -43,9 +47,10 @@ async def method_not_allowed(self, request: Request) -> Response: # If we're running inside a starlette application then raise an # exception, so that the configurable exception handler can deal with # returning the response. For plain ASGI apps, just return the response. + headers = {"Allow": ", ".join(self._allowed_methods)} if "app" in self.scope: - raise HTTPException(status_code=405) - return PlainTextResponse("Method Not Allowed", status_code=405) + raise HTTPException(status_code=405, headers=headers) + return PlainTextResponse("Method Not Allowed", status_code=405, headers=headers) class WebSocketEndpoint: diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 5181bdd3a..80307a521 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -21,10 +21,6 @@ def not_modified(request): raise HTTPException(status_code=304) -def not_allowed(request): - raise HTTPException(status_code=405) - - def with_headers(request): raise HTTPException(status_code=200, headers={"x-potato": "always"}) @@ -42,7 +38,6 @@ async def __call__(self, scope, receive, send): Route("/not_acceptable", endpoint=not_acceptable), Route("/no_content", endpoint=no_content), Route("/not_modified", endpoint=not_modified), - Route("/not_allowed", endpoint=not_allowed), Route("/with_headers", endpoint=with_headers), Route("/handled_exc_after_response", endpoint=HandledExcAfterResponse()), WebSocketRoute("/runtime_error", endpoint=raise_runtime_error), @@ -77,12 +72,6 @@ def test_not_modified(client): assert response.text == "" -def test_not_allowed(client): - response = client.get("/not_allowed") - assert response.status_code == 405 - assert set(response.headers["allow"].split(", ")) == {"HEAD", "GET"} - - def test_with_headers(client): response = client.get("/with_headers") assert response.status_code == 200 From 2a4cfa5d213f91d87ba4471d4878e4d20901cecf Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Wed, 26 Jan 2022 09:59:53 -0600 Subject: [PATCH 4/4] apply suggestion --- starlette/endpoints.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/starlette/endpoints.py b/starlette/endpoints.py index 604210758..73367c257 100644 --- a/starlette/endpoints.py +++ b/starlette/endpoints.py @@ -17,10 +17,11 @@ def __init__(self, scope: Scope, receive: Receive, send: Send) -> None: self.scope = scope self.receive = receive self.send = send - self._allowed_methods = filter( - lambda method: getattr(self, method.lower(), None) is not None, - ["GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "OPTIONS"], - ) + self._allowed_methods = [ + method + for method in ("GET", "HEAD", "POST", "PUT", "PATCH", "DELETE", "OPTIONS") + if getattr(self, method.lower(), None) is not None + ] def __await__(self) -> typing.Generator: return self.dispatch().__await__()