Skip to content

Commit

Permalink
Add more typing
Browse files Browse the repository at this point in the history
  • Loading branch information
blink1073 committed Nov 7, 2023
1 parent f282873 commit 3b46276
Show file tree
Hide file tree
Showing 26 changed files with 186 additions and 162 deletions.
4 changes: 2 additions & 2 deletions jupyter_server/_tz.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
ZERO = timedelta(0)


class tzUTC(tzinfo): # noqa
class tzUTC(tzinfo): # noqa: N801
"""tzinfo object for UTC (zero offset)"""

def utcoffset(self, d: datetime | None) -> timedelta:
Expand All @@ -30,7 +30,7 @@ def utcnow() -> datetime:
return datetime.now(timezone.utc)


def utcfromtimestamp(timestamp):
def utcfromtimestamp(timestamp: float) -> datetime:
return datetime.fromtimestamp(timestamp, timezone.utc)


Expand Down
2 changes: 1 addition & 1 deletion jupyter_server/auth/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,6 @@ def inner(self, *args, **kwargs):
method = action
action = None
# no-arguments `@authorized` decorator called
return wrapper(method)
return cast(FuncT, wrapper(method))

return cast(FuncT, wrapper)
54 changes: 28 additions & 26 deletions jupyter_server/auth/identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@
import os
import re
import sys
import typing as t
import uuid
from dataclasses import asdict, dataclass
from http.cookies import Morsel
from typing import TYPE_CHECKING, Any, Awaitable

from tornado import escape, httputil, web
from traitlets import Bool, Dict, Type, Unicode, default
Expand All @@ -28,7 +28,7 @@
from .utils import get_anonymous_username

# circular imports for type checking
if TYPE_CHECKING:
if t.TYPE_CHECKING:
from jupyter_server.base.handlers import AuthenticatedHandler, JupyterHandler
from jupyter_server.serverapp import ServerApp

Expand Down Expand Up @@ -82,7 +82,7 @@ def fill_defaults(self):
self.display_name = self.name


def _backward_compat_user(got_user: Any) -> User:
def _backward_compat_user(got_user: t.Any) -> User:
"""Backward-compatibility for LoginHandler.get_user
Prior to 2.0, LoginHandler.get_user could return anything truthy.
Expand Down Expand Up @@ -128,7 +128,7 @@ class IdentityProvider(LoggingConfigurable):
.. versionadded:: 2.0
"""

cookie_name: str | Unicode = Unicode(
cookie_name: str | Unicode[str, str | bytes] = Unicode(
"",
config=True,
help=_i18n("Name of the cookie to set for persisting login. Default: username-${Host}."),
Expand All @@ -142,7 +142,7 @@ class IdentityProvider(LoggingConfigurable):
),
)

secure_cookie: bool | Bool = Bool(
secure_cookie: bool | Bool[bool | None, bool | int | None] = Bool(
None,
allow_none=True,
config=True,
Expand All @@ -160,7 +160,7 @@ class IdentityProvider(LoggingConfigurable):
),
)

