diff --git a/starlette/middleware/cors.py b/starlette/middleware/cors.py index 0c9bdfb38..582aa942a 100644 --- a/starlette/middleware/cors.py +++ b/starlette/middleware/cors.py @@ -157,11 +157,16 @@ async def send( # If request includes any cookie headers, then we must respond # with the specific origin instead of '*'. if self.allow_all_origins and has_cookie: - headers["Access-Control-Allow-Origin"] = origin + self.allow_explicit_origin(headers, origin) # If we only allow specific origins, then we have to mirror back # the Origin header in the response. elif not self.allow_all_origins and self.is_allowed_origin(origin=origin): - headers["Access-Control-Allow-Origin"] = origin - headers.add_vary_header("Origin") + self.allow_explicit_origin(headers, origin) + await send(message) + + @staticmethod + def allow_explicit_origin(headers: MutableHeaders, origin: str) -> None: + headers["Access-Control-Allow-Origin"] = origin + headers.add_vary_header("Origin") diff --git a/tests/middleware/test_cors.py b/tests/middleware/test_cors.py index 2048cb11e..121902b0a 100644 --- a/tests/middleware/test_cors.py +++ b/tests/middleware/test_cors.py @@ -245,12 +245,28 @@ def homepage(request): assert response.headers["vary"] == "Origin" -def test_cors_vary_header_is_properly_set(): +def test_cors_vary_header_is_not_set_for_non_credentialed_request(): app = Starlette() - app.add_middleware(CORSMiddleware, allow_origins=["https://example.org"]) + app.add_middleware(CORSMiddleware, allow_origins=["*"]) - headers = {"Origin": "https://example.org"} + @app.route("/") + def homepage(request): + return PlainTextResponse( + "Homepage", status_code=200, headers={"Vary": "Accept-Encoding"} + ) + + client = TestClient(app) + + response = client.get("/", headers={"Origin": "https://someplace.org"}) + assert response.status_code == 200 + assert response.headers["vary"] == "Accept-Encoding" + + +def test_cors_vary_header_is_properly_set_for_credentialed_request(): + app = Starlette() + + app.add_middleware(CORSMiddleware, allow_origins=["*"]) @app.route("/") def homepage(request): @@ -260,13 +276,34 @@ def homepage(request): client = TestClient(app) - response = client.get("/", headers=headers) + response = client.get( + "/", headers={"Cookie": "foo=bar", "Origin": "https://someplace.org"} + ) + assert response.status_code == 200 + assert response.headers["vary"] == "Accept-Encoding, Origin" + + +def test_cors_vary_header_is_properly_set_when_allow_origins_is_not_wildcard(): + app = Starlette() + + app.add_middleware(CORSMiddleware, allow_origins=["https://example.org"]) + + @app.route("/") + def homepage(request): + return PlainTextResponse( + "Homepage", status_code=200, headers={"Vary": "Accept-Encoding"} + ) + + client = TestClient(app) + + response = client.get("/", headers={"Origin": "https://example.org"}) assert response.status_code == 200 assert response.headers["vary"] == "Accept-Encoding, Origin" def test_cors_allowed_origin_does_not_leak_between_credentialed_requests(): app = Starlette() + app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_headers=["*"], allow_methods=["*"] )