Skip to content

Commit

Permalink
add mypy check
Browse files Browse the repository at this point in the history
add py.typed file

use typings from jupyter_client

fix typing

fix gateway tests

use updated mypy config

fix 3.7 compat

config
  • Loading branch information
blink1073 committed Apr 12, 2022
1 parent cc8bdf4 commit b3a3336
Show file tree
Hide file tree
Showing 59 changed files with 299 additions and 159 deletions.
9 changes: 9 additions & 0 deletions .pre-commit-config.yaml
Expand Up @@ -28,6 +28,15 @@ repos:
files: \.py$
args: [--profile=black]

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.942
hooks:
- id: mypy
args: ["--config-file", "pyproject.toml"]
exclude: examples/simple/setup.py
additional_dependencies: [types-requests]
stages: [manual]

- repo: https://github.com/pre-commit/mirrors-prettier
rev: v2.6.2
hooks:
Expand Down
1 change: 1 addition & 0 deletions MANIFEST.in
Expand Up @@ -4,6 +4,7 @@ include README.md
include RELEASE.md
include CHANGELOG.md
include package.json
include jupyter_server/py.typed

# include everything in package_data
recursive-include jupyter_server *
Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Expand Up @@ -16,7 +16,7 @@
import shutil
import sys

from pkg_resources import parse_version
from packaging.version import parse as parse_version

HERE = osp.abspath(osp.dirname(__file__))

Expand Down
Expand Up @@ -11,4 +11,4 @@ def is_authorized(self, handler, user, action, resource):
return True


c.ServerApp.authorizer_class = ReadOnly
c.ServerApp.authorizer_class = ReadOnly # type:ignore[name-defined]
2 changes: 1 addition & 1 deletion examples/authorization/jupyter_nbclassic_rw_config.py
Expand Up @@ -11,4 +11,4 @@ def is_authorized(self, handler, user, action, resource):
return True


c.ServerApp.authorizer_class = ReadWriteOnly
c.ServerApp.authorizer_class = ReadWriteOnly # type:ignore[name-defined]
2 changes: 1 addition & 1 deletion examples/authorization/jupyter_temporary_config.py
Expand Up @@ -11,4 +11,4 @@ def is_authorized(self, handler, user, action, resource):
return True


c.ServerApp.authorizer_class = TemporaryServerPersonality
c.ServerApp.authorizer_class = TemporaryServerPersonality # type:ignore[name-defined]
4 changes: 3 additions & 1 deletion examples/simple/jupyter_server_config.py
Expand Up @@ -3,4 +3,6 @@
# Application(SingletonConfigurable) configuration
# ------------------------------------------------------------------------------
# The date format used by logging formatters for %(asctime)s
c.Application.log_datefmt = "%Y-%m-%d %H:%M:%S Simple_Extensions_Example"
c.Application.log_datefmt = ( # type:ignore[name-defined]
"%Y-%m-%d %H:%M:%S Simple_Extensions_Example"
)
2 changes: 1 addition & 1 deletion examples/simple/jupyter_simple_ext11_config.py
@@ -1 +1 @@
c.SimpleApp11.ignore_js = True
c.SimpleApp11.ignore_js = True # type:ignore[name-defined]
8 changes: 4 additions & 4 deletions examples/simple/jupyter_simple_ext1_config.py
@@ -1,4 +1,4 @@
c.SimpleApp1.configA = "ConfigA from file"
c.SimpleApp1.configB = "ConfigB from file"
c.SimpleApp1.configC = "ConfigC from file"
c.SimpleApp1.configD = "ConfigD from file"
c.SimpleApp1.configA = "ConfigA from file" # type:ignore[name-defined]
c.SimpleApp1.configB = "ConfigB from file" # type:ignore[name-defined]
c.SimpleApp1.configC = "ConfigC from file" # type:ignore[name-defined]
c.SimpleApp1.configD = "ConfigD from file" # type:ignore[name-defined]
2 changes: 1 addition & 1 deletion examples/simple/jupyter_simple_ext2_config.py
@@ -1 +1 @@
c.SimpleApp2.configD = "ConfigD from file"
c.SimpleApp2.configD = "ConfigD from file" # type:ignore[name-defined]
4 changes: 3 additions & 1 deletion jupyter_server/__init__.py
Expand Up @@ -13,7 +13,9 @@