token: str | Unicode = Unicode(
token: str | Unicode[str, str | bytes] = Unicode(
"<generated>",
help=_i18n(
"""Token used for authenticating first-time connections to the server.
Expand Down Expand Up @@ -211,9 +211,9 @@ def _token_default(self):
self.token_generated = True
return binascii.hexlify(os.urandom(24)).decode("ascii")

need_token: bool | Bool = Bool(True)
need_token: bool | Bool[bool, t.Union[bool, int]] = Bool(True)

def get_user(self, handler: JupyterHandler) -> User | None | Awaitable[User | None]:
def get_user(self, handler: JupyterHandler) -> User | None | t.Awaitable[User | None]:
"""Get the authenticated user for a request
Must return a :class:`jupyter_server.auth.User`,
Expand All @@ -233,12 +233,12 @@ async def _get_user(self, handler: JupyterHandler) -> User | None:
if getattr(handler, "_jupyter_current_user", None):
# already authenticated
return handler._jupyter_current_user
_token_user: User | None | Awaitable[User | None] = self.get_user_token(handler)
if isinstance(_token_user, Awaitable):
_token_user: User | None | t.Awaitable[User | None] = self.get_user_token(handler)
if isinstance(_token_user, t.Awaitable):
_token_user = await _token_user
token_user: User | None = _token_user # need second variable name to collapse type
_cookie_user = self.get_user_cookie(handler)
if isinstance(_cookie_user, Awaitable):
if isinstance(_cookie_user, t.Awaitable):
_cookie_user = await _cookie_user
cookie_user: User | None = _cookie_user
# prefer token to cookie if both given,
Expand Down Expand Up @@ -273,12 +273,12 @@ async def _get_user(self, handler: JupyterHandler) -> User | None:

return user

def identity_model(self, user: User) -> dict:
def identity_model(self, user: User) -> dict[str, t.Any]:
"""Return a User as an Identity model"""
# TODO: validate?
return asdict(user)

def get_handlers(self) -> list:
def get_handlers(self) -> list[tuple[str, object]]:
"""Return list of additional handlers for this identity provider
For example, an OAuth callback handler.
Expand Down Expand Up @@ -368,7 +368,7 @@ def _force_clear_cookie(
name = escape.native_str(name)
expires = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta(days=365)

morsel: Morsel = Morsel()
morsel: Morsel[object] = Morsel()
morsel.set(name, "", '""')
morsel["expires"] = httputil.format_timestamp(expires)
morsel["path"] = path
Expand All @@ -390,7 +390,7 @@ def clear_login_cookie(self, handler: AuthenticatedHandler) -> None:
# two cookies with the same name. See the method above.
self._force_clear_cookie(handler, cookie_name)

def get_user_cookie(self, handler: JupyterHandler) -> User | None | Awaitable[User | None]:
def get_user_cookie(self, handler: JupyterHandler) -> User | None | t.Awaitable[User | None]:
"""Get user from a cookie
Calls user_from_cookie to deserialize cookie value
Expand Down Expand Up @@ -455,7 +455,7 @@ async def get_user_token(self, handler: JupyterHandler) -> User | None:
# which is stored in a cookie.
# still check the cookie for the user id
_user = self.get_user_cookie(handler)
if isinstance(_user, Awaitable):
if isinstance(_user, t.Awaitable):
_user = await _user
user: User | None = _user
if user is None:
Expand Down Expand Up @@ -505,7 +505,7 @@ def is_token_authenticated(self, handler: AuthenticatedHandler) -> bool:
def validate_security(
self,
app: ServerApp,
ssl_options: dict | None = None,
ssl_options: dict[str, t.Any] | None = None,
) -> None:
"""Check the application's security.
Expand Down Expand Up @@ -538,7 +538,7 @@ def process_login_form(self, handler: JupyterHandler) -> User | None:
return self.generate_anonymous_user(handler)

if self.token and self.token == typed_password:
return self.user_for_token(typed_password) # type:ignore[attr-defined]
return t.cast(User, self.user_for_token(typed_password)) # type:ignore[attr-defined]

return user

Expand Down Expand Up @@ -660,7 +660,7 @@ def process_login_form(self, handler: JupyterHandler) -> User | None:
def validate_security(
self,
app: ServerApp,
ssl_options: dict | None = None,
ssl_options: dict[str, t.Any] | None = None,
) -> None:
"""Handle security validation."""
super().validate_security(app, ssl_options)
Expand Down Expand Up @@ -708,23 +708,25 @@ def get_user(self, handler: JupyterHandler) -> User | None:
return _backward_compat_user(user)

@property
def login_available(self):
return self.login_handler_class.get_login_available( # type:ignore[attr-defined]
self.settings
def login_available(self) -> bool:
return bool(
self.login_handler_class.get_login_available( # type:ignore[attr-defined]
self.settings
)
)

def should_check_origin(self, handler: AuthenticatedHandler) -> bool:
"""Whether we should check origin."""
return self.login_handler_class.should_check_origin(handler) # type:ignore[attr-defined]
return bool(self.login_handler_class.should_check_origin(handler)) # type:ignore[attr-defined]

def is_token_authenticated(self, handler: AuthenticatedHandler) -> bool:
"""Whether we are token authenticated."""
return self.login_handler_class.is_token_authenticated(handler) # type:ignore[attr-defined]
return bool(self.login_handler_class.is_token_authenticated(handler)) # type:ignore[attr-defined]

def validate_security(
self,
app: ServerApp,
ssl_options: dict | None = None,
ssl_options: dict[str, t.Any] | None = None,
) -> None:
"""Validate security."""
if self.password_required and (not self.hashed_password):
Expand All @@ -734,6 +736,6 @@ def validate_security(
self.log.critical(_i18n("Hint: run the following command to set a password"))
self.log.critical(_i18n("\t$ python -m jupyter_server.auth password"))
sys.exit(1)
return self.login_handler_class.validate_security( # type:ignore[attr-defined]
self.login_handler_class.validate_security( # type:ignore[attr-defined]
app, ssl_options
)

0 comments on commit 3b46276

Please sign in to comment.