From e804fff1a8decaa95e5ad283866b328985ed20ca Mon Sep 17 00:00:00 2001 From: Josh Wilson Date: Thu, 26 Nov 2020 03:03:36 -0800 Subject: [PATCH 1/7] Set explicit Origin in CORS preflight response if allow_credentials is True and allow_origins is wildcard When making a preflight request, the browser makes no indication as to whether the actual subsequent request will pass up credentials. However, unless the preflight response explicitly allows the request's `Origin` in the `Access-Control-Response-Header`, the browser will fail the CORS check and prevent the actual follow-up CORS request. This means that responding with the `*` wildcard is not sufficient to allow preflighted credentialed requests. The current workaround is to provide an equivalently permissive `allow_origin_regex` pattern. The `simple_response()` code already performs similar logic which currently only applies to non-preflighted requests since the browser would never make a preflighted request that hits this code due to this issue: ``` if self.allow_all_origins and has_cookie: headers["Access-Control-Allow-Origin"] = origin ``` This just bring the two halves inline with each other. --- starlette/middleware/cors.py | 5 +- tests/middleware/test_cors.py | 108 +++++++++++++++++++++++++++++++++- 2 files changed, 111 insertions(+), 2 deletions(-) diff --git a/starlette/middleware/cors.py b/starlette/middleware/cors.py index 0c9bdfb38..33e85d72b 100644 --- a/starlette/middleware/cors.py +++ b/starlette/middleware/cors.py @@ -56,6 +56,7 @@ def __init__( preflight_headers["Access-Control-Allow-Credentials"] = "true" self.app = app + self.allow_credentials = allow_credentials self.allow_origins = allow_origins self.allow_methods = allow_methods self.allow_headers = [h.lower() for h in allow_headers] @@ -105,11 +106,13 @@ def preflight_response(self, request_headers: Headers) -> Response: failures = [] if self.is_allowed_origin(origin=requested_origin): - if not self.allow_all_origins: + if not self.allow_all_origins or self.allow_credentials: # If self.allow_all_origins is True, then the # "Access-Control-Allow-Origin" header is already set to "*". # If we only allow specific origins, then we have to mirror back # the Origin header in the response. + # Similarly, if it's an allowed origin and credentials are + # allowed, we also have to mirror back the Origin header. headers["Access-Control-Allow-Origin"] = requested_origin else: failures.append("origin") diff --git a/tests/middleware/test_cors.py b/tests/middleware/test_cors.py index 2048cb11e..7dfc2ab6a 100644 --- a/tests/middleware/test_cors.py +++ b/tests/middleware/test_cors.py @@ -22,6 +22,61 @@ def homepage(request): client = TestClient(app) + # Test pre-flight response + headers = { + "Origin": "https://example.org", + "Access-Control-Request-Method": "GET", + "Access-Control-Request-Headers": "X-Example", + } + response = client.options("/", headers=headers) + assert response.status_code == 200 + assert response.text == "OK" + assert response.headers["access-control-allow-origin"] == "https://example.org" + assert response.headers["access-control-allow-headers"] == "X-Example" + assert response.headers["access-control-allow-credentials"] == "true" + + # Test standard response + headers = {"Origin": "https://example.org"} + response = client.get("/", headers=headers) + assert response.status_code == 200 + assert response.text == "Homepage" + assert response.headers["access-control-allow-origin"] == "*" + assert response.headers["access-control-expose-headers"] == "X-Status" + assert response.headers["access-control-allow-credentials"] == "true" + + # Test standard credentialed response + headers = {"Origin": "https://example.org", "Cookie": "star_cookie=sugar"} + response = client.get("/", headers=headers) + assert response.status_code == 200 + assert response.text == "Homepage" + assert response.headers["access-control-allow-origin"] == "https://example.org" + assert response.headers["access-control-expose-headers"] == "X-Status" + assert response.headers["access-control-allow-credentials"] == "true" + + # Test non-CORS response + response = client.get("/") + assert response.status_code == 200 + assert response.text == "Homepage" + assert "access-control-allow-origin" not in response.headers + + +def test_cors_allow_all_except_credentials(): + app = Starlette() + + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_headers=["*"], + allow_methods=["*"], + expose_headers=["X-Status"], + ) + + @app.route("/") + def homepage(request): + return PlainTextResponse("Homepage", status_code=200) + + client = TestClient(app) + # Test pre-flight response headers = { "Origin": "https://example.org", @@ -33,6 +88,7 @@ def homepage(request): assert response.text == "OK" assert response.headers["access-control-allow-origin"] == "*" assert response.headers["access-control-allow-headers"] == "X-Example" + assert "access-control-allow-credentials" not in response.headers # Test standard response headers = {"Origin": "https://example.org"} @@ -41,6 +97,7 @@ def homepage(request): assert response.text == "Homepage" assert response.headers["access-control-allow-origin"] == "*" assert response.headers["access-control-expose-headers"] == "X-Status" + assert "access-control-allow-credentials" not in response.headers # Test non-CORS response response = client.get("/") @@ -77,6 +134,7 @@ def homepage(request): assert response.headers["access-control-allow-headers"] == ( "Accept, Accept-Language, Content-Language, Content-Type, X-Example" ) + assert "access-control-allow-credentials" not in response.headers # Test standard response headers = {"Origin": "https://example.org"} @@ -84,6 +142,7 @@ def homepage(request): assert response.status_code == 200 assert response.text == "Homepage" assert response.headers["access-control-allow-origin"] == "https://example.org" + assert "access-control-allow-credentials" not in response.headers # Test non-CORS response response = client.get("/") @@ -116,6 +175,37 @@ def homepage(request): response = client.options("/", headers=headers) assert response.status_code == 400 assert response.text == "Disallowed CORS origin, method, headers" + assert "access-control-allow-origin" not in response.headers + + +def test_preflight_allows_request_origin_if_origins_wildcard_and_credentials_allowed(): + app = Starlette() + + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_methods=["POST"], + allow_credentials=True, + ) + + @app.route("/") + def homepage(request): + return # pragma: no cover + + client = TestClient(app) + + # Test pre-flight response + headers = { + "Origin": "https://example.org", + "Access-Control-Request-Method": "POST", + } + response = client.options( + "/", + headers=headers, + ) + assert response.status_code == 200 + assert response.headers["access-control-allow-origin"] == "https://example.org" + assert response.headers["access-control-allow-credentials"] == "true" def test_cors_allow_origin_regex(): @@ -125,6 +215,7 @@ def test_cors_allow_origin_regex(): CORSMiddleware, allow_headers=["X-Example", "Content-Type"], allow_origin_regex="https://.*", + allow_credentials=True, ) @app.route("/") @@ -139,8 +230,17 @@ def homepage(request): assert response.status_code == 200 assert response.text == "Homepage" assert response.headers["access-control-allow-origin"] == "https://example.org" + assert response.headers["access-control-allow-credentials"] == "true" - # Test diallowed standard response + # Test standard credentialed response + headers = {"Origin": "https://example.org", "Cookie": "star_cookie=sugar"} + response = client.get("/", headers=headers) + assert response.status_code == 200 + assert response.text == "Homepage" + assert response.headers["access-control-allow-origin"] == "https://example.org" + assert response.headers["access-control-allow-credentials"] == "true" + + # Test disallowed standard response # Note that enforcement is a browser concern. The disallowed-ness is reflected # in the lack of an "access-control-allow-origin" header in the response. headers = {"Origin": "http://example.org"} @@ -162,6 +262,7 @@ def homepage(request): assert response.headers["access-control-allow-headers"] == ( "Accept, Accept-Language, Content-Language, Content-Type, X-Example" ) + assert response.headers["access-control-allow-credentials"] == "true" # Test disallowed pre-flight response headers = { @@ -199,6 +300,7 @@ def homepage(request): response.headers["access-control-allow-origin"] == "https://subdomain.example.org" ) + assert "access-control-allow-credentials" not in response.headers # Test diallowed standard response headers = {"Origin": "https://subdomain.example.org.hacker.com"} @@ -225,6 +327,7 @@ def homepage(request): assert response.status_code == 200 assert response.text == "Homepage" assert response.headers["access-control-allow-origin"] == "https://example.org" + assert "access-control-allow-credentials" not in response.headers def test_cors_vary_header_defaults_to_origin(): @@ -278,11 +381,14 @@ def homepage(request): client = TestClient(app) response = client.get("/", headers={"Origin": "https://someplace.org"}) assert response.headers["access-control-allow-origin"] == "*" + assert "access-control-allow-credentials" not in response.headers response = client.get( "/", headers={"Cookie": "foo=bar", "Origin": "https://someplace.org"} ) assert response.headers["access-control-allow-origin"] == "https://someplace.org" + assert "access-control-allow-credentials" not in response.headers response = client.get("/", headers={"Origin": "https://someplace.org"}) assert response.headers["access-control-allow-origin"] == "*" + assert "access-control-allow-credentials" not in response.headers From 1a28ccea6776cf0f18904416cdb46f79cb37b57f Mon Sep 17 00:00:00 2001 From: Josh Wilson Date: Fri, 27 Nov 2020 01:39:17 -0800 Subject: [PATCH 2/7] Add Vary header to preflight response if allow_credentials --- starlette/middleware/cors.py | 2 ++ tests/middleware/test_cors.py | 3 +++ 2 files changed, 5 insertions(+) diff --git a/starlette/middleware/cors.py b/starlette/middleware/cors.py index 33e85d72b..259caae11 100644 --- a/starlette/middleware/cors.py +++ b/starlette/middleware/cors.py @@ -41,6 +41,8 @@ def __init__( preflight_headers = {} if "*" in allow_origins: preflight_headers["Access-Control-Allow-Origin"] = "*" + if allow_credentials: + preflight_headers["Vary"] = "Origin" else: preflight_headers["Vary"] = "Origin" preflight_headers.update( diff --git a/tests/middleware/test_cors.py b/tests/middleware/test_cors.py index 7dfc2ab6a..ea8e27bdb 100644 --- a/tests/middleware/test_cors.py +++ b/tests/middleware/test_cors.py @@ -34,6 +34,7 @@ def homepage(request): assert response.headers["access-control-allow-origin"] == "https://example.org" assert response.headers["access-control-allow-headers"] == "X-Example" assert response.headers["access-control-allow-credentials"] == "true" + assert response.headers["vary"] == "Origin" # Test standard response headers = {"Origin": "https://example.org"} @@ -89,6 +90,7 @@ def homepage(request): assert response.headers["access-control-allow-origin"] == "*" assert response.headers["access-control-allow-headers"] == "X-Example" assert "access-control-allow-credentials" not in response.headers + assert "vary" not in response.headers # Test standard response headers = {"Origin": "https://example.org"} @@ -206,6 +208,7 @@ def homepage(request): assert response.status_code == 200 assert response.headers["access-control-allow-origin"] == "https://example.org" assert response.headers["access-control-allow-credentials"] == "true" + assert response.headers["vary"] == "Origin" def test_cors_allow_origin_regex(): From 2ac9646d3815b93f5049ff0f2673c1deda25eaee Mon Sep 17 00:00:00 2001 From: Josh Wilson Date: Tue, 6 Apr 2021 18:38:31 -0700 Subject: [PATCH 3/7] Use allow_explicit_origin() for preflight request_headers This simplifies the code slightly by using this recently added method. It has some trade-offs, though. We now construct a `MutableHeaders` instead of a simple `dict` when copying the pre-computed preflight headers, and we move the `Vary` header construction out of the pre-computation and into the call handler. I think it makes the code more maintainable and the added per-call computation is minimal. --- starlette/middleware/cors.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/starlette/middleware/cors.py b/starlette/middleware/cors.py index 27a78230a..8ee9522a2 100644 --- a/starlette/middleware/cors.py +++ b/starlette/middleware/cors.py @@ -41,10 +41,6 @@ def __init__( preflight_headers = {} if "*" in allow_origins: preflight_headers["Access-Control-Allow-Origin"] = "*" - if allow_credentials: - preflight_headers["Vary"] = "Origin" - else: - preflight_headers["Vary"] = "Origin" preflight_headers.update( { "Access-Control-Allow-Methods": ", ".join(allow_methods), @@ -104,7 +100,7 @@ def preflight_response(self, request_headers: Headers) -> Response: requested_method = request_headers["access-control-request-method"] requested_headers = request_headers.get("access-control-request-headers") - headers = dict(self.preflight_headers) + headers = MutableHeaders(headers=self.preflight_headers) failures = [] if self.is_allowed_origin(origin=requested_origin): @@ -115,7 +111,7 @@ def preflight_response(self, request_headers: Headers) -> Response: # the Origin header in the response. # Similarly, if it's an allowed origin and credentials are # allowed, we also have to mirror back the Origin header. - headers["Access-Control-Allow-Origin"] = requested_origin + self.allow_explicit_origin(headers, requested_origin) else: failures.append("origin") From b100232090afeed204676e53ae2e60d387649e55 Mon Sep 17 00:00:00 2001 From: Josh Wilson Date: Tue, 6 Apr 2021 22:36:16 -0700 Subject: [PATCH 4/7] Convert MutableHeaders to dict for PlainTextResponse --- starlette/middleware/cors.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/starlette/middleware/cors.py b/starlette/middleware/cors.py index 8ee9522a2..1c0300218 100644 --- a/starlette/middleware/cors.py +++ b/starlette/middleware/cors.py @@ -132,9 +132,9 @@ def preflight_response(self, request_headers: Headers) -> Response: # if we do. if failures: failure_text = "Disallowed CORS " + ", ".join(failures) - return PlainTextResponse(failure_text, status_code=400, headers=headers) + return PlainTextResponse(failure_text, status_code=400, headers=dict(headers)) - return PlainTextResponse("OK", status_code=200, headers=headers) + return PlainTextResponse("OK", status_code=200, headers=dict(headers)) async def simple_response( self, scope: Scope, receive: Receive, send: Send, request_headers: Headers From 496983c2f5801b67281f20459face3a5e75777c4 Mon Sep 17 00:00:00 2001 From: Josh Wilson Date: Tue, 6 Apr 2021 23:14:31 -0700 Subject: [PATCH 5/7] Revert back to dict() for preflight headers This also names and caches some of the boolean tests in __init__() which we use in later if-blocks. This follows the existing pattern in order to better self-document the code. --- starlette/middleware/cors.py | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) diff --git a/starlette/middleware/cors.py b/starlette/middleware/cors.py index 1c0300218..f2f43b3a2 100644 --- a/starlette/middleware/cors.py +++ b/starlette/middleware/cors.py @@ -30,8 +30,12 @@ def __init__( if allow_origin_regex is not None: compiled_allow_origin_regex = re.compile(allow_origin_regex) + allow_all_origins = "*" in allow_origins + allow_all_headers = "*" in allow_headers + preflight_explicit_allow_origin = not allow_all_origins or allow_credentials + simple_headers = {} - if "*" in allow_origins: + if allow_all_origins: simple_headers["Access-Control-Allow-Origin"] = "*" if allow_credentials: simple_headers["Access-Control-Allow-Credentials"] = "true" @@ -39,7 +43,9 @@ def __init__( simple_headers["Access-Control-Expose-Headers"] = ", ".join(expose_headers) preflight_headers = {} - if "*" in allow_origins: + if preflight_explicit_allow_origin: + preflight_headers["Vary"] = "Origin" + else: preflight_headers["Access-Control-Allow-Origin"] = "*" preflight_headers.update( { @@ -48,18 +54,19 @@ def __init__( } ) allow_headers = sorted(SAFELISTED_HEADERS | set(allow_headers)) - if allow_headers and "*" not in allow_headers: + if allow_headers and not allow_all_headers: preflight_headers["Access-Control-Allow-Headers"] = ", ".join(allow_headers) if allow_credentials: preflight_headers["Access-Control-Allow-Credentials"] = "true" self.app = app - self.allow_credentials = allow_credentials self.allow_origins = allow_origins self.allow_methods = allow_methods self.allow_headers = [h.lower() for h in allow_headers] - self.allow_all_origins = "*" in allow_origins - self.allow_all_headers = "*" in allow_headers + self.allow_credentials = allow_credentials + self.allow_all_origins = allow_all_origins + self.allow_all_headers = allow_all_headers + self.preflight_explicit_allow_origin = preflight_explicit_allow_origin self.allow_origin_regex = compiled_allow_origin_regex self.simple_headers = simple_headers self.preflight_headers = preflight_headers @@ -100,18 +107,18 @@ def preflight_response(self, request_headers: Headers) -> Response: requested_method = request_headers["access-control-request-method"] requested_headers = request_headers.get("access-control-request-headers") - headers = MutableHeaders(headers=self.preflight_headers) + headers = dict(self.preflight_headers) failures = [] if self.is_allowed_origin(origin=requested_origin): - if not self.allow_all_origins or self.allow_credentials: + if self.preflight_explicit_allow_origin: # If self.allow_all_origins is True, then the # "Access-Control-Allow-Origin" header is already set to "*". # If we only allow specific origins, then we have to mirror back # the Origin header in the response. # Similarly, if it's an allowed origin and credentials are # allowed, we also have to mirror back the Origin header. - self.allow_explicit_origin(headers, requested_origin) + headers["Access-Control-Allow-Origin"] = requested_origin else: failures.append("origin") @@ -132,9 +139,9 @@ def preflight_response(self, request_headers: Headers) -> Response: # if we do. if failures: failure_text = "Disallowed CORS " + ", ".join(failures) - return PlainTextResponse(failure_text, status_code=400, headers=dict(headers)) + return PlainTextResponse(failure_text, status_code=400, headers=headers) - return PlainTextResponse("OK", status_code=200, headers=dict(headers)) + return PlainTextResponse("OK", status_code=200, headers=headers) async def simple_response( self, scope: Scope, receive: Receive, send: Send, request_headers: Headers From 5ee63ace189b41ea5e6acc2b224ab0adbcba6fbf Mon Sep 17 00:00:00 2001 From: Josh Wilson Date: Tue, 6 Apr 2021 23:27:19 -0700 Subject: [PATCH 6/7] Clean up comments --- starlette/middleware/cors.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/starlette/middleware/cors.py b/starlette/middleware/cors.py index f2f43b3a2..5aafa911e 100644 --- a/starlette/middleware/cors.py +++ b/starlette/middleware/cors.py @@ -44,6 +44,7 @@ def __init__( preflight_headers = {} if preflight_explicit_allow_origin: + # The origin value will be set in preflight_response() if it is allowed. preflight_headers["Vary"] = "Origin" else: preflight_headers["Access-Control-Allow-Origin"] = "*" @@ -112,12 +113,8 @@ def preflight_response(self, request_headers: Headers) -> Response: if self.is_allowed_origin(origin=requested_origin): if self.preflight_explicit_allow_origin: - # If self.allow_all_origins is True, then the - # "Access-Control-Allow-Origin" header is already set to "*". - # If we only allow specific origins, then we have to mirror back - # the Origin header in the response. - # Similarly, if it's an allowed origin and credentials are - # allowed, we also have to mirror back the Origin header. + # The "else" case is already accounted for in self.preflight_headers + # and the value would be "*". headers["Access-Control-Allow-Origin"] = requested_origin else: failures.append("origin") From d15f3b0623d66a0b6b3a4b4b9356af78a1b8b8d9 Mon Sep 17 00:00:00 2001 From: Josh Wilson Date: Wed, 7 Apr 2021 00:34:29 -0700 Subject: [PATCH 7/7] Remove unused self.allow_credentials attribute --- starlette/middleware/cors.py | 1 - 1 file changed, 1 deletion(-) diff --git a/starlette/middleware/cors.py b/starlette/middleware/cors.py index 5aafa911e..0b3f505e7 100644 --- a/starlette/middleware/cors.py +++ b/starlette/middleware/cors.py @@ -64,7 +64,6 @@ def __init__( self.allow_origins = allow_origins self.allow_methods = allow_methods self.allow_headers = [h.lower() for h in allow_headers] - self.allow_credentials = allow_credentials self.allow_all_origins = allow_all_origins self.allow_all_headers = allow_all_headers self.preflight_explicit_allow_origin = preflight_explicit_allow_origin