diff --git a/fastapi/security/api_key.py b/fastapi/security/api_key.py index bca5c721a6996..24ddbf4825907 100644 --- a/fastapi/security/api_key.py +++ b/fastapi/security/api_key.py @@ -54,7 +54,7 @@ def __init__( self.auto_error = auto_error async def __call__(self, request: Request) -> Optional[str]: - api_key: str = request.headers.get(self.model.name) + api_key = request.headers.get(self.model.name) if not api_key: if self.auto_error: raise HTTPException( diff --git a/fastapi/security/http.py b/fastapi/security/http.py index 1b473c69e7cc3..8b677299dde42 100644 --- a/fastapi/security/http.py +++ b/fastapi/security/http.py @@ -38,7 +38,7 @@ def __init__( async def __call__( self, request: Request ) -> Optional[HTTPAuthorizationCredentials]: - authorization: str = request.headers.get("Authorization") + authorization = request.headers.get("Authorization") scheme, credentials = get_authorization_scheme_param(authorization) if not (authorization and scheme and credentials): if self.auto_error: @@ -67,7 +67,7 @@ def __init__( async def __call__( # type: ignore self, request: Request ) -> Optional[HTTPBasicCredentials]: - authorization: str = request.headers.get("Authorization") + authorization = request.headers.get("Authorization") scheme, param = get_authorization_scheme_param(authorization) if self.realm: unauthorized_headers = {"WWW-Authenticate": f'Basic realm="{self.realm}"'} @@ -113,7 +113,7 @@ def __init__( async def __call__( self, request: Request ) -> Optional[HTTPAuthorizationCredentials]: - authorization: str = request.headers.get("Authorization") + authorization = request.headers.get("Authorization") scheme, credentials = get_authorization_scheme_param(authorization) if not (authorization and scheme and credentials): if self.auto_error: @@ -148,7 +148,7 @@ def __init__( async def __call__( self, request: Request ) -> Optional[HTTPAuthorizationCredentials]: - authorization: str = request.headers.get("Authorization") + authorization = request.headers.get("Authorization") scheme, credentials = get_authorization_scheme_param(authorization) if not (authorization and scheme and credentials): if self.auto_error: diff --git a/fastapi/security/oauth2.py b/fastapi/security/oauth2.py index 653c3010e58a3..eb6b4277cf8e3 100644 --- a/fastapi/security/oauth2.py +++ b/fastapi/security/oauth2.py @@ -126,7 +126,7 @@ def __init__( self.auto_error = auto_error async def __call__(self, request: Request) -> Optional[str]: - authorization: str = request.headers.get("Authorization") + authorization = request.headers.get("Authorization") if not authorization: if self.auto_error: raise HTTPException( @@ -157,7 +157,7 @@ def __init__( ) async def __call__(self, request: Request) -> Optional[str]: - authorization: str = request.headers.get("Authorization") + authorization = request.headers.get("Authorization") scheme, param = get_authorization_scheme_param(authorization) if not authorization or scheme.lower() != "bearer": if self.auto_error: @@ -200,7 +200,7 @@ def __init__( ) async def __call__(self, request: Request) -> Optional[str]: - authorization: str = request.headers.get("Authorization") + authorization = request.headers.get("Authorization") scheme, param = get_authorization_scheme_param(authorization) if not authorization or scheme.lower() != "bearer": if self.auto_error: diff --git a/fastapi/security/open_id_connect_url.py b/fastapi/security/open_id_connect_url.py index dfe9f7b255e6e..393614f7cbc3b 100644 --- a/fastapi/security/open_id_connect_url.py +++ b/fastapi/security/open_id_connect_url.py @@ -23,7 +23,7 @@ def __init__( self.auto_error = auto_error async def __call__(self, request: Request) -> Optional[str]: - authorization: str = request.headers.get("Authorization") + authorization = request.headers.get("Authorization") if not authorization: if self.auto_error: raise HTTPException( diff --git a/fastapi/security/utils.py b/fastapi/security/utils.py index 2da0dd20f30a6..fa7a450b74e81 100644 --- a/fastapi/security/utils.py +++ b/fastapi/security/utils.py @@ -1,7 +1,9 @@ -from typing import Tuple +from typing import Optional, Tuple -def get_authorization_scheme_param(authorization_header_value: str) -> Tuple[str, str]: +def get_authorization_scheme_param( + authorization_header_value: Optional[str], +) -> Tuple[str, str]: if not authorization_header_value: return "", "" scheme, _, param = authorization_header_value.partition(" ") diff --git a/pyproject.toml b/pyproject.toml index dec4cff70c1a9..591a399249db8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,7 +38,7 @@ classifiers = [ "Topic :: Internet :: WWW/HTTP", ] dependencies = [ - "starlette==0.20.4", + "starlette>=0.21.0,<0.22.0", "pydantic >=1.6.2,!=1.7,!=1.7.1,!=1.7.2,!=1.7.3,!=1.8,!=1.8.1,<2.0.0", ] dynamic = ["version"] diff --git a/tests/test_enforce_once_required_parameter.py b/tests/test_enforce_once_required_parameter.py index ba8c7353fa97f..bf05aa5852ac9 100644 --- a/tests/test_enforce_once_required_parameter.py +++ b/tests/test_enforce_once_required_parameter.py @@ -101,7 +101,7 @@ def test_schema(): def test_get_invalid(): - response = client.get("/foo", params={"client_id": None}) + response = client.get("/foo") assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY diff --git a/tests/test_extra_routes.py b/tests/test_extra_routes.py index 491ba61c68040..e979628a5cc5e 100644 --- a/tests/test_extra_routes.py +++ b/tests/test_extra_routes.py @@ -333,7 +333,7 @@ def test_get_api_route_not_decorated(): def test_delete(): - response = client.delete("/items/foo", json={"name": "Foo"}) + response = client.request("DELETE", "/items/foo", json={"name": "Foo"}) assert response.status_code == 200, response.text assert response.json() == {"item_id": "foo", "item": {"name": "Foo", "price": None}} diff --git a/tests/test_get_request_body.py b/tests/test_get_request_body.py index 88b9d839f5a6b..52a052faab1b2 100644 --- a/tests/test_get_request_body.py +++ b/tests/test_get_request_body.py @@ -104,5 +104,5 @@ def test_openapi_schema(): def test_get_with_body(): body = {"name": "Foo", "description": "Some description", "price": 5.5} - response = client.get("/product", json=body) + response = client.request("GET", "/product", json=body) assert response.json() == body diff --git a/tests/test_param_include_in_schema.py b/tests/test_param_include_in_schema.py index 214f039b67d3e..cb182a1cd4bf3 100644 --- a/tests/test_param_include_in_schema.py +++ b/tests/test_param_include_in_schema.py @@ -33,8 +33,6 @@ async def hidden_query( return {"hidden_query": hidden_query} -client = TestClient(app) - openapi_shema = { "openapi": "3.0.2", "info": {"title": "FastAPI", "version": "0.1.0"}, @@ -161,6 +159,7 @@ async def hidden_query( def test_openapi_schema(): + client = TestClient(app) response = client.get("/openapi.json") assert response.status_code == 200 assert response.json() == openapi_shema @@ -184,7 +183,8 @@ def test_openapi_schema(): ], ) def test_hidden_cookie(path, cookies, expected_status, expected_response): - response = client.get(path, cookies=cookies) + client = TestClient(app, cookies=cookies) + response = client.get(path) assert response.status_code == expected_status assert response.json() == expected_response @@ -207,12 +207,14 @@ def test_hidden_cookie(path, cookies, expected_status, expected_response): ], ) def test_hidden_header(path, headers, expected_status, expected_response): + client = TestClient(app) response = client.get(path, headers=headers) assert response.status_code == expected_status assert response.json() == expected_response def test_hidden_path(): + client = TestClient(app) response = client.get("/hidden_path/hidden_path") assert response.status_code == 200 assert response.json() == {"hidden_path": "hidden_path"} @@ -234,6 +236,7 @@ def test_hidden_path(): ], ) def test_hidden_query(path, expected_status, expected_response): + client = TestClient(app) response = client.get(path) assert response.status_code == expected_status assert response.json() == expected_response diff --git a/tests/test_security_api_key_cookie.py b/tests/test_security_api_key_cookie.py index a5b2e44f0ce50..0bf4e9bb3ad25 100644 --- a/tests/test_security_api_key_cookie.py +++ b/tests/test_security_api_key_cookie.py @@ -22,8 +22,6 @@ def read_current_user(current_user: User = Depends(get_current_user)): return current_user -client = TestClient(app) - openapi_schema = { "openapi": "3.0.2", "info": {"title": "FastAPI", "version": "0.1.0"}, @@ -51,18 +49,21 @@ def read_current_user(current_user: User = Depends(get_current_user)): def test_openapi_schema(): + client = TestClient(app) response = client.get("/openapi.json") assert response.status_code == 200, response.text assert response.json() == openapi_schema def test_security_api_key(): - response = client.get("/users/me", cookies={"key": "secret"}) + client = TestClient(app, cookies={"key": "secret"}) + response = client.get("/users/me") assert response.status_code == 200, response.text assert response.json() == {"username": "secret"} def test_security_api_key_no_key(): + client = TestClient(app) response = client.get("/users/me") assert response.status_code == 403, response.text assert response.json() == {"detail": "Not authenticated"} diff --git a/tests/test_security_api_key_cookie_description.py b/tests/test_security_api_key_cookie_description.py index 2cd3565b43ad6..ed4e652394482 100644 --- a/tests/test_security_api_key_cookie_description.py +++ b/tests/test_security_api_key_cookie_description.py @@ -22,8 +22,6 @@ def read_current_user(current_user: User = Depends(get_current_user)): return current_user -client = TestClient(app) - openapi_schema = { "openapi": "3.0.2", "info": {"title": "FastAPI", "version": "0.1.0"}, @@ -56,18 +54,21 @@ def read_current_user(current_user: User = Depends(get_current_user)): def test_openapi_schema(): + client = TestClient(app) response = client.get("/openapi.json") assert response.status_code == 200, response.text assert response.json() == openapi_schema def test_security_api_key(): - response = client.get("/users/me", cookies={"key": "secret"}) + client = TestClient(app, cookies={"key": "secret"}) + response = client.get("/users/me") assert response.status_code == 200, response.text assert response.json() == {"username": "secret"} def test_security_api_key_no_key(): + client = TestClient(app) response = client.get("/users/me") assert response.status_code == 403, response.text assert response.json() == {"detail": "Not authenticated"} diff --git a/tests/test_security_api_key_cookie_optional.py b/tests/test_security_api_key_cookie_optional.py index 96a64f09a6e7d..3e7aa81c07a5a 100644 --- a/tests/test_security_api_key_cookie_optional.py +++ b/tests/test_security_api_key_cookie_optional.py @@ -29,8 +29,6 @@ def read_current_user(current_user: User = Depends(get_current_user)): return current_user -client = TestClient(app) - openapi_schema = { "openapi": "3.0.2", "info": {"title": "FastAPI", "version": "0.1.0"}, @@ -58,18 +56,21 @@ def read_current_user(current_user: User = Depends(get_current_user)): def test_openapi_schema(): + client = TestClient(app) response = client.get("/openapi.json") assert response.status_code == 200, response.text assert response.json() == openapi_schema def test_security_api_key(): - response = client.get("/users/me", cookies={"key": "secret"}) + client = TestClient(app, cookies={"key": "secret"}) + response = client.get("/users/me") assert response.status_code == 200, response.text assert response.json() == {"username": "secret"} def test_security_api_key_no_key(): + client = TestClient(app) response = client.get("/users/me") assert response.status_code == 200, response.text assert response.json() == {"msg": "Create an account first"} diff --git a/tests/test_tuples.py b/tests/test_tuples.py index 18ec2d0489912..9fa65e1def082 100644 --- a/tests/test_tuples.py +++ b/tests/test_tuples.py @@ -252,16 +252,16 @@ def test_tuple_with_model_invalid(): def test_tuple_form_valid(): - response = client.post("/tuple-form/", data=[("values", "1"), ("values", "2")]) + response = client.post("/tuple-form/", data={"values": ("1", "2")}) assert response.status_code == 200, response.text assert response.json() == [1, 2] def test_tuple_form_invalid(): response = client.post( - "/tuple-form/", data=[("values", "1"), ("values", "2"), ("values", "3")] + "/tuple-form/", content=[("values", "1"), ("values", "2"), ("values", "3")] ) assert response.status_code == 422, response.text - response = client.post("/tuple-form/", data=[("values", "1")]) + response = client.post("/tuple-form/", content=[("values", "1")]) assert response.status_code == 422, response.text diff --git a/tests/test_tutorial/test_advanced_middleware/test_tutorial001.py b/tests/test_tutorial/test_advanced_middleware/test_tutorial001.py index 17165c0fc61a3..157fa5caf1a03 100644 --- a/tests/test_tutorial/test_advanced_middleware/test_tutorial001.py +++ b/tests/test_tutorial/test_advanced_middleware/test_tutorial001.py @@ -9,6 +9,6 @@ def test_middleware(): assert response.status_code == 200, response.text client = TestClient(app) - response = client.get("/", allow_redirects=False) + response = client.get("/", follow_redirects=False) assert response.status_code == 307, response.text assert response.headers["location"] == "https://testserver/" diff --git a/tests/test_tutorial/test_body/test_tutorial001.py b/tests/test_tutorial/test_body/test_tutorial001.py index 8dbaf15dbef06..65cdc758adc00 100644 --- a/tests/test_tutorial/test_body/test_tutorial001.py +++ b/tests/test_tutorial/test_body/test_tutorial001.py @@ -176,7 +176,7 @@ def test_post_broken_body(): response = client.post( "/items/", headers={"content-type": "application/json"}, - data="{some broken json}", + content="{some broken json}", ) assert response.status_code == 422, response.text assert response.json() == { @@ -214,7 +214,7 @@ def test_post_form_for_json(): def test_explicit_content_type(): response = client.post( "/items/", - data='{"name": "Foo", "price": 50.5}', + content='{"name": "Foo", "price": 50.5}', headers={"Content-Type": "application/json"}, ) assert response.status_code == 200, response.text @@ -223,7 +223,7 @@ def test_explicit_content_type(): def test_geo_json(): response = client.post( "/items/", - data='{"name": "Foo", "price": 50.5}', + content='{"name": "Foo", "price": 50.5}', headers={"Content-Type": "application/geo+json"}, ) assert response.status_code == 200, response.text @@ -232,7 +232,7 @@ def test_geo_json(): def test_no_content_type_is_json(): response = client.post( "/items/", - data='{"name": "Foo", "price": 50.5}', + content='{"name": "Foo", "price": 50.5}', ) assert response.status_code == 200, response.text assert response.json() == { @@ -255,17 +255,19 @@ def test_wrong_headers(): ] } - response = client.post("/items/", data=data, headers={"Content-Type": "text/plain"}) + response = client.post( + "/items/", content=data, headers={"Content-Type": "text/plain"} + ) assert response.status_code == 422, response.text assert response.json() == invalid_dict response = client.post( - "/items/", data=data, headers={"Content-Type": "application/geo+json-seq"} + "/items/", content=data, headers={"Content-Type": "application/geo+json-seq"} ) assert response.status_code == 422, response.text assert response.json() == invalid_dict response = client.post( - "/items/", data=data, headers={"Content-Type": "application/not-really-json"} + "/items/", content=data, headers={"Content-Type": "application/not-really-json"} ) assert response.status_code == 422, response.text assert response.json() == invalid_dict diff --git a/tests/test_tutorial/test_body/test_tutorial001_py310.py b/tests/test_tutorial/test_body/test_tutorial001_py310.py index dd9d9911e402c..83bcb68f30d85 100644 --- a/tests/test_tutorial/test_body/test_tutorial001_py310.py +++ b/tests/test_tutorial/test_body/test_tutorial001_py310.py @@ -185,7 +185,7 @@ def test_post_broken_body(client: TestClient): response = client.post( "/items/", headers={"content-type": "application/json"}, - data="{some broken json}", + content="{some broken json}", ) assert response.status_code == 422, response.text assert response.json() == { @@ -225,7 +225,7 @@ def test_post_form_for_json(client: TestClient): def test_explicit_content_type(client: TestClient): response = client.post( "/items/", - data='{"name": "Foo", "price": 50.5}', + content='{"name": "Foo", "price": 50.5}', headers={"Content-Type": "application/json"}, ) assert response.status_code == 200, response.text @@ -235,7 +235,7 @@ def test_explicit_content_type(client: TestClient): def test_geo_json(client: TestClient): response = client.post( "/items/", - data='{"name": "Foo", "price": 50.5}', + content='{"name": "Foo", "price": 50.5}', headers={"Content-Type": "application/geo+json"}, ) assert response.status_code == 200, response.text @@ -245,7 +245,7 @@ def test_geo_json(client: TestClient): def test_no_content_type_is_json(client: TestClient): response = client.post( "/items/", - data='{"name": "Foo", "price": 50.5}', + content='{"name": "Foo", "price": 50.5}', ) assert response.status_code == 200, response.text assert response.json() == { @@ -269,17 +269,19 @@ def test_wrong_headers(client: TestClient): ] } - response = client.post("/items/", data=data, headers={"Content-Type": "text/plain"}) + response = client.post( + "/items/", content=data, headers={"Content-Type": "text/plain"} + ) assert response.status_code == 422, response.text assert response.json() == invalid_dict response = client.post( - "/items/", data=data, headers={"Content-Type": "application/geo+json-seq"} + "/items/", content=data, headers={"Content-Type": "application/geo+json-seq"} ) assert response.status_code == 422, response.text assert response.json() == invalid_dict response = client.post( - "/items/", data=data, headers={"Content-Type": "application/not-really-json"} + "/items/", content=data, headers={"Content-Type": "application/not-really-json"} ) assert response.status_code == 422, response.text assert response.json() == invalid_dict diff --git a/tests/test_tutorial/test_cookie_params/test_tutorial001.py b/tests/test_tutorial/test_cookie_params/test_tutorial001.py index edccffec1e87e..38ae211db361d 100644 --- a/tests/test_tutorial/test_cookie_params/test_tutorial001.py +++ b/tests/test_tutorial/test_cookie_params/test_tutorial001.py @@ -3,8 +3,6 @@ from docs_src.cookie_params.tutorial001 import app -client = TestClient(app) - openapi_schema = { "openapi": "3.0.2", "info": {"title": "FastAPI", "version": "0.1.0"}, @@ -88,6 +86,7 @@ ], ) def test(path, cookies, expected_status, expected_response): - response = client.get(path, cookies=cookies) + client = TestClient(app, cookies=cookies) + response = client.get(path) assert response.status_code == expected_status assert response.json() == expected_response diff --git a/tests/test_tutorial/test_cookie_params/test_tutorial001_py310.py b/tests/test_tutorial/test_cookie_params/test_tutorial001_py310.py index 5caa5c440039e..5ad52fb5e1258 100644 --- a/tests/test_tutorial/test_cookie_params/test_tutorial001_py310.py +++ b/tests/test_tutorial/test_cookie_params/test_tutorial001_py310.py @@ -70,14 +70,6 @@ } -@pytest.fixture(name="client") -def get_client(): - from docs_src.cookie_params.tutorial001_py310 import app - - client = TestClient(app) - return client - - @needs_py310 @pytest.mark.parametrize( "path,cookies,expected_status,expected_response", @@ -94,7 +86,10 @@ def get_client(): ("/items", {"session": "cookiesession"}, 200, {"ads_id": None}), ], ) -def test(path, cookies, expected_status, expected_response, client: TestClient): - response = client.get(path, cookies=cookies) +def test(path, cookies, expected_status, expected_response): + from docs_src.cookie_params.tutorial001_py310 import app + + client = TestClient(app, cookies=cookies) + response = client.get(path) assert response.status_code == expected_status assert response.json() == expected_response diff --git a/tests/test_tutorial/test_custom_request_and_route/test_tutorial001.py b/tests/test_tutorial/test_custom_request_and_route/test_tutorial001.py index 3eb5822e28816..e6da630e8813a 100644 --- a/tests/test_tutorial/test_custom_request_and_route/test_tutorial001.py +++ b/tests/test_tutorial/test_custom_request_and_route/test_tutorial001.py @@ -26,7 +26,7 @@ def test_gzip_request(compress): data = gzip.compress(data) headers["Content-Encoding"] = "gzip" headers["Content-Type"] = "application/json" - response = client.post("/sum", data=data, headers=headers) + response = client.post("/sum", content=data, headers=headers) assert response.json() == {"sum": n} diff --git a/tests/test_tutorial/test_custom_response/test_tutorial006.py b/tests/test_tutorial/test_custom_response/test_tutorial006.py index 72bbfd2777209..9b10916e588a3 100644 --- a/tests/test_tutorial/test_custom_response/test_tutorial006.py +++ b/tests/test_tutorial/test_custom_response/test_tutorial006.py @@ -32,6 +32,6 @@ def test_openapi_schema(): def test_get(): - response = client.get("/typer", allow_redirects=False) + response = client.get("/typer", follow_redirects=False) assert response.status_code == 307, response.text assert response.headers["location"] == "https://typer.tiangolo.com" diff --git a/tests/test_tutorial/test_custom_response/test_tutorial006b.py b/tests/test_tutorial/test_custom_response/test_tutorial006b.py index ac5a76d34d061..b3e60e86a38b9 100644 --- a/tests/test_tutorial/test_custom_response/test_tutorial006b.py +++ b/tests/test_tutorial/test_custom_response/test_tutorial006b.py @@ -27,6 +27,6 @@ def test_openapi_schema(): def test_redirect_response_class(): - response = client.get("/fastapi", allow_redirects=False) + response = client.get("/fastapi", follow_redirects=False) assert response.status_code == 307 assert response.headers["location"] == "https://fastapi.tiangolo.com" diff --git a/tests/test_tutorial/test_custom_response/test_tutorial006c.py b/tests/test_tutorial/test_custom_response/test_tutorial006c.py index 009225e8c58c7..0cb6ddaa330eb 100644 --- a/tests/test_tutorial/test_custom_response/test_tutorial006c.py +++ b/tests/test_tutorial/test_custom_response/test_tutorial006c.py @@ -27,6 +27,6 @@ def test_openapi_schema(): def test_redirect_status_code(): - response = client.get("/pydantic", allow_redirects=False) + response = client.get("/pydantic", follow_redirects=False) assert response.status_code == 302 assert response.headers["location"] == "https://pydantic-docs.helpmanual.io/" diff --git a/tests/test_tutorial/test_path_operation_advanced_configurations/test_tutorial006.py b/tests/test_tutorial/test_path_operation_advanced_configurations/test_tutorial006.py index 5533b29571433..330b4e2c791ba 100644 --- a/tests/test_tutorial/test_path_operation_advanced_configurations/test_tutorial006.py +++ b/tests/test_tutorial/test_path_operation_advanced_configurations/test_tutorial006.py @@ -47,7 +47,7 @@ def test_openapi_schema(): def test_post(): - response = client.post("/items/", data=b"this is actually not validated") + response = client.post("/items/", content=b"this is actually not validated") assert response.status_code == 200, response.text assert response.json() == { "size": 30, diff --git a/tests/test_tutorial/test_path_operation_advanced_configurations/test_tutorial007.py b/tests/test_tutorial/test_path_operation_advanced_configurations/test_tutorial007.py index cb5dbc8eb010d..076f60b2f079d 100644 --- a/tests/test_tutorial/test_path_operation_advanced_configurations/test_tutorial007.py +++ b/tests/test_tutorial/test_path_operation_advanced_configurations/test_tutorial007.py @@ -58,7 +58,7 @@ def test_post(): - x-men - x-avengers """ - response = client.post("/items/", data=yaml_data) + response = client.post("/items/", content=yaml_data) assert response.status_code == 200, response.text assert response.json() == { "name": "Deadpoolio", @@ -74,7 +74,7 @@ def test_post_broken_yaml(): x - x-men x - x-avengers """ - response = client.post("/items/", data=yaml_data) + response = client.post("/items/", content=yaml_data) assert response.status_code == 422, response.text assert response.json() == {"detail": "Invalid YAML"} @@ -88,7 +88,7 @@ def test_post_invalid(): - x-avengers - sneaky: object """ - response = client.post("/items/", data=yaml_data) + response = client.post("/items/", content=yaml_data) assert response.status_code == 422, response.text assert response.json() == { "detail": [ diff --git a/tests/test_tutorial/test_websockets/test_tutorial002.py b/tests/test_tutorial/test_websockets/test_tutorial002.py index a8523c9c4fcfa..bb5ccbf8ef61e 100644 --- a/tests/test_tutorial/test_websockets/test_tutorial002.py +++ b/tests/test_tutorial/test_websockets/test_tutorial002.py @@ -4,20 +4,18 @@ from docs_src.websockets.tutorial002 import app -client = TestClient(app) - def test_main(): + client = TestClient(app) response = client.get("/") assert response.status_code == 200, response.text assert b"" in response.content def test_websocket_with_cookie(): + client = TestClient(app, cookies={"session": "fakesession"}) with pytest.raises(WebSocketDisconnect): - with client.websocket_connect( - "/items/foo/ws", cookies={"session": "fakesession"} - ) as websocket: + with client.websocket_connect("/items/foo/ws") as websocket: message = "Message one" websocket.send_text(message) data = websocket.receive_text() @@ -33,6 +31,7 @@ def test_websocket_with_cookie(): def test_websocket_with_header(): + client = TestClient(app) with pytest.raises(WebSocketDisconnect): with client.websocket_connect("/items/bar/ws?token=some-token") as websocket: message = "Message one" @@ -50,6 +49,7 @@ def test_websocket_with_header(): def test_websocket_with_header_and_query(): + client = TestClient(app) with pytest.raises(WebSocketDisconnect): with client.websocket_connect("/items/2/ws?q=3&token=some-token") as websocket: message = "Message one" @@ -71,6 +71,7 @@ def test_websocket_with_header_and_query(): def test_websocket_no_credentials(): + client = TestClient(app) with pytest.raises(WebSocketDisconnect): with client.websocket_connect("/items/foo/ws"): pytest.fail( @@ -79,6 +80,7 @@ def test_websocket_no_credentials(): def test_websocket_invalid_data(): + client = TestClient(app) with pytest.raises(WebSocketDisconnect): with client.websocket_connect("/items/foo/ws?q=bar&token=some-token"): pytest.fail(