Skip to content

Commit

Permalink
Set explicit Origin in CORS preflight response if allow_credentials i…
Browse files Browse the repository at this point in the history
…s True and allow_origins is wildcard (#1113)

* Set explicit Origin in CORS preflight response if allow_credentials is True and allow_origins is wildcard

When making a preflight request, the browser makes no indication as to whether the actual subsequent
request will pass up credentials. However, unless the preflight response explicitly allows the
request's `Origin` in the `Access-Control-Response-Header`, the browser will fail the CORS check and
prevent the actual follow-up CORS request. This means that responding with the `*` wildcard is not
sufficient to allow preflighted credentialed requests. The current workaround is to provide an
equivalently permissive `allow_origin_regex` pattern.

The `simple_response()` code already performs similar logic which currently only applies to
non-preflighted requests since the browser would never make a preflighted request that hits this
code due to this issue:

```
if self.allow_all_origins and has_cookie:
    headers["Access-Control-Allow-Origin"] = origin
```

This just bring the two halves inline with each other.

* Add Vary header to preflight response if allow_credentials

* Use allow_explicit_origin() for preflight request_headers

This simplifies the code slightly by using this recently added method.

It has some trade-offs, though. We now construct a `MutableHeaders` instead of a simple `dict` when
copying the pre-computed preflight headers, and we move the `Vary` header construction out of the
pre-computation and into the call handler.

I think it makes the code more maintainable and the added per-call computation is minimal.

* Convert MutableHeaders to dict for PlainTextResponse

* Revert back to dict() for preflight headers

This also names and caches some of the boolean tests in __init__() which we use in later if-blocks.
This follows the existing pattern in order to better self-document the code.

* Clean up comments

* Remove unused self.allow_credentials attribute
  • Loading branch information
jcwilson committed Apr 14, 2021
1 parent f5ecb53 commit 995d70c
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 13 deletions.
28 changes: 16 additions & 12 deletions starlette/middleware/cors.py
Expand Up @@ -30,27 +30,32 @@ def __init__(
if allow_origin_regex is not None:
compiled_allow_origin_regex = re.compile(allow_origin_regex)

allow_all_origins = "*" in allow_origins
allow_all_headers = "*" in allow_headers
preflight_explicit_allow_origin = not allow_all_origins or allow_credentials

simple_headers = {}
if "*" in allow_origins:
if allow_all_origins:
simple_headers["Access-Control-Allow-Origin"] = "*"
if allow_credentials:
simple_headers["Access-Control-Allow-Credentials"] = "true"
if expose_headers:
simple_headers["Access-Control-Expose-Headers"] = ", ".join(expose_headers)

preflight_headers = {}
if "*" in allow_origins:
preflight_headers["Access-Control-Allow-Origin"] = "*"
else:
if preflight_explicit_allow_origin:
# The origin value will be set in preflight_response() if it is allowed.
preflight_headers["Vary"] = "Origin"
else:
preflight_headers["Access-Control-Allow-Origin"] = "*"
preflight_headers.update(
{
"Access-Control-Allow-Methods": ", ".join(allow_methods),
"Access-Control-Max-Age": str(max_age),
}
)
allow_headers = sorted(SAFELISTED_HEADERS | set(allow_headers))
if allow_headers and "*" not in allow_headers:
if allow_headers and not allow_all_headers:
preflight_headers["Access-Control-Allow-Headers"] = ", ".join(allow_headers)
if allow_credentials:
preflight_headers["Access-Control-Allow-Credentials"] = "true"
Expand All @@ -59,8 +64,9 @@ def __init__(
self.allow_origins = allow_origins
self.allow_methods = allow_methods
self.allow_headers = [h.lower() for h in allow_headers]
self.allow_all_origins = "*" in allow_origins
self.allow_all_headers = "*" in allow_headers
self.allow_all_origins = allow_all_origins
self.allow_all_headers = allow_all_headers
self.preflight_explicit_allow_origin = preflight_explicit_allow_origin
self.allow_origin_regex = compiled_allow_origin_regex
self.simple_headers = simple_headers
self.preflight_headers = preflight_headers
Expand Down Expand Up @@ -105,11 +111,9 @@ def preflight_response(self, request_headers: Headers) -> Response:
failures = []

if self.is_allowed_origin(origin=requested_origin):
if not self.allow_all_origins:
# If self.allow_all_origins is True, then the
# "Access-Control-Allow-Origin" header is already set to "*".
# If we only allow specific origins, then we have to mirror back
# the Origin header in the response.
if self.preflight_explicit_allow_origin:
# The "else" case is already accounted for in self.preflight_headers
# and the value would be "*".
headers["Access-Control-Allow-Origin"] = requested_origin
else:
failures.append("origin")
Expand Down
111 changes: 110 additions & 1 deletion tests/middleware/test_cors.py
Expand Up @@ -22,6 +22,62 @@ def homepage(request):

client = TestClient(app)

# Test pre-flight response
headers = {
"Origin": "https://example.org",
"Access-Control-Request-Method": "GET",
"Access-Control-Request-Headers": "X-Example",
}
response = client.options("/", headers=headers)
assert response.status_code == 200
assert response.text == "OK"
assert response.headers["access-control-allow-origin"] == "https://example.org"
assert response.headers["access-control-allow-headers"] == "X-Example"
assert response.headers["access-control-allow-credentials"] == "true"
assert response.headers["vary"] == "Origin"

# Test standard response
headers = {"Origin": "https://example.org"}
response = client.get("/", headers=headers)
assert response.status_code == 200
assert response.text == "Homepage"
assert response.headers["access-control-allow-origin"] == "*"
assert response.headers["access-control-expose-headers"] == "X-Status"
assert response.headers["access-control-allow-credentials"] == "true"

# Test standard credentialed response
headers = {"Origin": "https://example.org", "Cookie": "star_cookie=sugar"}
response = client.get("/", headers=headers)
assert response.status_code == 200
assert response.text == "Homepage"
assert response.headers["access-control-allow-origin"] == "https://example.org"
assert response.headers["access-control-expose-headers"] == "X-Status"
assert response.headers["access-control-allow-credentials"] == "true"

# Test non-CORS response
response = client.get("/")
assert response.status_code == 200
assert response.text == "Homepage"
assert "access-control-allow-origin" not in response.headers


def test_cors_allow_all_except_credentials():
app = Starlette()

app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_headers=["*"],
allow_methods=["*"],
expose_headers=["X-Status"],
)

@app.route("/")
def homepage(request):
return PlainTextResponse("Homepage", status_code=200)

client = TestClient(app)

# Test pre-flight response
headers = {
"Origin": "https://example.org",
Expand All @@ -33,6 +89,8 @@ def homepage(request):
assert response.text == "OK"
assert response.headers["access-control-allow-origin"] == "*"
assert response.headers["access-control-allow-headers"] == "X-Example"
assert "access-control-allow-credentials" not in response.headers
assert "vary" not in response.headers

# Test standard response
headers = {"Origin": "https://example.org"}
Expand All @@ -41,6 +99,7 @@ def homepage(request):
assert response.text == "Homepage"
assert response.headers["access-control-allow-origin"] == "*"
assert response.headers["access-control-expose-headers"] == "X-Status"
assert "access-control-allow-credentials" not in response.headers

# Test non-CORS response
response = client.get("/")
Expand Down Expand Up @@ -77,13 +136,15 @@ def homepage(request):
assert response.headers["access-control-allow-headers"] == (
"Accept, Accept-Language, Content-Language, Content-Type, X-Example"
)
assert "access-control-allow-credentials" not in response.headers

# Test standard response
headers = {"Origin": "https://example.org"}
response = client.get("/", headers=headers)
assert response.status_code == 200
assert response.text == "Homepage"
assert response.headers["access-control-allow-origin"] == "https://example.org"
assert "access-control-allow-credentials" not in response.headers

# Test non-CORS response
response = client.get("/")
Expand Down Expand Up @@ -116,6 +177,38 @@ def homepage(request):
response = client.options("/", headers=headers)
assert response.status_code == 400
assert response.text == "Disallowed CORS origin, method, headers"
assert "access-control-allow-origin" not in response.headers


def test_preflight_allows_request_origin_if_origins_wildcard_and_credentials_allowed():
app = Starlette()

app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["POST"],
allow_credentials=True,
)

@app.route("/")
def homepage(request):
return # pragma: no cover

client = TestClient(app)

# Test pre-flight response
headers = {
"Origin": "https://example.org",
"Access-Control-Request-Method": "POST",
}
response = client.options(
"/",
headers=headers,
)
assert response.status_code == 200
assert response.headers["access-control-allow-origin"] == "https://example.org"
assert response.headers["access-control-allow-credentials"] == "true"
assert response.headers["vary"] == "Origin"


def test_cors_preflight_allow_all_methods():
Expand Down Expand Up @@ -175,6 +268,7 @@ def test_cors_allow_origin_regex():
CORSMiddleware,
allow_headers=["X-Example", "Content-Type"],
allow_origin_regex="https://.*",
allow_credentials=True,
)

@app.route("/")
Expand All @@ -189,8 +283,17 @@ def homepage(request):
assert response.status_code == 200
assert response.text == "Homepage"
assert response.headers["access-control-allow-origin"] == "https://example.org"
assert response.headers["access-control-allow-credentials"] == "true"

# Test diallowed standard response
# Test standard credentialed response
headers = {"Origin": "https://example.org", "Cookie": "star_cookie=sugar"}
response = client.get("/", headers=headers)
assert response.status_code == 200
assert response.text == "Homepage"
assert response.headers["access-control-allow-origin"] == "https://example.org"
assert response.headers["access-control-allow-credentials"] == "true"

# Test disallowed standard response
# Note that enforcement is a browser concern. The disallowed-ness is reflected
# in the lack of an "access-control-allow-origin" header in the response.
headers = {"Origin": "http://example.org"}
Expand All @@ -212,6 +315,7 @@ def homepage(request):
assert response.headers["access-control-allow-headers"] == (
"Accept, Accept-Language, Content-Language, Content-Type, X-Example"
)
assert response.headers["access-control-allow-credentials"] == "true"

# Test disallowed pre-flight response
headers = {
Expand Down Expand Up @@ -249,6 +353,7 @@ def homepage(request):
response.headers["access-control-allow-origin"]
== "https://subdomain.example.org"
)
assert "access-control-allow-credentials" not in response.headers

# Test diallowed standard response
headers = {"Origin": "https://subdomain.example.org.hacker.com"}
Expand All @@ -275,6 +380,7 @@ def homepage(request):
assert response.status_code == 200
assert response.text == "Homepage"
assert response.headers["access-control-allow-origin"] == "https://example.org"
assert "access-control-allow-credentials" not in response.headers


def test_cors_vary_header_defaults_to_origin():
Expand Down Expand Up @@ -365,11 +471,14 @@ def homepage(request):
client = TestClient(app)
response = client.get("/", headers={"Origin": "https://someplace.org"})
assert response.headers["access-control-allow-origin"] == "*"
assert "access-control-allow-credentials" not in response.headers

response = client.get(
"/", headers={"Cookie": "foo=bar", "Origin": "https://someplace.org"}
)
assert response.headers["access-control-allow-origin"] == "https://someplace.org"
assert "access-control-allow-credentials" not in response.headers

response = client.get("/", headers={"Origin": "https://someplace.org"})
assert response.headers["access-control-allow-origin"] == "*"
assert "access-control-allow-credentials" not in response.headers

0 comments on commit 995d70c

Please sign in to comment.