Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WebSocket handling support for HTTP security dependencies #10147

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
12 changes: 8 additions & 4 deletions fastapi/security/api_key.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

from fastapi.openapi.models import APIKey, APIKeyIn
from fastapi.security.base import SecurityBase
from fastapi.security.utils import handle_exc_for_ws
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.requests import HTTPConnection
from starlette.status import HTTP_403_FORBIDDEN


Expand All @@ -28,7 +29,8 @@ def __init__(
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error

async def __call__(self, request: Request) -> Optional[str]:
@handle_exc_for_ws
async def __call__(self, request: HTTPConnection) -> Optional[str]:
api_key = request.query_params.get(self.model.name)
if not api_key:
if self.auto_error:
Expand Down Expand Up @@ -57,7 +59,8 @@ def __init__(
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error

async def __call__(self, request: Request) -> Optional[str]:
@handle_exc_for_ws
async def __call__(self, request: HTTPConnection) -> Optional[str]:
api_key = request.headers.get(self.model.name)
if not api_key:
if self.auto_error:
Expand Down Expand Up @@ -86,7 +89,8 @@ def __init__(
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error

async def __call__(self, request: Request) -> Optional[str]:
@handle_exc_for_ws
async def __call__(self, request: HTTPConnection) -> Optional[str]:
api_key = request.cookies.get(self.model.name)
if not api_key:
if self.auto_error:
Expand Down
16 changes: 10 additions & 6 deletions fastapi/security/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from fastapi.openapi.models import HTTPBase as HTTPBaseModel
from fastapi.openapi.models import HTTPBearer as HTTPBearerModel
from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param
from fastapi.security.utils import get_authorization_scheme_param, handle_exc_for_ws
from pydantic import BaseModel
from starlette.requests import Request
from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN


Expand All @@ -35,8 +35,9 @@ def __init__(
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error

@handle_exc_for_ws
async def __call__(
self, request: Request
self, request: HTTPConnection
) -> Optional[HTTPAuthorizationCredentials]:
authorization = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization)
Expand Down Expand Up @@ -64,8 +65,9 @@ def __init__(
self.realm = realm
self.auto_error = auto_error

@handle_exc_for_ws
async def __call__( # type: ignore
self, request: Request
self, request: HTTPConnection
) -> Optional[HTTPBasicCredentials]:
authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
Expand Down Expand Up @@ -110,8 +112,9 @@ def __init__(
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error

@handle_exc_for_ws
async def __call__(
self, request: Request
self, request: HTTPConnection
) -> Optional[HTTPAuthorizationCredentials]:
authorization = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization)
Expand Down Expand Up @@ -145,8 +148,9 @@ def __init__(
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error

@handle_exc_for_ws
async def __call__(
self, request: Request
self, request: HTTPConnection
) -> Optional[HTTPAuthorizationCredentials]:
authorization = request.headers.get("Authorization")
scheme, credentials = get_authorization_scheme_param(authorization)
Expand Down
13 changes: 8 additions & 5 deletions fastapi/security/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from fastapi.openapi.models import OAuthFlows as OAuthFlowsModel
from fastapi.param_functions import Form
from fastapi.security.base import SecurityBase
from fastapi.security.utils import get_authorization_scheme_param
from starlette.requests import Request
from fastapi.security.utils import get_authorization_scheme_param, handle_exc_for_ws
from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN

# TODO: import from typing when deprecating Python 3.9
Expand Down Expand Up @@ -131,7 +131,8 @@ def __init__(
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error

async def __call__(self, request: Request) -> Optional[str]:
@handle_exc_for_ws
async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization")
if not authorization:
if self.auto_error:
Expand Down Expand Up @@ -164,7 +165,8 @@ def __init__(
auto_error=auto_error,
)

async def __call__(self, request: Request) -> Optional[str]:
@handle_exc_for_ws
async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer":
Expand Down Expand Up @@ -210,7 +212,8 @@ def __init__(
auto_error=auto_error,
)

async def __call__(self, request: Request) -> Optional[str]:
@handle_exc_for_ws
async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization")
scheme, param = get_authorization_scheme_param(authorization)
if not authorization or scheme.lower() != "bearer":
Expand Down
6 changes: 4 additions & 2 deletions fastapi/security/open_id_connect_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@

from fastapi.openapi.models import OpenIdConnect as OpenIdConnectModel
from fastapi.security.base import SecurityBase
from fastapi.security.utils import handle_exc_for_ws
from starlette.exceptions import HTTPException
from starlette.requests import Request
from starlette.requests import HTTPConnection
from starlette.status import HTTP_403_FORBIDDEN


Expand All @@ -22,7 +23,8 @@ def __init__(
self.scheme_name = scheme_name or self.__class__.__name__
self.auto_error = auto_error

async def __call__(self, request: Request) -> Optional[str]:
@handle_exc_for_ws
async def __call__(self, request: HTTPConnection) -> Optional[str]:
authorization = request.headers.get("Authorization")
if not authorization:
if self.auto_error:
Expand Down
30 changes: 29 additions & 1 deletion fastapi/security/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
from typing import Optional, Tuple
from functools import wraps
from typing import Any, Awaitable, Callable, Optional, Tuple, TypeVar

from fastapi.exceptions import HTTPException, WebSocketException
from starlette.requests import HTTPConnection
from starlette.status import WS_1008_POLICY_VIOLATION
from starlette.websockets import WebSocket


def get_authorization_scheme_param(
Expand All @@ -8,3 +14,25 @@ def get_authorization_scheme_param(
return "", ""
scheme, _, param = authorization_header_value.partition(" ")
return scheme, param


_SecurityDepFunc = TypeVar(
"_SecurityDepFunc", bound=Callable[[Any, HTTPConnection], Awaitable[Any]]
)


def handle_exc_for_ws(func: _SecurityDepFunc) -> _SecurityDepFunc:
@wraps(func)
async def wrapper(self: Any, request: HTTPConnection) -> Any:
try:
return await func(self, request)
except HTTPException as e:
if not isinstance(request, WebSocket):
raise e
# close before accepted with result a HTTP 403 so the exception argument is ignored
# ref: https://asgi.readthedocs.io/en/latest/specs/www.html#close-send-event
raise WebSocketException(
code=WS_1008_POLICY_VIOLATION, reason=e.detail
) from None

return wrapper # type: ignore
mnixry marked this conversation as resolved.
Show resolved Hide resolved
29 changes: 28 additions & 1 deletion tests/test_security_http_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from fastapi import FastAPI, Security
import pytest
from fastapi import FastAPI, Security, WebSocket
from fastapi.security.http import HTTPAuthorizationCredentials, HTTPBase
from fastapi.testclient import TestClient
from starlette.websockets import WebSocketDisconnect

app = FastAPI()

Expand All @@ -12,6 +14,16 @@ def read_current_user(credentials: HTTPAuthorizationCredentials = Security(secur
return {"scheme": credentials.scheme, "credentials": credentials.credentials}


@app.websocket("/users/timeline")
async def read_user_timeline(
websocket: WebSocket, credentials: HTTPAuthorizationCredentials = Security(security)
):
await websocket.accept()
await websocket.send_json(
{"scheme": credentials.scheme, "credentials": credentials.credentials}
)


client = TestClient(app)


Expand All @@ -27,6 +39,21 @@ def test_security_http_base_no_credentials():
assert response.json() == {"detail": "Not authenticated"}


def test_security_http_base_with_ws():
with client.websocket_connect(
"/users/timeline", headers={"Authorization": "Other foobar"}
) as websocket:
data = websocket.receive_json()
assert data == {"scheme": "Other", "credentials": "foobar"}


def test_security_http_base_with_ws_no_credentials():
with pytest.raises(WebSocketDisconnect) as e:
with client.websocket_connect("/users/timeline"):
pass
assert e.value.reason == "Not authenticated"


def test_openapi_schema():
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
Expand Down