Skip to content

Commit

Permalink
add subprotocol for token-authenticated websockets
Browse files Browse the repository at this point in the history
follows kubernetes' example of smuggling the token in the subprotocol itself
  • Loading branch information
minrk committed Mar 14, 2024
1 parent 0adfb2a commit 9627c7c
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 1 deletion.
13 changes: 13 additions & 0 deletions jupyter_server/auth/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from http.cookies import Morsel

from tornado import escape, httputil, web
from tornado.websocket import WebSocketHandler
from traitlets import Bool, Dict, Type, Unicode, default
from traitlets.config import LoggingConfigurable

Expand Down Expand Up @@ -106,6 +107,9 @@ def _backward_compat_user(got_user: t.Any) -> User:
raise ValueError(msg)


_TOKEN_SUBPROTOCOL = "v1.token.websocket.jupyter.org"


class IdentityProvider(LoggingConfigurable):
"""
Interface for providing identity management and authentication.
Expand Down Expand Up @@ -424,6 +428,15 @@ def get_token(self, handler: web.RequestHandler) -> str | None:
m = self.auth_header_pat.match(handler.request.headers.get("Authorization", ""))
if m:
user_token = m.group(2)
if not user_token and isinstance(handler, WebSocketHandler):
subprotocol_header = handler.request.headers.get("Sec-WebSocket-Protocol")
if subprotocol_header:
subprotocols = [s.strip() for s in subprotocol_header.split(",")]
for subprotocol in subprotocols:
if subprotocol.startswith(_TOKEN_SUBPROTOCOL + "."):
user_token = subprotocol[len(_TOKEN_SUBPROTOCOL) + 1 :]
break

return user_token

async def get_user_token(self, handler: web.RequestHandler) -> User | None:
Expand Down
32 changes: 31 additions & 1 deletion tests/base/test_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tornado.websocket import WebSocketClosedError, WebSocketHandler

from jupyter_server.auth import IdentityProvider, User
from jupyter_server.auth.decorator import allow_unauthenticated
from jupyter_server.auth.decorator import allow_unauthenticated, ws_authenticated
from jupyter_server.base.handlers import JupyterHandler
from jupyter_server.base.websocket import WebSocketMixin
from jupyter_server.serverapp import ServerApp
Expand Down Expand Up @@ -75,6 +75,12 @@ class NoAuthRulesWebsocketHandler(MockJupyterHandler):
pass


class AuthenticatedWebsocketHandler(MockJupyterHandler):
@ws_authenticated
def get(self, *args, **kwargs) -> None:
return super().get(*args, **kwargs)


class PermissiveWebsocketHandler(MockJupyterHandler):
@allow_unauthenticated
def get(self, *args, **kwargs) -> None:
Expand Down Expand Up @@ -126,6 +132,30 @@ async def test_websocket_auth_required(jp_serverapp, jp_ws_fetch):
assert exception.value.code == 403


async def test_websocket_token_subprotocol_auth(jp_serverapp, jp_ws_fetch):
app: ServerApp = jp_serverapp
app.web_app.add_handlers(
".*$",
[
(url_path_join(app.base_url, "ws"), AuthenticatedWebsocketHandler),
],
)

with pytest.raises(HTTPClientError) as exception:
ws = await jp_ws_fetch("ws", headers={"Authorization": ""})
assert exception.value.code == 403
token = jp_serverapp.identity_provider.token
ws = await jp_ws_fetch(
"ws",
headers={
"Authorization": "",
"Sec-WebSocket-Protocol": "v1.kernel.websocket.jupyter.org, v1.token.websocket.jupyter.org."
+ token,
},
)
ws.close()


class IndiscriminateIdentityProvider(IdentityProvider):
async def get_user(self, handler):
return User(username="test")
Expand Down

0 comments on commit 9627c7c

Please sign in to comment.