Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set explicit Origin in CORS preflight response if allow_credentials is True and allow_origins is wildcard #1113

Merged
merged 9 commits into from Apr 14, 2021
15 changes: 8 additions & 7 deletions starlette/middleware/cors.py
Expand Up @@ -41,8 +41,6 @@ def __init__(
preflight_headers = {}
if "*" in allow_origins:
preflight_headers["Access-Control-Allow-Origin"] = "*"
else:
preflight_headers["Vary"] = "Origin"
preflight_headers.update(
{
"Access-Control-Allow-Methods": ", ".join(allow_methods),
Expand All @@ -56,6 +54,7 @@ def __init__(
preflight_headers["Access-Control-Allow-Credentials"] = "true"

self.app = app
self.allow_credentials = allow_credentials
self.allow_origins = allow_origins
self.allow_methods = allow_methods
self.allow_headers = [h.lower() for h in allow_headers]
Expand Down Expand Up @@ -101,16 +100,18 @@ def preflight_response(self, request_headers: Headers) -> Response:
requested_method = request_headers["access-control-request-method"]
requested_headers = request_headers.get("access-control-request-headers")

headers = dict(self.preflight_headers)
headers = MutableHeaders(headers=self.preflight_headers)
failures = []

if self.is_allowed_origin(origin=requested_origin):
if not self.allow_all_origins:
if not self.allow_all_origins or self.allow_credentials:
# If self.allow_all_origins is True, then the
# "Access-Control-Allow-Origin" header is already set to "*".
# If we only allow specific origins, then we have to mirror back
# the Origin header in the response.
headers["Access-Control-Allow-Origin"] = requested_origin
# Similarly, if it's an allowed origin and credentials are
# allowed, we also have to mirror back the Origin header.
self.allow_explicit_origin(headers, requested_origin)
else:
failures.append("origin")

Expand All @@ -131,9 +132,9 @@ def preflight_response(self, request_headers: Headers) -> Response:
# if we do.
if failures:
failure_text = "Disallowed CORS " + ", ".join(failures)
return PlainTextResponse(failure_text, status_code=400, headers=headers)
return PlainTextResponse(failure_text, status_code=400, headers=dict(headers))

return PlainTextResponse("OK", status_code=200, headers=headers)
return PlainTextResponse("OK", status_code=200, headers=dict(headers))
jcwilson marked this conversation as resolved.
Show resolved Hide resolved

async def simple_response(
self, scope: Scope, receive: Receive, send: Send, request_headers: Headers
Expand Down
111 changes: 110 additions & 1 deletion tests/middleware/test_cors.py
Expand Up @@ -22,6 +22,62 @@ def homepage(request):

client = TestClient(app)

# Test pre-flight response
headers = {
"Origin": "https://example.org",
"Access-Control-Request-Method": "GET",
"Access-Control-Request-Headers": "X-Example",
}
response = client.options("/", headers=headers)
assert response.status_code == 200
assert response.text == "OK"
assert response.headers["access-control-allow-origin"] == "https://example.org"
assert response.headers["access-control-allow-headers"] == "X-Example"
assert response.headers["access-control-allow-credentials"] == "true"
assert response.headers["vary"] == "Origin"

# Test standard response
headers = {"Origin": "https://example.org"}
response = client.get("/", headers=headers)
assert response.status_code == 200
assert response.text == "Homepage"
assert response.headers["access-control-allow-origin"] == "*"
assert response.headers["access-control-expose-headers"] == "X-Status"
assert response.headers["access-control-allow-credentials"] == "true"

# Test standard credentialed response
headers = {"Origin": "https://example.org", "Cookie": "star_cookie=sugar"}
response = client.get("/", headers=headers)
assert response.status_code == 200
assert response.text == "Homepage"
assert response.headers["access-control-allow-origin"] == "https://example.org"
assert response.headers["access-control-expose-headers"] == "X-Status"
assert response.headers["access-control-allow-credentials"] == "true"

# Test non-CORS response
response = client.get("/")
assert response.status_code == 200
assert response.text == "Homepage"
assert "access-control-allow-origin" not in response.headers


def test_cors_allow_all_except_credentials():
app = Starlette()

app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_headers=["*"],
allow_methods=["*"],
expose_headers=["X-Status"],
)

@app.route("/")
def homepage(request):
return PlainTextResponse("Homepage", status_code=200)

client = TestClient(app)

# Test pre-flight response
headers = {
"Origin": "https://example.org",
Expand All @@ -33,6 +89,8 @@ def homepage(request):
assert response.text == "OK"
assert response.headers["access-control-allow-origin"] == "*"
assert response.headers["access-control-allow-headers"] == "X-Example"
assert "access-control-allow-credentials" not in response.headers
assert "vary" not in response.headers

# Test standard response
headers = {"Origin": "https://example.org"}
Expand All @@ -41,6 +99,7 @@ def homepage(request):
assert response.text == "Homepage"
assert response.headers["access-control-allow-origin"] == "*"
assert response.headers["access-control-expose-headers"] == "X-Status"
assert "access-control-allow-credentials" not in response.headers

# Test non-CORS response
response = client.get("/")
Expand Down Expand Up @@ -77,13 +136,15 @@ def homepage(request):
assert response.headers["access-control-allow-headers"] == (
"Accept, Accept-Language, Content-Language, Content-Type, X-Example"
)
assert "access-control-allow-credentials" not in response.headers

# Test standard response
headers = {"Origin": "https://example.org"}
response = client.get("/", headers=headers)
assert response.status_code == 200
assert response.text == "Homepage"
assert response.headers["access-control-allow-origin"] == "https://example.org"
assert "access-control-allow-credentials" not in response.headers

# Test non-CORS response
response = client.get("/")
Expand Down Expand Up @@ -116,6 +177,38 @@ def homepage(request):
response = client.options("/", headers=headers)
assert response.status_code == 400
assert response.text == "Disallowed CORS origin, method, headers"
assert "access-control-allow-origin" not in response.headers


def test_preflight_allows_request_origin_if_origins_wildcard_and_credentials_allowed():
app = Starlette()

app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["POST"],
allow_credentials=True,
)

@app.route("/")
def homepage(request):
return # pragma: no cover

client = TestClient(app)

# Test pre-flight response
headers = {
"Origin": "https://example.org",
"Access-Control-Request-Method": "POST",
}
response = client.options(
"/",
headers=headers,
)
assert response.status_code == 200
assert response.headers["access-control-allow-origin"] == "https://example.org"
assert response.headers["access-control-allow-credentials"] == "true"
assert response.headers["vary"] == "Origin"


def test_cors_preflight_allow_all_methods():
Expand Down Expand Up @@ -175,6 +268,7 @@ def test_cors_allow_origin_regex():
CORSMiddleware,
allow_headers=["X-Example", "Content-Type"],
allow_origin_regex="https://.*",
allow_credentials=True,
)

@app.route("/")
Expand All @@ -189,8 +283,17 @@ def homepage(request):
assert response.status_code == 200
assert response.text == "Homepage"
assert response.headers["access-control-allow-origin"] == "https://example.org"
assert response.headers["access-control-allow-credentials"] == "true"

# Test diallowed standard response
# Test standard credentialed response
headers = {"Origin": "https://example.org", "Cookie": "star_cookie=sugar"}
response = client.get("/", headers=headers)
assert response.status_code == 200
assert response.text == "Homepage"
assert response.headers["access-control-allow-origin"] == "https://example.org"
assert response.headers["access-control-allow-credentials"] == "true"

# Test disallowed standard response
# Note that enforcement is a browser concern. The disallowed-ness is reflected
# in the lack of an "access-control-allow-origin" header in the response.
headers = {"Origin": "http://example.org"}
Expand All @@ -212,6 +315,7 @@ def homepage(request):
assert response.headers["access-control-allow-headers"] == (
"Accept, Accept-Language, Content-Language, Content-Type, X-Example"
)
assert response.headers["access-control-allow-credentials"] == "true"

# Test disallowed pre-flight response
headers = {
Expand Down Expand Up @@ -249,6 +353,7 @@ def homepage(request):
response.headers["access-control-allow-origin"]
== "https://subdomain.example.org"
)
assert "access-control-allow-credentials" not in response.headers

# Test diallowed standard response
headers = {"Origin": "https://subdomain.example.org.hacker.com"}
Expand All @@ -275,6 +380,7 @@ def homepage(request):
assert response.status_code == 200
assert response.text == "Homepage"
assert response.headers["access-control-allow-origin"] == "https://example.org"
assert "access-control-allow-credentials" not in response.headers


def test_cors_vary_header_defaults_to_origin():
Expand Down Expand Up @@ -365,11 +471,14 @@ def homepage(request):
client = TestClient(app)
response = client.get("/", headers={"Origin": "https://someplace.org"})
assert response.headers["access-control-allow-origin"] == "*"
assert "access-control-allow-credentials" not in response.headers

response = client.get(
"/", headers={"Cookie": "foo=bar", "Origin": "https://someplace.org"}
)
assert response.headers["access-control-allow-origin"] == "https://someplace.org"
assert "access-control-allow-credentials" not in response.headers

response = client.get("/", headers={"Origin": "https://someplace.org"})
assert response.headers["access-control-allow-origin"] == "*"
assert "access-control-allow-credentials" not in response.headers