Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set explicit Origin in CORS preflight response if allow_credentials is True and allow_origins is wildcard #1113

Merged
merged 9 commits into from Apr 14, 2021
29 changes: 17 additions & 12 deletions starlette/middleware/cors.py
Expand Up @@ -30,27 +30,32 @@ def __init__(
if allow_origin_regex is not None:
compiled_allow_origin_regex = re.compile(allow_origin_regex)

allow_all_origins = "*" in allow_origins
allow_all_headers = "*" in allow_headers
preflight_explicit_allow_origin = not allow_all_origins or allow_credentials

simple_headers = {}
if "*" in allow_origins:
if allow_all_origins:
simple_headers["Access-Control-Allow-Origin"] = "*"
if allow_credentials:
simple_headers["Access-Control-Allow-Credentials"] = "true"
if expose_headers:
simple_headers["Access-Control-Expose-Headers"] = ", ".join(expose_headers)

preflight_headers = {}
if "*" in allow_origins:
preflight_headers["Access-Control-Allow-Origin"] = "*"
else:
if preflight_explicit_allow_origin:
# The origin value will be set in preflight_response() if it is allowed.
preflight_headers["Vary"] = "Origin"
else:
preflight_headers["Access-Control-Allow-Origin"] = "*"
preflight_headers.update(
{
"Access-Control-Allow-Methods": ", ".join(allow_methods),
"Access-Control-Max-Age": str(max_age),
}
)
allow_headers = sorted(SAFELISTED_HEADERS | set(allow_headers))
if allow_headers and "*" not in allow_headers:
if allow_headers and not allow_all_headers:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't allow_headers redundant here since it's never an empty list? This could be reduced to this, right?:

if not allow_all_headers:
    ...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, it looks unused.

Also a bit concerned about headers with mixed case in the Access-Control-Allow-Headers header, but if it is an issue it's a separate one.

preflight_headers["Access-Control-Allow-Headers"] = ", ".join(allow_headers)
if allow_credentials:
preflight_headers["Access-Control-Allow-Credentials"] = "true"
Expand All @@ -59,8 +64,10 @@ def __init__(
self.allow_origins = allow_origins
self.allow_methods = allow_methods
self.allow_headers = [h.lower() for h in allow_headers]
self.allow_all_origins = "*" in allow_origins
self.allow_all_headers = "*" in allow_headers
self.allow_credentials = allow_credentials
self.allow_all_origins = allow_all_origins
self.allow_all_headers = allow_all_headers
self.preflight_explicit_allow_origin = preflight_explicit_allow_origin
self.allow_origin_regex = compiled_allow_origin_regex
self.simple_headers = simple_headers
self.preflight_headers = preflight_headers
Expand Down Expand Up @@ -105,11 +112,9 @@ def preflight_response(self, request_headers: Headers) -> Response:
failures = []

if self.is_allowed_origin(origin=requested_origin):
if not self.allow_all_origins:
# If self.allow_all_origins is True, then the
# "Access-Control-Allow-Origin" header is already set to "*".
# If we only allow specific origins, then we have to mirror back
# the Origin header in the response.
if self.preflight_explicit_allow_origin:
# The "else" case is already accounted for in self.preflight_headers
# and the value would be "*".
headers["Access-Control-Allow-Origin"] = requested_origin
else:
failures.append("origin")
Expand Down
111 changes: 110 additions & 1 deletion tests/middleware/test_cors.py
Expand Up @@ -22,6 +22,62 @@ def homepage(request):

client = TestClient(app)

# Test pre-flight response
headers = {
"Origin": "https://example.org",
"Access-Control-Request-Method": "GET",
"Access-Control-Request-Headers": "X-Example",
}
response = client.options("/", headers=headers)
assert response.status_code == 200
assert response.text == "OK"
assert response.headers["access-control-allow-origin"] == "https://example.org"
assert response.headers["access-control-allow-headers"] == "X-Example"
assert response.headers["access-control-allow-credentials"] == "true"
assert response.headers["vary"] == "Origin"

# Test standard response
headers = {"Origin": "https://example.org"}
response = client.get("/", headers=headers)
assert response.status_code == 200
assert response.text == "Homepage"
assert response.headers["access-control-allow-origin"] == "*"
assert response.headers["access-control-expose-headers"] == "X-Status"
assert response.headers["access-control-allow-credentials"] == "true"

# Test standard credentialed response
headers = {"Origin": "https://example.org", "Cookie": "star_cookie=sugar"}
response = client.get("/", headers=headers)
assert response.status_code == 200
assert response.text == "Homepage"
assert response.headers["access-control-allow-origin"] == "https://example.org"
assert response.headers["access-control-expose-headers"] == "X-Status"
assert response.headers["access-control-allow-credentials"] == "true"

# Test non-CORS response
response = client.get("/")
assert response.status_code == 200
assert response.text == "Homepage"
assert "access-control-allow-origin" not in response.headers


def test_cors_allow_all_except_credentials():
app = Starlette()

app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_headers=["*"],
allow_methods=["*"],
expose_headers=["X-Status"],
)

@app.route("/")
def homepage(request):
return PlainTextResponse("Homepage", status_code=200)

client = TestClient(app)

# Test pre-flight response
headers = {
"Origin": "https://example.org",
Expand All @@ -33,6 +89,8 @@ def homepage(request):
assert response.text == "OK"
assert response.headers["access-control-allow-origin"] == "*"
assert response.headers["access-control-allow-headers"] == "X-Example"
assert "access-control-allow-credentials" not in response.headers
assert "vary" not in response.headers

# Test standard response
headers = {"Origin": "https://example.org"}
Expand All @@ -41,6 +99,7 @@ def homepage(request):
assert response.text == "Homepage"
assert response.headers["access-control-allow-origin"] == "*"
assert response.headers["access-control-expose-headers"] == "X-Status"
assert "access-control-allow-credentials" not in response.headers

# Test non-CORS response
response = client.get("/")
Expand Down Expand Up @@ -77,13 +136,15 @@ def homepage(request):
assert response.headers["access-control-allow-headers"] == (
"Accept, Accept-Language, Content-Language, Content-Type, X-Example"
)
assert "access-control-allow-credentials" not in response.headers

# Test standard response
headers = {"Origin": "https://example.org"}
response = client.get("/", headers=headers)
assert response.status_code == 200
assert response.text == "Homepage"
assert response.headers["access-control-allow-origin"] == "https://example.org"
assert "access-control-allow-credentials" not in response.headers

# Test non-CORS response
response = client.get("/")
Expand Down Expand Up @@ -116,6 +177,38 @@ def homepage(request):
response = client.options("/", headers=headers)
assert response.status_code == 400
assert response.text == "Disallowed CORS origin, method, headers"
assert "access-control-allow-origin" not in response.headers


def test_preflight_allows_request_origin_if_origins_wildcard_and_credentials_allowed():
app = Starlette()

app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["POST"],
allow_credentials=True,
)

@app.route("/")
def homepage(request):
return # pragma: no cover

client = TestClient(app)

# Test pre-flight response
headers = {
"Origin": "https://example.org",
"Access-Control-Request-Method": "POST",
}
response = client.options(
"/",
headers=headers,
)
assert response.status_code == 200
assert response.headers["access-control-allow-origin"] == "https://example.org"
assert response.headers["access-control-allow-credentials"] == "true"
assert response.headers["vary"] == "Origin"


def test_cors_preflight_allow_all_methods():
Expand Down Expand Up @@ -175,6 +268,7 @@ def test_cors_allow_origin_regex():
CORSMiddleware,
allow_headers=["X-Example", "Content-Type"],
allow_origin_regex="https://.*",
allow_credentials=True,
)

