Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WebSocket handling support for HTTP security dependencies #10147

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 4 additions & 4 deletions fastapi/security/api_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.security.base import SecurityBase
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.requests import HTTPConnection
from starlette.status import HTTP_403_FORBIDDEN
from typing_extensions import Annotated, Doc

Expand Down Expand Up @@ -99,7 +99,7 @@ def __init__(
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error

async def __call__(self, request: Request) -> Optional[str]:
async def __call__(self, request: HTTPConnection) -> Optional[str]:
api_key = request.query_params.get(self.model.name)
if not api_key:
if self.auto_error:
Expand Down Expand Up @@ -194,7 +194,7 @@ def __init__(
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error

async def __call__(self, request: Request) -> Optional[str]:
async def __call__(self, request: HTTPConnection) -> Optional[str]:
api_key = request.headers.get(self.model.name)
if not api_key:
if self.auto_error:
Expand Down Expand Up @@ -289,7 +289,7 @@ def __init__(
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error

async def __call__(self, request: Request) -> Optional[str]:
async def __call__(self, request: HTTPConnection) -> Optional[str]:
api_key = request.cookies.get(self.model.name)
if not api_key:
if self.auto_error:
Expand Down
10 changes: 5 additions & 5 deletions fastapi/security/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param
from pydantic import BaseModel
from starlette.requests import Request
from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
from typing_extensions import Annotated, Doc

Expand Down Expand Up @@ -80,7 +80,7 @@ def __init__(
self.auto_error = auto_error

async def __call__(
self, request: Request
self, request: HTTPConnection
) -> Optional[HTTPAuthorizationCredentials]:
authorization = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization)
Expand Down Expand Up @@ -185,7 +185,7 @@ def __init__(
self.auto_error = auto_error

async def __call__( # type: ignore
self, request: Request
self, request: HTTPConnection
) -> Optional[HTTPBasicCredentials]:
authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
Expand Down Expand Up @@ -299,7 +299,7 @@ def __init__(
self.auto_error = auto_error

async def __call__(
self, request: Request
self, request: HTTPConnection
) -> Optional[HTTPAuthorizationCredentials]:
authorization = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization)
Expand Down Expand Up @@ -401,7 +401,7 @@ def __init__(
self.auto_error = auto_error

async def __call__(
self, request: Request
self, request: HTTPConnection
) -> Optional[HTTPAuthorizationCredentials]:
authorization = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization)
Expand Down
8 changes: 4 additions & 4 deletions fastapi/security/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from fastapi.param_functions import Form
from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param
from starlette.requests import Request
from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN

# TODO: import from typing when deprecating Python 3.9
Expand Down Expand Up @@ -376,7 +376,7 @@ def __init__(
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error

async def __call__(self, request: Request) -> Optional[str]:
async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization")
if not authorization:
if self.auto_error:
Expand Down Expand Up @@ -470,7 +470,7 @@ def __init__(
auto_error=auto_error,
)

async def __call__(self, request: Request) -> Optional[str]:
async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer":
Expand Down Expand Up @@ -580,7 +580,7 @@ def __init__(
auto_error=auto_error,
)

async def __call__(self, request: Request) -> Optional[str]:
async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer":
Expand Down
4 changes: 2 additions & 2 deletions fastapi/security/open_id_connect_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel
from fastapi.security.base import SecurityBase
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.requests import HTTPConnection
from starlette.status import HTTP_403_FORBIDDEN
from typing_extensions import Annotated, Doc

Expand Down Expand Up @@ -72,7 +72,7 @@ def __init__(
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error

async def __call__(self, request: Request) -> Optional[str]:
async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization")
if not authorization:
if self.auto_error:
Expand Down
29 changes: 28 additions & 1 deletion tests/test_security_http_base_optional.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional

from fastapi import FastAPI, Security
from fastapi import FastAPI, Security, WebSocket
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBase
from fastapi.testclient import TestClient

Expand All @@ -18,6 +18,19 @@ def read_current_user(
return {"scheme": credentials.scheme, "credentials": credentials.credentials}


@app.websocket("/users/timeline")
async def read_user_timeline(
websocket: WebSocket,
credentials: Optional[HTTPAuthorizationCredentials] = Security(security),
):
await websocket.accept()
await websocket.send_json(
{"scheme": credentials.scheme, "credentials": credentials.credentials}
if credentials
else {"msg": "Create an account first"}
)


client = TestClient(app)


Expand All @@ -33,6 +46,20 @@ def test_security_http_base_no_credentials():
assert response.json() == {"msg": "Create an account first"}


def test_security_http_base_with_ws():
with client.websocket_connect(
"/users/timeline", headers={"Authorization": "Other foobar"}
) as websocket:
data = websocket.receive_json()
assert data == {"scheme": "Other", "credentials": "foobar"}


def test_security_http_base_with_ws_no_credentials():
with client.websocket_connect("/users/timeline") as websocket:
data = websocket.receive_json()
assert data == {"msg": "Create an account first"}


def test_openapi_schema():
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
Expand Down