del os

from ._version import __version__, version_info # noqa
from ._version import __version__, version_info

__all__ = ["__version__", "version_info"]


def _cleanup():
Expand Down
2 changes: 2 additions & 0 deletions jupyter_server/auth/__init__.py
@@ -1,3 +1,5 @@
from .authorizer import * # noqa
from .decorator import authorized # noqa
from .security import passwd # noqa

__all__ = ["authorized", "passwd"]
12 changes: 7 additions & 5 deletions jupyter_server/auth/decorator.py
Expand Up @@ -3,19 +3,21 @@
# Copyright (c) Jupyter Development Team.
# Distributed under the terms of the Modified BSD License.
from functools import wraps
from typing import Callable, Optional, Union
from typing import Any, Callable, Optional, TypeVar, Union, cast

from tornado.log import app_log
from tornado.web import HTTPError

from .utils import HTTP_METHOD_TO_AUTH_ACTION, warn_disabled_authorization

T = TypeVar("T", bound=Callable[..., Any])


def authorized(
action: Optional[Union[str, Callable]] = None,
action: Optional[Union[str, Callable[..., Any]]] = None,
resource: Optional[str] = None,
message: Optional[str] = None,
) -> Callable:
) -> Callable[..., Any]:
"""A decorator for tornado.web.RequestHandler methods
that verifies whether the current user is authorized
to make the following request.
Expand All @@ -38,7 +40,7 @@ def authorized(
a message for the unauthorized action.
"""

def wrapper(method):
def wrapper(method: T) -> T:
@wraps(method)
def inner(self, *args, **kwargs):
# default values for action, resource
Expand Down Expand Up @@ -70,7 +72,7 @@ def inner(self, *args, **kwargs):
# Raise an exception if the method wasn't returned (i.e. not authorized)
raise HTTPError(status_code=403, log_message=message)

return inner
return cast(T, inner)

if callable(action):
method = action
Expand Down
2 changes: 1 addition & 1 deletion jupyter_server/auth/login.py
Expand Up @@ -83,7 +83,7 @@ def post(self):
elif self.token and self.token == typed_password:
self.set_login_cookie(self, uuid.uuid4().hex)
if new_password and self.settings.get("allow_password_change"):
config_dir = self.settings.get("config_dir")
config_dir = self.settings.get("config_dir", "")
config_file = os.path.join(config_dir, "jupyter_server_config.json")
set_password(new_password, config_file=config_file)
self.log.info("Wrote hashed password to %s" % config_file)
Expand Down
4 changes: 2 additions & 2 deletions jupyter_server/auth/security.py
Expand Up @@ -64,9 +64,9 @@ def passwd(passphrase=None, algorithm="argon2"):
time_cost=10,
parallelism=8,
)
h = ph.hash(passphrase)
h_ph = ph.hash(passphrase)

return ":".join((algorithm, h))
return ":".join((algorithm, h_ph))

h = hashlib.new(algorithm)
salt = ("%0" + str(salt_len) + "x") % random.getrandbits(4 * salt_len)
Expand Down
6 changes: 3 additions & 3 deletions jupyter_server/auth/utils.py
Expand Up @@ -44,9 +44,9 @@ def get_regex_to_resource_map():
from jupyter_server.serverapp import JUPYTER_SERVICE_HANDLERS

modules = []
for mod in JUPYTER_SERVICE_HANDLERS.values():
if mod:
modules.extend(mod)
for mod_name in JUPYTER_SERVICE_HANDLERS.values():
if mod_name:
modules.extend(mod_name)
resource_map = {}
for handler_module in modules:
mod = importlib.import_module(handler_module)
Expand Down
30 changes: 20 additions & 10 deletions jupyter_server/base/handlers.py
Expand Up @@ -10,6 +10,7 @@
import re
import traceback
import types
import typing as t
import warnings
from http.client import responses
from http.cookies import Morsel
Expand Down Expand Up @@ -114,7 +115,7 @@ def force_clear_cookie(self, name, path="/", domain=None):
name = escape.native_str(name)
expires = datetime.datetime.utcnow() - datetime.timedelta(days=365)

