diff --git a/starlette/middleware/cors.py b/starlette/middleware/cors.py index 45262d3df..0b3f505e7 100644 --- a/starlette/middleware/cors.py +++ b/starlette/middleware/cors.py @@ -30,8 +30,12 @@ def __init__( if allow_origin_regex is not None: compiled_allow_origin_regex = re.compile(allow_origin_regex) + allow_all_origins = "*" in allow_origins + allow_all_headers = "*" in allow_headers + preflight_explicit_allow_origin = not allow_all_origins or allow_credentials + simple_headers = {} - if "*" in allow_origins: + if allow_all_origins: simple_headers["Access-Control-Allow-Origin"] = "*" if allow_credentials: simple_headers["Access-Control-Allow-Credentials"] = "true" @@ -39,10 +43,11 @@ def __init__( simple_headers["Access-Control-Expose-Headers"] = ", ".join(expose_headers) preflight_headers = {} - if "*" in allow_origins: - preflight_headers["Access-Control-Allow-Origin"] = "*" - else: + if preflight_explicit_allow_origin: + # The origin value will be set in preflight_response() if it is allowed. preflight_headers["Vary"] = "Origin" + else: + preflight_headers["Access-Control-Allow-Origin"] = "*" preflight_headers.update( { "Access-Control-Allow-Methods": ", ".join(allow_methods), @@ -50,7 +55,7 @@ def __init__( } ) allow_headers = sorted(SAFELISTED_HEADERS | set(allow_headers)) - if allow_headers and "*" not in allow_headers: + if allow_headers and not allow_all_headers: preflight_headers["Access-Control-Allow-Headers"] = ", ".join(allow_headers) if allow_credentials: preflight_headers["Access-Control-Allow-Credentials"] = "true" @@ -59,8 +64,9 @@ def __init__( self.allow_origins = allow_origins self.allow_methods = allow_methods self.allow_headers = [h.lower() for h in allow_headers] - self.allow_all_origins = "*" in allow_origins - self.allow_all_headers = "*" in allow_headers + self.allow_all_origins = allow_all_origins + self.allow_all_headers = allow_all_headers + self.preflight_explicit_allow_origin = preflight_explicit_allow_origin self.allow_origin_regex = compiled_allow_origin_regex self.simple_headers = simple_headers self.preflight_headers = preflight_headers @@ -105,11 +111,9 @@ def preflight_response(self, request_headers: Headers) -> Response: failures = [] if self.is_allowed_origin(origin=requested_origin): - if not self.allow_all_origins: - # If self.allow_all_origins is True, then the - # "Access-Control-Allow-Origin" header is already set to "*". - # If we only allow specific origins, then we have to mirror back - # the Origin header in the response. + if self.preflight_explicit_allow_origin: + # The "else" case is already accounted for in self.preflight_headers + # and the value would be "*". headers["Access-Control-Allow-Origin"] = requested_origin else: failures.append("origin") diff --git a/tests/middleware/test_cors.py b/tests/middleware/test_cors.py index 3d9b16f97..7a250a241 100644 --- a/tests/middleware/test_cors.py +++ b/tests/middleware/test_cors.py @@ -22,6 +22,62 @@ def homepage(request): client = TestClient(app) + # Test pre-flight response + headers = { + "Origin": "https://example.org", + "Access-Control-Request-Method": "GET", + "Access-Control-Request-Headers": "X-Example", + } + response = client.options("/", headers=headers) + assert response.status_code == 200 + assert response.text == "OK" + assert response.headers["access-control-allow-origin"] == "https://example.org" + assert response.headers["access-control-allow-headers"] == "X-Example" + assert response.headers["access-control-allow-credentials"] == "true" + assert response.headers["vary"] == "Origin" + + # Test standard response + headers = {"Origin": "https://example.org"} + response = client.get("/", headers=headers) + assert response.status_code == 200 + assert response.text == "Homepage" + assert response.headers["access-control-allow-origin"] == "*" + assert response.headers["access-control-expose-headers"] == "X-Status" + assert response.headers["access-control-allow-credentials"] == "true" + + # Test standard credentialed response + headers = {"Origin": "https://example.org", "Cookie": "star_cookie=sugar"} + response = client.get("/", headers=headers) + assert response.status_code == 200 + assert response.text == "Homepage" + assert response.headers["access-control-allow-origin"] == "https://example.org" + assert response.headers["access-control-expose-headers"] == "X-Status" + assert response.headers["access-control-allow-credentials"] == "true" + + # Test non-CORS response + response = client.get("/") + assert response.status_code == 200 + assert response.text == "Homepage" + assert "access-control-allow-origin" not in response.headers + + +def test_cors_allow_all_except_credentials(): + app = Starlette() + + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_headers=["*"], + allow_methods=["*"], + expose_headers=["X-Status"], + ) + + @app.route("/") + def homepage(request): + return PlainTextResponse("Homepage", status_code=200) + + client = TestClient(app) + # Test pre-flight response headers = { "Origin": "https://example.org", @@ -33,6 +89,8 @@ def homepage(request): assert response.text == "OK" assert response.headers["access-control-allow-origin"] == "*" assert response.headers["access-control-allow-headers"] == "X-Example" + assert "access-control-allow-credentials" not in response.headers + assert "vary" not in response.headers # Test standard response headers = {"Origin": "https://example.org"} @@ -41,6 +99,7 @@ def homepage(request): assert response.text == "Homepage" assert response.headers["access-control-allow-origin"] == "*" assert response.headers["access-control-expose-headers"] == "X-Status" + assert "access-control-allow-credentials" not in response.headers # Test non-CORS response response = client.get("/") @@ -77,6 +136,7 @@ def homepage(request): assert response.headers["access-control-allow-headers"] == ( "Accept, Accept-Language, Content-Language, Content-Type, X-Example" ) + assert "access-control-allow-credentials" not in response.headers # Test standard response headers = {"Origin": "https://example.org"} @@ -84,6 +144,7 @@ def homepage(request): assert response.status_code == 200 assert response.text == "Homepage" assert response.headers["access-control-allow-origin"] == "https://example.org" + assert "access-control-allow-credentials" not in response.headers # Test non-CORS response response = client.get("/") @@ -116,6 +177,38 @@ def homepage(request): response = client.options("/", headers=headers) assert response.status_code == 400 assert response.text == "Disallowed CORS origin, method, headers" + assert "access-control-allow-origin" not in response.headers + + +def test_preflight_allows_request_origin_if_origins_wildcard_and_credentials_allowed(): + app = Starlette() + + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_methods=["POST"], + allow_credentials=True, + ) + + @app.route("/") + def homepage(request): + return # pragma: no cover + + client = TestClient(app) + + # Test pre-flight response + headers = { + "Origin": "https://example.org", + "Access-Control-Request-Method": "POST", + } + response = client.options( + "/", + headers=headers, + ) + assert response.status_code == 200 + assert response.headers["access-control-allow-origin"] == "https://example.org" + assert response.headers["access-control-allow-credentials"] == "true" + assert response.headers["vary"] == "Origin" def test_cors_preflight_allow_all_methods(): @@ -175,6 +268,7 @@ def test_cors_allow_origin_regex(): CORSMiddleware, allow_headers=["X-Example", "Content-Type"], allow_origin_regex="https://.*", + allow_credentials=True, ) @app.route("/") @@ -189,8 +283,17 @@ def homepage(request): assert response.status_code == 200 assert response.text == "Homepage" assert response.headers["access-control-allow-origin"] == "https://example.org" + assert response.headers["access-control-allow-credentials"] == "true" - # Test diallowed standard response + # Test standard credentialed response + headers = {"Origin": "https://example.org", "Cookie": "star_cookie=sugar"} + response = client.get("/", headers=headers) + assert response.status_code == 200 + assert response.text == "Homepage" + assert response.headers["access-control-allow-origin"] == "https://example.org" + assert response.headers["access-control-allow-credentials"] == "true" + + # Test disallowed standard response # Note that enforcement is a browser concern. The disallowed-ness is reflected # in the lack of an "access-control-allow-origin" header in the response. headers = {"Origin": "http://example.org"} @@ -212,6 +315,7 @@ def homepage(request): assert response.headers["access-control-allow-headers"] == ( "Accept, Accept-Language, Content-Language, Content-Type, X-Example" ) + assert response.headers["access-control-allow-credentials"] == "true" # Test disallowed pre-flight response headers = { @@ -249,6 +353,7 @@ def homepage(request): response.headers["access-control-allow-origin"] == "https://subdomain.example.org" ) + assert "access-control-allow-credentials" not in response.headers # Test diallowed standard response headers = {"Origin": "https://subdomain.example.org.hacker.com"} @@ -275,6 +380,7 @@ def homepage(request): assert response.status_code == 200 assert response.text == "Homepage" assert response.headers["access-control-allow-origin"] == "https://example.org" + assert "access-control-allow-credentials" not in response.headers def test_cors_vary_header_defaults_to_origin(): @@ -365,11 +471,14 @@ def homepage(request): client = TestClient(app) response = client.get("/", headers={"Origin": "https://someplace.org"}) assert response.headers["access-control-allow-origin"] == "*" + assert "access-control-allow-credentials" not in response.headers response = client.get( "/", headers={"Cookie": "foo=bar", "Origin": "https://someplace.org"} ) assert response.headers["access-control-allow-origin"] == "https://someplace.org" + assert "access-control-allow-credentials" not in response.headers response = client.get("/", headers={"Origin": "https://someplace.org"}) assert response.headers["access-control-allow-origin"] == "*" + assert "access-control-allow-credentials" not in response.headers