@app.route("/")
Expand All @@ -189,8 +283,17 @@ def homepage(request):
assert response.status_code == 200
assert response.text == "Homepage"
assert response.headers["access-control-allow-origin"] == "https://example.org"
assert response.headers["access-control-allow-credentials"] == "true"

# Test diallowed standard response
# Test standard credentialed response
headers = {"Origin": "https://example.org", "Cookie": "star_cookie=sugar"}
response = client.get("/", headers=headers)
assert response.status_code == 200
assert response.text == "Homepage"
assert response.headers["access-control-allow-origin"] == "https://example.org"
assert response.headers["access-control-allow-credentials"] == "true"

# Test disallowed standard response
# Note that enforcement is a browser concern. The disallowed-ness is reflected
# in the lack of an "access-control-allow-origin" header in the response.
headers = {"Origin": "http://example.org"}
Expand All @@ -212,6 +315,7 @@ def homepage(request):
assert response.headers["access-control-allow-headers"] == (
"Accept, Accept-Language, Content-Language, Content-Type, X-Example"
)
assert response.headers["access-control-allow-credentials"] == "true"

# Test disallowed pre-flight response
headers = {
Expand Down Expand Up @@ -249,6 +353,7 @@ def homepage(request):
response.headers["access-control-allow-origin"]
== "https://subdomain.example.org"
)
assert "access-control-allow-credentials" not in response.headers

# Test diallowed standard response
headers = {"Origin": "https://subdomain.example.org.hacker.com"}
Expand All @@ -275,6 +380,7 @@ def homepage(request):
assert response.status_code == 200
assert response.text == "Homepage"
assert response.headers["access-control-allow-origin"] == "https://example.org"
assert "access-control-allow-credentials" not in response.headers


def test_cors_vary_header_defaults_to_origin():
Expand Down Expand Up @@ -365,11 +471,14 @@ def homepage(request):
client = TestClient(app)
response = client.get("/", headers={"Origin": "https://someplace.org"})
assert response.headers["access-control-allow-origin"] == "*"
assert "access-control-allow-credentials" not in response.headers

response = client.get(
"/", headers={"Cookie": "foo=bar", "Origin": "https://someplace.org"}
)
assert response.headers["access-control-allow-origin"] == "https://someplace.org"
assert "access-control-allow-credentials" not in response.headers

response = client.get("/", headers={"Origin": "https://someplace.org"})
assert response.headers["access-control-allow-origin"] == "*"
assert "access-control-allow-credentials" not in response.headers