morsel = Morsel()
morsel: Morsel[str] = Morsel()
morsel.set(name, "", '""')
morsel["expires"] = httputil.format_timestamp(expires)
morsel["path"] = path
Expand Down Expand Up @@ -241,8 +242,8 @@ def mathjax_config(self):
return self.settings.get("mathjax_config", "TeX-AMS-MML_HTMLorMML-full,Safe")

@property
def base_url(self):
return self.settings.get("base_url", "/")
def base_url(self) -> str:
return self.settings.get("base_url", "/") # type:ignore[no-any-return]

@property
def default_url(self):
Expand Down Expand Up @@ -476,7 +477,9 @@ def check_host(self):
return True

# Remove port (e.g. ':8888') from host
host = re.match(r"^(.*?)(:\d+)?$", self.request.host).group(1)
match = re.match(r"^(.*?)(:\d+)?$", self.request.host)
assert match is not None
host = match.group(1)

# Browsers format IPv6 addresses like [::1]; we need to remove the []
if host.startswith("[") and host.endswith("]"):
Expand Down Expand Up @@ -567,7 +570,7 @@ def write_error(self, status_code, **kwargs):
exc_info = kwargs.get("exc_info")
message = ""
status_message = responses.get(status_code, "Unknown HTTP Error")
exception = "(unknown)"

if exc_info:
exception = exc_info[1]
# get the custom message, if defined
Expand All @@ -580,6 +583,8 @@ def write_error(self, status_code, **kwargs):
reason = getattr(exception, "reason", "")
if reason:
status_message = reason
else:
exception = "(unknown)"

