Skip to content

Commit

Permalink
Update starlette to 0.21.0.
Browse files Browse the repository at this point in the history
- Adapt tests suite after breaking changes to the starlette's TestClient
- Fix issues found by mypy caused by more precise type annotations in starlette
  • Loading branch information
Paweł Rubin committed Oct 7, 2022
1 parent c6aa28b commit 5d8a327
Show file tree
Hide file tree
Showing 26 changed files with 80 additions and 71 deletions.
2 changes: 1 addition & 1 deletion fastapi/security/api_key.py
Expand Up @@ -54,7 +54,7 @@ def __init__(
self.auto_error = auto_error

async def __call__(self, request: Request) -> Optional[str]:
api_key: str = request.headers.get(self.model.name)
api_key = request.headers.get(self.model.name)
if not api_key:
if self.auto_error:
raise HTTPException(
Expand Down
8 changes: 4 additions & 4 deletions fastapi/security/http.py
Expand Up @@ -38,7 +38,7 @@ def __init__(
async def __call__(
self, request: Request
) -> Optional[HTTPAuthorizationCredentials]:
authorization: str = request.headers.get("Authorization")
authorization = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials):
if self.auto_error:
Expand Down Expand Up @@ -67,7 +67,7 @@ def __init__(
async def __call__( # type: ignore
self, request: Request
) -> Optional[HTTPBasicCredentials]:
authorization: str = request.headers.get("Authorization")
authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
if self.realm:
unauthorized_headers = {"WWW-Authenticate": f'Basic realm="{self.realm}"'}
Expand Down Expand Up @@ -113,7 +113,7 @@ def __init__(
async def __call__(
self, request: Request
) -> Optional[HTTPAuthorizationCredentials]:
authorization: str = request.headers.get("Authorization")
authorization = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials):
if self.auto_error:
Expand Down Expand Up @@ -148,7 +148,7 @@ def __init__(
async def __call__(
self, request: Request
) -> Optional[HTTPAuthorizationCredentials]:
authorization: str = request.headers.get("Authorization")
authorization = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization)
if not (authorization and scheme and credentials):
if self.auto_error:
Expand Down
6 changes: 3 additions & 3 deletions fastapi/security/oauth2.py
Expand Up @@ -126,7 +126,7 @@ def __init__(
self.auto_error = auto_error

async def __call__(self, request: Request) -> Optional[str]:
authorization: str = request.headers.get("Authorization")
authorization = request.headers.get("Authorization")
if not authorization:
if self.auto_error:
raise HTTPException(
Expand Down Expand Up @@ -157,7 +157,7 @@ def __init__(
)

async def __call__(self, request: Request) -> Optional[str]:
authorization: str = request.headers.get("Authorization")
authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer":
if self.auto_error:
Expand Down Expand Up @@ -200,7 +200,7 @@ def __init__(
)

async def __call__(self, request: Request) -> Optional[str]:
authorization: str = request.headers.get("Authorization")
authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer":
if self.auto_error:
Expand Down
2 changes: 1 addition & 1 deletion fastapi/security/open_id_connect_url.py
Expand Up @@ -23,7 +23,7 @@ def __init__(
self.auto_error = auto_error

async def __call__(self, request: Request) -> Optional[str]:
authorization: str = request.headers.get("Authorization")
authorization = request.headers.get("Authorization")
if not authorization:
if self.auto_error:
raise HTTPException(
Expand Down
6 changes: 4 additions & 2 deletions fastapi/security/utils.py
@@ -1,7 +1,9 @@
from typing import Tuple
from typing import Optional, Tuple


def get_authorization_scheme_param(authorization_header_value: str) -> Tuple[str, str]:
def get_authorization_scheme_param(
authorization_header_value: Optional[str],
) -> Tuple[str, str]:
if not authorization_header_value:
return "", ""
scheme, _, param = authorization_header_value.partition(" ")
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Expand Up @@ -38,7 +38,7 @@ classifiers = [
"Topic :: Internet :: WWW/HTTP",
]
dependencies = [
"starlette==0.20.4",
"starlette>=0.21.0,<0.22.0",
"pydantic >=1.6.2,!=1.7,!=1.7.1,!=1.7.2,!=1.7.3,!=1.8,!=1.8.1,<2.0.0",
]
dynamic = ["version"]
Expand Down Expand Up @@ -69,6 +69,7 @@ test = [
"python-jose[cryptography] >=3.3.0,<4.0.0",
"pyyaml >=5.3.1,<7.0.0",
"passlib[bcrypt] >=1.7.2,<2.0.0",
"trio >=0.19,<0.22.0",

# types
"types-ujson ==5.4.0",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_enforce_once_required_parameter.py
Expand Up @@ -101,7 +101,7 @@ def test_schema():


def test_get_invalid():
response = client.get("/foo", params={"client_id": None})
response = client.get("/foo")
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY


Expand Down
2 changes: 1 addition & 1 deletion tests/test_extra_routes.py
Expand Up @@ -333,7 +333,7 @@ def test_get_api_route_not_decorated():


def test_delete():
response = client.delete("/items/foo", json={"name": "Foo"})
response = client.request("DELETE", "/items/foo", json={"name": "Foo"})
assert response.status_code == 200, response.text
assert response.json() == {"item_id": "foo", "item": {"name": "Foo", "price": None}}

Expand Down
2 changes: 1 addition & 1 deletion tests/test_get_request_body.py
Expand Up @@ -104,5 +104,5 @@ def test_openapi_schema():

def test_get_with_body():
body = {"name": "Foo", "description": "Some description", "price": 5.5}
response = client.get("/product", json=body)
response = client.request("GET", "/product", json=body)
assert response.json() == body
9 changes: 6 additions & 3 deletions tests/test_param_include_in_schema.py
Expand Up @@ -33,8 +33,6 @@ async def hidden_query(
return {"hidden_query": hidden_query}


client = TestClient(app)

openapi_shema = {
"openapi": "3.0.2",
"info": {"title": "FastAPI", "version": "0.1.0"},
Expand Down Expand Up @@ -161,6 +159,7 @@ async def hidden_query(


def test_openapi_schema():
client = TestClient(app)
response = client.get("/openapi.json")
assert response.status_code == 200
assert response.json() == openapi_shema
Expand All @@ -184,7 +183,8 @@ def test_openapi_schema():
],
)
def test_hidden_cookie(path, cookies, expected_status, expected_response):
response = client.get(path, cookies=cookies)
client = TestClient(app, cookies=cookies)
response = client.get(path)
assert response.status_code == expected_status
assert response.json() == expected_response

Expand All @@ -207,12 +207,14 @@ def test_hidden_cookie(path, cookies, expected_status, expected_response):
],
)
def test_hidden_header(path, headers, expected_status, expected_response):
client = TestClient(app)
response = client.get(path, headers=headers)
assert response.status_code == expected_status
assert response.json() == expected_response


def test_hidden_path():
client = TestClient(app)
response = client.get("/hidden_path/hidden_path")
assert response.status_code == 200
assert response.json() == {"hidden_path": "hidden_path"}
Expand All @@ -234,6 +236,7 @@ def test_hidden_path():
],
)
def test_hidden_query(path, expected_status, expected_response):
client = TestClient(app)
response = client.get(path)
assert response.status_code == expected_status
assert response.json() == expected_response
7 changes: 4 additions & 3 deletions tests/test_security_api_key_cookie.py
Expand Up @@ -22,8 +22,6 @@ def read_current_user(current_user: User = Depends(get_current_user)):
return current_user


client = TestClient(app)

openapi_schema = {
"openapi": "3.0.2",
"info": {"title": "FastAPI", "version": "0.1.0"},
Expand Down Expand Up @@ -51,18 +49,21 @@ def read_current_user(current_user: User = Depends(get_current_user)):


def test_openapi_schema():
client = TestClient(app)
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
assert response.json() == openapi_schema


def test_security_api_key():
response = client.get("/users/me", cookies={"key": "secret"})
client = TestClient(app, cookies={"key": "secret"})
response = client.get("/users/me")
assert response.status_code == 200, response.text
assert response.json() == {"username": "secret"}


def test_security_api_key_no_key():
client = TestClient(app)
response = client.get("/users/me")
assert response.status_code == 403, response.text
assert response.json() == {"detail": "Not authenticated"}
7 changes: 4 additions & 3 deletions tests/test_security_api_key_cookie_description.py
Expand Up @@ -22,8 +22,6 @@ def read_current_user(current_user: User = Depends(get_current_user)):
return current_user


client = TestClient(app)

openapi_schema = {
"openapi": "3.0.2",
"info": {"title": "FastAPI", "version": "0.1.0"},
Expand Down Expand Up @@ -56,18 +54,21 @@ def read_current_user(current_user: User = Depends(get_current_user)):


def test_openapi_schema():
client = TestClient(app)
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
assert response.json() == openapi_schema


def test_security_api_key():
response = client.get("/users/me", cookies={"key": "secret"})
client = TestClient(app, cookies={"key": "secret"})
response = client.get("/users/me")
assert response.status_code == 200, response.text
assert response.json() == {"username": "secret"}


def test_security_api_key_no_key():
client = TestClient(app)
response = client.get("/users/me")
assert response.status_code == 403, response.text
assert response.json() == {"detail": "Not authenticated"}
7 changes: 4 additions & 3 deletions tests/test_security_api_key_cookie_optional.py
Expand Up @@ -29,8 +29,6 @@ def read_current_user(current_user: User = Depends(get_current_user)):
return current_user


client = TestClient(app)

openapi_schema = {
"openapi": "3.0.2",
"info": {"title": "FastAPI", "version": "0.1.0"},
Expand Down Expand Up @@ -58,18 +56,21 @@ def read_current_user(current_user: User = Depends(get_current_user)):


def test_openapi_schema():
client = TestClient(app)
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
assert response.json() == openapi_schema


def test_security_api_key():
response = client.get("/users/me", cookies={"key": "secret"})
client = TestClient(app, cookies={"key": "secret"})
response = client.get("/users/me")
assert response.status_code == 200, response.text
assert response.json() == {"username": "secret"}


def test_security_api_key_no_key():
client = TestClient(app)
response = client.get("/users/me")
assert response.status_code == 200, response.text
assert response.json() == {"msg": "Create an account first"}
6 changes: 3 additions & 3 deletions tests/test_tuples.py
Expand Up @@ -252,16 +252,16 @@ def test_tuple_with_model_invalid():


def test_tuple_form_valid():
response = client.post("/tuple-form/", data=[("values", "1"), ("values", "2")])
response = client.post("/tuple-form/", data={"values": ("1", "2")})
assert response.status_code == 200, response.text
assert response.json() == [1, 2]


def test_tuple_form_invalid():
response = client.post(
"/tuple-form/", data=[("values", "1"), ("values", "2"), ("values", "3")]
"/tuple-form/", content=[("values", "1"), ("values", "2"), ("values", "3")]
)
assert response.status_code == 422, response.text

response = client.post("/tuple-form/", data=[("values", "1")])
response = client.post("/tuple-form/", content=[("values", "1")])
assert response.status_code == 422, response.text
Expand Up @@ -9,6 +9,6 @@ def test_middleware():
assert response.status_code == 200, response.text

client = TestClient(app)
response = client.get("/", allow_redirects=False)
response = client.get("/", follow_redirects=False)
assert response.status_code == 307, response.text
assert response.headers["location"] == "https://testserver/"
16 changes: 9 additions & 7 deletions tests/test_tutorial/test_body/test_tutorial001.py
Expand Up @@ -176,7 +176,7 @@ def test_post_broken_body():
response = client.post(
"/items/",
headers={"content-type": "application/json"},
data="{some broken json}",
content="{some broken json}",
)
assert response.status_code == 422, response.text
assert response.json() == {
Expand Down Expand Up @@ -214,7 +214,7 @@ def test_post_form_for_json():
def test_explicit_content_type():
response = client.post(
"/items/",
data='{"name": "Foo", "price": 50.5}',
content='{"name": "Foo", "price": 50.5}',
headers={"Content-Type": "application/json"},
)
assert response.status_code == 200, response.text
Expand All @@ -223,7 +223,7 @@ def test_explicit_content_type():
def test_geo_json():
response = client.post(
"/items/",
data='{"name": "Foo", "price": 50.5}',
content='{"name": "Foo", "price": 50.5}',
headers={"Content-Type": "application/geo+json"},
)
assert response.status_code == 200, response.text
Expand All @@ -232,7 +232,7 @@ def test_geo_json():
def test_no_content_type_is_json():
response = client.post(
"/items/",
data='{"name": "Foo", "price": 50.5}',
content='{"name": "Foo", "price": 50.5}',
)
assert response.status_code == 200, response.text
assert response.json() == {
Expand All @@ -255,17 +255,19 @@ def test_wrong_headers():
]
}

response = client.post("/items/", data=data, headers={"Content-Type": "text/plain"})
response = client.post(
"/items/", content=data, headers={"Content-Type": "text/plain"}
)
assert response.status_code == 422, response.text
assert response.json() == invalid_dict

response = client.post(
"/items/", data=data, headers={"Content-Type": "application/geo+json-seq"}
"/items/", content=data, headers={"Content-Type": "application/geo+json-seq"}
)
assert response.status_code == 422, response.text
assert response.json() == invalid_dict
response = client.post(
"/items/", data=data, headers={"Content-Type": "application/not-really-json"}
"/items/", content=data, headers={"Content-Type": "application/not-really-json"}
)
assert response.status_code == 422, response.text
assert response.json() == invalid_dict
Expand Down

0 comments on commit 5d8a327

Please sign in to comment.