diff --git a/starlette/middleware/cors.py b/starlette/middleware/cors.py index 582aa942a..45262d3df 100644 --- a/starlette/middleware/cors.py +++ b/starlette/middleware/cors.py @@ -6,7 +6,7 @@ from starlette.responses import PlainTextResponse, Response from starlette.types import ASGIApp, Message, Receive, Scope, Send -ALL_METHODS = ("DELETE", "GET", "OPTIONS", "PATCH", "POST", "PUT") +ALL_METHODS = ("DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT") SAFELISTED_HEADERS = {"Accept", "Accept-Language", "Content-Language", "Content-Type"} diff --git a/tests/middleware/test_cors.py b/tests/middleware/test_cors.py index 121902b0a..3d9b16f97 100644 --- a/tests/middleware/test_cors.py +++ b/tests/middleware/test_cors.py @@ -118,6 +118,56 @@ def homepage(request): assert response.text == "Disallowed CORS origin, method, headers" +def test_cors_preflight_allow_all_methods(): + app = Starlette() + + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_methods=["*"], + ) + + @app.route("/") + def homepage(request): + pass # pragma: no cover + + client = TestClient(app) + + headers = { + "Origin": "https://example.org", + "Access-Control-Request-Method": "POST", + } + + for method in ("DELETE", "GET", "HEAD", "OPTIONS", "PATCH", "POST", "PUT"): + response = client.options("/", headers=headers) + assert response.status_code == 200 + assert method in response.headers["access-control-allow-methods"] + + +def test_cors_allow_all_methods(): + app = Starlette() + + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_methods=["*"], + ) + + @app.route( + "/", methods=("delete", "get", "head", "options", "patch", "post", "put") + ) + def homepage(request): + return PlainTextResponse("Homepage", status_code=200) + + client = TestClient(app) + + headers = {"Origin": "https://example.org"} + + for method in ("delete", "get", "head", "options", "patch", "post", "put"): + response = getattr(client, method)("/", headers=headers, json={}) + assert response.status_code == 200 + + def test_cors_allow_origin_regex(): app = Starlette()