# build template namespace
ns = dict(
Expand All @@ -602,6 +607,8 @@ def write_error(self, status_code, **kwargs):
class APIHandler(JupyterHandler):
"""Base class for API handlers"""

_user_cache: str

def prepare(self):
if not self.check_origin():
raise web.HTTPError(404)
Expand All @@ -611,7 +618,7 @@ def write_error(self, status_code, **kwargs):
"""APIHandler errors are JSON, not human pages"""
self.set_header("Content-Type", "application/json")
message = responses.get(status_code, "Unknown HTTP Error")
reply = {
reply: t.Dict[str, t.Any] = {
"message": message,
}
exc_info = kwargs.get("exc_info")
Expand All @@ -627,13 +634,13 @@ def write_error(self, status_code, **kwargs):
self.log.warning(reply["message"])
self.finish(json.dumps(reply))

def get_current_user(self):
def get_current_user(self) -> str:
"""Raise 403 on API handlers instead of redirecting to human login page"""
# preserve _user_cache so we don't raise more than once
if hasattr(self, "_user_cache"):
return self._user_cache
self._user_cache = user = super().get_current_user()
return user
return t.cast(str, user)

def get_login_url(self):
# if get_login_url is invoked in an API handler,
Expand Down Expand Up @@ -733,13 +740,14 @@ def head(self, path):

@web.authenticated
def get(self, path):
if os.path.splitext(path)[1] == ".ipynb" or self.get_argument("download", False):
if os.path.splitext(path)[1] == ".ipynb" or self.get_argument("download", None):
name = path.rsplit("/", 1)[-1]
self.set_attachment_header(name)

return web.StaticFileHandler.get(self, path)

def get_content_type(self):
assert self.absolute_path is not None
path = self.absolute_path.strip("/")
if "/" in path:
_, name = path.rsplit("/", 1)
Expand Down Expand Up @@ -818,7 +826,8 @@ class FileFindHandler(JupyterHandler, web.StaticFileHandler):
"""subclass of StaticFileHandler for serving files from a search path"""

# cache search results, don't search for files more than once
_static_paths = {}
_static_paths: t.Dict[str, str] = {}
root: t.Any

def set_headers(self):
super().set_headers()
Expand Down Expand Up @@ -882,6 +891,7 @@ class TrailingSlashHandler(web.RequestHandler):
"""

def get(self):
assert self.request.uri is not None
path, *rest = self.request.uri.partition("?")
# trim trailing *and* leading /
# to avoid misinterpreting repeated '//'
Expand Down
23 changes: 15 additions & 8 deletions jupyter_server/base/zmqhandlers.py
Expand Up @@ -5,6 +5,7 @@
import re
import struct
import sys
import typing as t
from urllib.parse import urlparse

import tornado
Expand All @@ -17,6 +18,7 @@
from jupyter_client.jsonutil import extract_dates
from jupyter_client.session import Session
from tornado import ioloop, web
from tornado.iostream import IOStream
from tornado.websocket import WebSocketHandler

from jupyter_server.auth.utils import warn_disabled_authorization
Expand Down Expand Up @@ -93,7 +95,7 @@ def serialize_msg_to_ws_v1(msg_or_list, channel, pack=None):
else:
msg_list = msg_or_list
channel = channel.encode("utf-8")
offsets = []
offsets: t.List[t.Any] = []
offsets.append(8 * (1 + 1 + len(msg_list) + 1))
offsets.append(len(channel) + offsets[-1])
for msg in msg_list:
Expand Down Expand Up @@ -122,27 +124,30 @@ class WebSocketMixin:
"""Mixin for common websocket options"""

ping_callback = None
last_ping = 0
last_pong = 0
stream = None
last_ping = 0.0
last_pong = 0.0
stream = None # type: t.Optional[IOStream]

@property
def ping_interval(self):
"""The interval for websocket keep-alive pings.
Set ws_ping_interval = 0 to disable pings.
"""
return self.settings.get("ws_ping_interval", WS_PING_INTERVAL)
return self.settings.get("ws_ping_interval", WS_PING_INTERVAL) # type:ignore[attr-defined]

@property
def ping_timeout(self):
"""If no ping is received in this many milliseconds,
close the websocket connection (VPNs, etc. can fail to cleanly close ws connections).
Default is max of 3 pings or 30 seconds.
"""
return self.settings.get("ws_ping_timeout", max(3 * self.ping_interval, WS_PING_INTERVAL))
return self.settings.get( # type:ignore[attr-defined]
"ws_ping_timeout", max(3 * self.ping_interval, WS_PING_INTERVAL)
)

def check_origin(self, origin=None):
@t.no_type_check
def check_origin(self, origin: t.Optional[str] = None) -> bool:
"""Check Origin == Host or Access-Control-Allow-Origin.
Tornado >= 4 calls this method automatically, raising 403 if it returns False.
Expand Down Expand Up @@ -188,6 +193,7 @@ def clear_cookie(self, *args, **kwargs):
"""meaningless for websockets"""
pass

@t.no_type_check
def open(self, *args, **kwargs):
self.log.debug("Opening websocket %s", self.request.path)

Expand All @@ -203,6 +209,7 @@ def open(self, *args, **kwargs):
self.ping_callback.start()
return super().open(*args, **kwargs)

@t.no_type_check
def send_ping(self):
"""send a ping to keep the websocket alive"""
if self.ws_connection is None and self.ping_callback is not None:
Expand Down Expand Up @@ -327,7 +334,7 @@ def pre_get(self):
elif not self.authorizer.is_authorized(self, user, "execute", "kernels"):
raise web.HTTPError(403)

if self.get_argument("session_id", False):
if self.get_argument("session_id", None):
self.session.session = self.get_argument("session_id")
else:
self.log.warning("No session ID specified")
Expand Down
3 changes: 2 additions & 1 deletion jupyter_server/config_manager.py
Expand Up @@ -6,6 +6,7 @@
import glob
import json
import os
import typing as t

from traitlets.config import LoggingConfigurable
from traitlets.traitlets import Bool, Unicode
Expand Down Expand Up @@ -95,7 +96,7 @@ def get(self, section_name, include_root=True):
section_name,
"\n\t".join(paths),
)
data = {}
data: t.Dict[str, t.Any] = {}
for path in paths:
if os.path.isfile(path):
with open(path, encoding="utf-8") as f:
Expand Down

0 comments on commit b3a3336

Please sign in to comment.