diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e262bdf981..735678f109 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -28,6 +28,13 @@ repos: files: \.py$ args: [--profile=black] + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v0.942 + hooks: + - id: mypy + exclude: examples/simple/setup.py + additional_dependencies: [types-requests] + - repo: https://github.com/pre-commit/mirrors-prettier rev: v2.6.2 hooks: diff --git a/docs/source/conf.py b/docs/source/conf.py index 3e28778284..6b1be76e2a 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -16,7 +16,7 @@ import shutil import sys -from pkg_resources import parse_version +from packaging.version import parse as parse_version HERE = osp.abspath(osp.dirname(__file__)) diff --git a/examples/authorization/jupyter_nbclassic_readonly_config.py b/examples/authorization/jupyter_nbclassic_readonly_config.py index 292644c284..7fa83b17be 100644 --- a/examples/authorization/jupyter_nbclassic_readonly_config.py +++ b/examples/authorization/jupyter_nbclassic_readonly_config.py @@ -11,4 +11,4 @@ def is_authorized(self, handler, user, action, resource): return True -c.ServerApp.authorizer_class = ReadOnly +c.ServerApp.authorizer_class = ReadOnly # type:ignore[name-defined] diff --git a/examples/authorization/jupyter_nbclassic_rw_config.py b/examples/authorization/jupyter_nbclassic_rw_config.py index 261efcf984..c56c6dcc8f 100644 --- a/examples/authorization/jupyter_nbclassic_rw_config.py +++ b/examples/authorization/jupyter_nbclassic_rw_config.py @@ -11,4 +11,4 @@ def is_authorized(self, handler, user, action, resource): return True -c.ServerApp.authorizer_class = ReadWriteOnly +c.ServerApp.authorizer_class = ReadWriteOnly # type:ignore[name-defined] diff --git a/examples/authorization/jupyter_temporary_config.py b/examples/authorization/jupyter_temporary_config.py index e1bd2fb507..d19b5f74df 100644 --- a/examples/authorization/jupyter_temporary_config.py +++ b/examples/authorization/jupyter_temporary_config.py @@ -11,4 +11,4 @@ def is_authorized(self, handler, user, action, resource): return True -c.ServerApp.authorizer_class = TemporaryServerPersonality +c.ServerApp.authorizer_class = TemporaryServerPersonality # type:ignore[name-defined] diff --git a/examples/simple/jupyter_server_config.py b/examples/simple/jupyter_server_config.py index 723d6cdadb..4e3a70049e 100644 --- a/examples/simple/jupyter_server_config.py +++ b/examples/simple/jupyter_server_config.py @@ -3,4 +3,6 @@ # Application(SingletonConfigurable) configuration # ------------------------------------------------------------------------------ # The date format used by logging formatters for %(asctime)s -c.Application.log_datefmt = "%Y-%m-%d %H:%M:%S Simple_Extensions_Example" +c.Application.log_datefmt = ( # type:ignore[name-defined] + "%Y-%m-%d %H:%M:%S Simple_Extensions_Example" +) diff --git a/examples/simple/jupyter_simple_ext11_config.py b/examples/simple/jupyter_simple_ext11_config.py index d2baa1360a..b1035b8746 100644 --- a/examples/simple/jupyter_simple_ext11_config.py +++ b/examples/simple/jupyter_simple_ext11_config.py @@ -1 +1 @@ -c.SimpleApp11.ignore_js = True +c.SimpleApp11.ignore_js = True # type:ignore[name-defined] diff --git a/examples/simple/jupyter_simple_ext1_config.py b/examples/simple/jupyter_simple_ext1_config.py index f40b66afaf..5e32346335 100644 --- a/examples/simple/jupyter_simple_ext1_config.py +++ b/examples/simple/jupyter_simple_ext1_config.py @@ -1,4 +1,4 @@ -c.SimpleApp1.configA = "ConfigA from file" -c.SimpleApp1.configB = "ConfigB from file" -c.SimpleApp1.configC = "ConfigC from file" -c.SimpleApp1.configD = "ConfigD from file" +c.SimpleApp1.configA = "ConfigA from file" # type:ignore[name-defined] +c.SimpleApp1.configB = "ConfigB from file" # type:ignore[name-defined] +c.SimpleApp1.configC = "ConfigC from file" # type:ignore[name-defined] +c.SimpleApp1.configD = "ConfigD from file" # type:ignore[name-defined] diff --git a/examples/simple/jupyter_simple_ext2_config.py b/examples/simple/jupyter_simple_ext2_config.py index f145cbb87a..d5faa9e942 100644 --- a/examples/simple/jupyter_simple_ext2_config.py +++ b/examples/simple/jupyter_simple_ext2_config.py @@ -1 +1 @@ -c.SimpleApp2.configD = "ConfigD from file" +c.SimpleApp2.configD = "ConfigD from file" # type:ignore[name-defined] diff --git a/jupyter_server/auth/login.py b/jupyter_server/auth/login.py index 6eb07e5748..ca48767f58 100644 --- a/jupyter_server/auth/login.py +++ b/jupyter_server/auth/login.py @@ -83,7 +83,7 @@ def post(self): elif self.token and self.token == typed_password: self.set_login_cookie(self, uuid.uuid4().hex) if new_password and self.settings.get("allow_password_change"): - config_dir = self.settings.get("config_dir") + config_dir = self.settings.get("config_dir", "") config_file = os.path.join(config_dir, "jupyter_server_config.json") set_password(new_password, config_file=config_file) self.log.info("Wrote hashed password to %s" % config_file) diff --git a/jupyter_server/auth/security.py b/jupyter_server/auth/security.py index fa7dded7fb..219687e1ae 100644 --- a/jupyter_server/auth/security.py +++ b/jupyter_server/auth/security.py @@ -64,9 +64,9 @@ def passwd(passphrase=None, algorithm="argon2"): time_cost=10, parallelism=8, ) - h = ph.hash(passphrase) + h_ph = ph.hash(passphrase) - return ":".join((algorithm, h)) + return ":".join((algorithm, h_ph)) h = hashlib.new(algorithm) salt = ("%0" + str(salt_len) + "x") % random.getrandbits(4 * salt_len) diff --git a/jupyter_server/auth/utils.py b/jupyter_server/auth/utils.py index 3f129dce63..d89127ac88 100644 --- a/jupyter_server/auth/utils.py +++ b/jupyter_server/auth/utils.py @@ -39,9 +39,9 @@ def get_regex_to_resource_map(): from jupyter_server.serverapp import JUPYTER_SERVICE_HANDLERS modules = [] - for mod in JUPYTER_SERVICE_HANDLERS.values(): - if mod: - modules.extend(mod) + for mod_name in JUPYTER_SERVICE_HANDLERS.values(): + if mod_name: + modules.extend(mod_name) resource_map = {} for handler_module in modules: mod = importlib.import_module(handler_module) diff --git a/jupyter_server/base/handlers.py b/jupyter_server/base/handlers.py index b3c827bc6d..dc9d85018b 100644 --- a/jupyter_server/base/handlers.py +++ b/jupyter_server/base/handlers.py @@ -115,7 +115,7 @@ def force_clear_cookie(self, name, path="/", domain=None): name = escape.native_str(name) expires = datetime.datetime.utcnow() - datetime.timedelta(days=365) - morsel = Morsel() + morsel: Morsel = Morsel() morsel.set(name, "", '""') morsel["expires"] = httputil.format_timestamp(expires) morsel["path"] = path @@ -292,7 +292,7 @@ def mathjax_config(self): return self.settings.get("mathjax_config", "TeX-AMS-MML_HTMLorMML-full,Safe") @property - def base_url(self): + def base_url(self) -> str: return self.settings.get("base_url", "/") @property @@ -537,7 +537,9 @@ def check_host(self): return True # Remove port (e.g. ':8888') from host - host = re.match(r"^(.*?)(:\d+)?$", self.request.host).group(1) + match = re.match(r"^(.*?)(:\d+)?$", self.request.host) + assert match is not None + host = match.group(1) # Browsers format IPv6 addresses like [::1]; we need to remove the [] if host.startswith("[") and host.endswith("]"): @@ -574,10 +576,10 @@ async def prepare(self): from jupyter_server.auth import IdentityProvider - if ( - type(self.identity_provider) is IdentityProvider - and inspect.getmodule(self.get_current_user).__name__ != __name__ - ): + mod_obj = inspect.getmodule(self.get_current_user) + assert mod_obj is not None + + if type(self.identity_provider) is IdentityProvider and mod_obj.__name__ != __name__: # check for overridden get_current_user + default IdentityProvider # deprecated way to override auth (e.g. JupyterHub < 3.0) # allow deprecated, overridden get_current_user @@ -659,7 +661,7 @@ def write_error(self, status_code, **kwargs): exc_info = kwargs.get("exc_info") message = "" status_message = responses.get(status_code, "Unknown HTTP Error") - exception = "(unknown)" + if exc_info: exception = exc_info[1] # get the custom message, if defined @@ -672,6 +674,8 @@ def write_error(self, status_code, **kwargs): reason = getattr(exception, "reason", "") if reason: status_message = reason + else: + exception = "(unknown)" # build template namespace ns = dict( @@ -703,7 +707,7 @@ def write_error(self, status_code, **kwargs): """APIHandler errors are JSON, not human pages""" self.set_header("Content-Type", "application/json") message = responses.get(status_code, "Unknown HTTP Error") - reply = { + reply: dict = { "message": message, } exc_info = kwargs.get("exc_info") @@ -817,13 +821,14 @@ def head(self, path): @web.authenticated def get(self, path): - if os.path.splitext(path)[1] == ".ipynb" or self.get_argument("download", False): + if os.path.splitext(path)[1] == ".ipynb" or self.get_argument("download", None): name = path.rsplit("/", 1)[-1] self.set_attachment_header(name) return web.StaticFileHandler.get(self, path) def get_content_type(self): + assert self.absolute_path is not None path = self.absolute_path.strip("/") if "/" in path: _, name = path.rsplit("/", 1) @@ -902,7 +907,8 @@ class FileFindHandler(JupyterHandler, web.StaticFileHandler): """subclass of StaticFileHandler for serving files from a search path""" # cache search results, don't search for files more than once - _static_paths = {} + _static_paths: dict = {} + root: tuple def set_headers(self): super().set_headers() @@ -966,6 +972,7 @@ class TrailingSlashHandler(web.RequestHandler): """ def get(self): + assert self.request.uri is not None path, *rest = self.request.uri.partition("?") # trim trailing *and* leading / # to avoid misinterpreting repeated '//' diff --git a/jupyter_server/base/zmqhandlers.py b/jupyter_server/base/zmqhandlers.py index ad6342af85..8600ed33a0 100644 --- a/jupyter_server/base/zmqhandlers.py +++ b/jupyter_server/base/zmqhandlers.py @@ -5,6 +5,7 @@ import re import struct import sys +from typing import Optional, no_type_check from urllib.parse import urlparse import tornado @@ -17,6 +18,7 @@ from jupyter_client.jsonutil import extract_dates from jupyter_client.session import Session from tornado import ioloop, web +from tornado.iostream import IOStream from tornado.websocket import WebSocketHandler from .handlers import JupyterHandler @@ -91,7 +93,7 @@ def serialize_msg_to_ws_v1(msg_or_list, channel, pack=None): else: msg_list = msg_or_list channel = channel.encode("utf-8") - offsets = [] + offsets: list = [] offsets.append(8 * (1 + 1 + len(msg_list) + 1)) offsets.append(len(channel) + offsets[-1]) for msg in msg_list: @@ -120,9 +122,9 @@ class WebSocketMixin: """Mixin for common websocket options""" ping_callback = None - last_ping = 0 - last_pong = 0 - stream = None + last_ping = 0.0 + last_pong = 0.0 + stream = None # type: Optional[IOStream] @property def ping_interval(self): @@ -130,7 +132,7 @@ def ping_interval(self): Set ws_ping_interval = 0 to disable pings. """ - return self.settings.get("ws_ping_interval", WS_PING_INTERVAL) + return self.settings.get("ws_ping_interval", WS_PING_INTERVAL) # type:ignore[attr-defined] @property def ping_timeout(self): @@ -138,9 +140,12 @@ def ping_timeout(self): close the websocket connection (VPNs, etc. can fail to cleanly close ws connections). Default is max of 3 pings or 30 seconds. """ - return self.settings.get("ws_ping_timeout", max(3 * self.ping_interval, WS_PING_INTERVAL)) + return self.settings.get( # type:ignore[attr-defined] + "ws_ping_timeout", max(3 * self.ping_interval, WS_PING_INTERVAL) + ) - def check_origin(self, origin=None): + @no_type_check + def check_origin(self, origin: Optional[str] = None) -> bool: """Check Origin == Host or Access-Control-Allow-Origin. Tornado >= 4 calls this method automatically, raising 403 if it returns False. @@ -186,6 +191,7 @@ def clear_cookie(self, *args, **kwargs): """meaningless for websockets""" pass + @no_type_check def open(self, *args, **kwargs): self.log.debug("Opening websocket %s", self.request.path) @@ -201,6 +207,7 @@ def open(self, *args, **kwargs): self.ping_callback.start() return super().open(*args, **kwargs) + @no_type_check def send_ping(self): """send a ping to keep the websocket alive""" if self.ws_connection is None and self.ping_callback is not None: @@ -322,7 +329,7 @@ def pre_get(self): if not self.authorizer.is_authorized(self, user, "execute", "kernels"): raise web.HTTPError(403) - if self.get_argument("session_id", False): + if self.get_argument("session_id", None): self.session.session = self.get_argument("session_id") else: self.log.warning("No session ID specified") diff --git a/jupyter_server/config_manager.py b/jupyter_server/config_manager.py index 25c5efd28f..6b4c47c164 100644 --- a/jupyter_server/config_manager.py +++ b/jupyter_server/config_manager.py @@ -95,7 +95,7 @@ def get(self, section_name, include_root=True): section_name, "\n\t".join(paths), ) - data = {} + data: dict = {} for path in paths: if os.path.isfile(path): with open(path, encoding="utf-8") as f: diff --git a/jupyter_server/extension/application.py b/jupyter_server/extension/application.py index c8d951c367..d1194a0a5d 100644 --- a/jupyter_server/extension/application.py +++ b/jupyter_server/extension/application.py @@ -137,7 +137,7 @@ class method. This method can be set as a entry_point in # A useful class property that subclasses can override to # configure the underlying Jupyter Server when this extension # is launched directly (using its `launch_instance` method). - serverapp_config = {} + serverapp_config: dict = {} # Some subclasses will likely override this trait to flip # the default value to False if they don't offer a browser @@ -165,7 +165,7 @@ def config_file_paths(self): # file, jupyter_{name}_config. # This should also match the jupyter subcommand used to launch # this extension from the CLI, e.g. `jupyter {name}`. - name = None + name = "ExtensionApp" @classmethod def get_extension_package(cls): @@ -318,7 +318,7 @@ def _prepare_handlers(self): handler = handler_items[1] # Get handler kwargs, if given - kwargs = {} + kwargs: dict = {} if issubclass(handler, ExtensionHandlerMixin): kwargs["name"] = self.name diff --git a/jupyter_server/extension/handler.py b/jupyter_server/extension/handler.py index 8ea326465a..f872126f77 100644 --- a/jupyter_server/extension/handler.py +++ b/jupyter_server/extension/handler.py @@ -1,3 +1,5 @@ +from typing import no_type_check + from jinja2.exceptions import TemplateNotFound from jupyter_server.base.handlers import FileFindHandler @@ -8,6 +10,7 @@ class ExtensionHandlerJinjaMixin: template rendering. """ + @no_type_check def get_template(self, name): """Return the jinja template object for a given name""" try: @@ -31,23 +34,23 @@ class ExtensionHandlerMixin: def initialize(self, name, *args, **kwargs): self.name = name try: - super().initialize(*args, **kwargs) + super().initialize(*args, **kwargs) # type:ignore[misc] except TypeError: pass @property def extensionapp(self): - return self.settings[self.name] + return self.settings[self.name] # type:ignore[attr-defined] @property def serverapp(self): key = "serverapp" - return self.settings[key] + return self.settings[key] # type:ignore[attr-defined] @property def log(self): if not hasattr(self, "name"): - return super().log + return super().log # type:ignore[misc] # Attempt to pull the ExtensionApp's log, otherwise fall back to ServerApp. try: return self.extensionapp.log @@ -56,15 +59,15 @@ def log(self): @property def config(self): - return self.settings[f"{self.name}_config"] + return self.settings[f"{self.name}_config"] # type:ignore[attr-defined] @property def server_config(self): - return self.settings["config"] + return self.settings["config"] # type:ignore[attr-defined] @property - def base_url(self): - return self.settings.get("base_url", "/") + def base_url(self) -> str: + return self.settings.get("base_url", "/") # type:ignore[attr-defined] @property def static_url_prefix(self): @@ -72,7 +75,7 @@ def static_url_prefix(self): @property def static_path(self): - return self.settings[f"{self.name}_static_paths"] + return self.settings[f"{self.name}_static_paths"] # type:ignore[attr-defined] def static_url(self, path, include_host=None, **kwargs): """Returns a static URL for the given relative static file path. @@ -93,9 +96,9 @@ def static_url(self, path, include_host=None, **kwargs): """ key = f"{self.name}_static_paths" try: - self.require_setting(key, "static_url") + self.require_setting(key, "static_url") # type:ignore[attr-defined] except Exception as e: - if key in self.settings: + if key in self.settings: # type:ignore[attr-defined] raise Exception( "This extension doesn't have any static paths listed. Check that the " "extension's `static_paths` trait is set." @@ -103,13 +106,15 @@ def static_url(self, path, include_host=None, **kwargs): else: raise e - get_url = self.settings.get("static_handler_class", FileFindHandler).make_static_url + get_url = self.settings.get( # type:ignore[attr-defined] + "static_handler_class", FileFindHandler + ).make_static_url if include_host is None: include_host = getattr(self, "include_host", False) if include_host: - base = self.request.protocol + "://" + self.request.host + base = self.request.protocol + "://" + self.request.host # type:ignore[attr-defined] else: base = "" diff --git a/jupyter_server/extension/manager.py b/jupyter_server/extension/manager.py index 2d000cbc21..0bb6b68cfb 100644 --- a/jupyter_server/extension/manager.py +++ b/jupyter_server/extension/manager.py @@ -167,7 +167,7 @@ def __init__(self, *args, **kwargs): self._linked_points = {} super().__init__(*args, **kwargs) - _linked_points = {} + _linked_points: dict = {} @validate_trait("name") def _validate_name(self, proposed): diff --git a/jupyter_server/files/handlers.py b/jupyter_server/files/handlers.py index c76fdc28d3..de60117324 100644 --- a/jupyter_server/files/handlers.py +++ b/jupyter_server/files/handlers.py @@ -4,6 +4,7 @@ import json import mimetypes from base64 import decodebytes +from typing import List from tornado import web @@ -57,7 +58,7 @@ async def get(self, path, include_body=True): model = await ensure_async(cm.get(path, type="file", content=include_body)) - if self.get_argument("download", False): + if self.get_argument("download", None): self.set_attachment_header(name) # get mimetype from filename @@ -91,4 +92,4 @@ async def get(self, path, include_body=True): self.flush() -default_handlers = [] +default_handlers: List[JupyterHandler] = [] diff --git a/jupyter_server/gateway/handlers.py b/jupyter_server/gateway/handlers.py index e31302c464..17878bb8ba 100644 --- a/jupyter_server/gateway/handlers.py +++ b/jupyter_server/gateway/handlers.py @@ -52,7 +52,8 @@ def authenticate(self): self.log.warning("Couldn't authenticate WebSocket connection") raise web.HTTPError(403) - if self.get_argument("session_id", False): + if self.get_argument("session_id", None): + assert self.session is not None self.session.session = self.get_argument("session_id") else: self.log.warning("No session ID specified") @@ -79,6 +80,7 @@ def open(self, kernel_id, *args, **kwargs): self.ping_callback = PeriodicCallback(self.send_ping, GATEWAY_WS_PING_INTERVAL_SECS * 1000) self.ping_callback.start() + assert self.gateway is not None self.gateway.on_open( kernel_id=kernel_id, message_callback=self.write_message, @@ -87,6 +89,7 @@ def open(self, kernel_id, *args, **kwargs): def on_message(self, message): """Forward message to gateway web socket handler.""" + assert self.gateway is not None self.gateway.on_message(message) def write_message(self, message, binary=False): @@ -105,6 +108,7 @@ def write_message(self, message, binary=False): def on_close(self): self.log.debug("Closing websocket connection %s", self.request.path) + assert self.gateway is not None self.gateway.on_close() super().on_close() @@ -137,7 +141,7 @@ def __init__(self, **kwargs): super().__init__(**kwargs) self.kernel_id = None self.ws = None - self.ws_future = Future() + self.ws_future: Future = Future() self.disconnected = False self.retry = 0 @@ -152,7 +156,7 @@ async def _connect(self, kernel_id, message_callback): "channels", ) self.log.info(f"Connecting to {ws_url}") - kwargs = {} + kwargs: dict = {} kwargs = GatewayClient.instance().load_connection_args(**kwargs) request = HTTPRequest(ws_url, **kwargs) @@ -269,7 +273,8 @@ async def get(self, kernel_name, path, include_body=True): " resource serving.".format(path, kernel_name) ) else: - self.set_header("Content-Type", mimetypes.guess_type(path)[0]) + mimetype = mimetypes.guess_type(path)[0] or "text/plain" + self.set_header("Content-Type", mimetype) self.finish(kernel_spec_res) diff --git a/jupyter_server/gateway/managers.py b/jupyter_server/gateway/managers.py index 4645429cf1..5efad8e72e 100644 --- a/jupyter_server/gateway/managers.py +++ b/jupyter_server/gateway/managers.py @@ -6,7 +6,7 @@ from logging import Logger from queue import Queue from threading import Thread -from typing import Dict +from typing import Any, Dict, Optional import websocket from jupyter_client.asynchronous.client import AsyncKernelClient @@ -454,6 +454,7 @@ async def shutdown_kernel(self, now=False, restart=False): async def restart_kernel(self, **kw): """Restarts a kernel via HTTP.""" if self.has_kernel: + assert self.kernel_url is not None kernel_url = self.kernel_url + "/restart" self.log.debug("Request restart kernel at: %s", kernel_url) response = await gateway_request(kernel_url, method="POST", body=json_encode({})) @@ -462,6 +463,7 @@ async def restart_kernel(self, **kw): async def interrupt_kernel(self): """Interrupts the kernel via an HTTP request.""" if self.has_kernel: + assert self.kernel_url is not None kernel_url = self.kernel_url + "/interrupt" self.log.debug("Request interrupt kernel at: %s", kernel_url) response = await gateway_request(kernel_url, method="POST", body=json_encode({})) @@ -486,7 +488,7 @@ def cleanup_resources(self, restart=False): class ChannelQueue(Queue): - channel_name: str = None + channel_name: Optional[str] = None def __init__(self, channel_name: str, channel_socket: websocket.WebSocket, log: Logger): super().__init__() @@ -494,7 +496,7 @@ def __init__(self, channel_name: str, channel_socket: websocket.WebSocket, log: self.channel_socket = channel_socket self.log = log - async def get_msg(self, *args, **kwargs) -> dict: + async def get_msg(self, *args: Any, **kwargs: Any) -> dict: timeout = kwargs.get("timeout", 1) msg = self.get(timeout=timeout) self.log.debug( @@ -516,7 +518,7 @@ def send(self, msg: dict) -> None: @staticmethod def serialize_datetime(dt): - if isinstance(dt, (datetime.date, datetime.datetime)): + if isinstance(dt, (datetime.datetime)): return dt.timestamp() return None @@ -572,13 +574,18 @@ class GatewayKernelClient(AsyncKernelClient): # flag for whether execute requests should be allowed to call raw_input: allow_stdin = False _channels_stopped = False - _channel_queues = {} + _channel_queues: Optional[dict] = {} + _control_channel: Optional[ChannelQueue] + _hb_channel: Optional[ChannelQueue] + _stdin_channel: Optional[ChannelQueue] + _iopub_channel: Optional[ChannelQueue] + _shell_channel: Optional[ChannelQueue] def __init__(self, **kwargs): super().__init__(**kwargs) self.kernel_id = kwargs["kernel_id"] - self.channel_socket = None - self.response_router = None + self.channel_socket: Optional[websocket.WebSocket] = None + self.response_router: Optional[Thread] = None # -------------------------------------------------------------------------- # Channel management methods @@ -627,7 +634,9 @@ def stop_channels(self): self._channels_stopped = True self.log.debug("Closing websocket connection") + assert self.channel_socket is not None self.channel_socket.close() + assert self.response_router is not None self.response_router.join() if self._channel_queues: @@ -641,7 +650,9 @@ def shell_channel(self): """Get the shell channel object for this kernel.""" if self._shell_channel is None: self.log.debug("creating shell channel queue") + assert self.channel_socket is not None self._shell_channel = ChannelQueue("shell", self.channel_socket, self.log) + assert self._channel_queues is not None self._channel_queues["shell"] = self._shell_channel return self._shell_channel @@ -650,7 +661,9 @@ def iopub_channel(self): """Get the iopub channel object for this kernel.""" if self._iopub_channel is None: self.log.debug("creating iopub channel queue") + assert self.channel_socket is not None self._iopub_channel = ChannelQueue("iopub", self.channel_socket, self.log) + assert self._channel_queues is not None self._channel_queues["iopub"] = self._iopub_channel return self._iopub_channel @@ -659,7 +672,9 @@ def stdin_channel(self): """Get the stdin channel object for this kernel.""" if self._stdin_channel is None: self.log.debug("creating stdin channel queue") + assert self.channel_socket is not None self._stdin_channel = ChannelQueue("stdin", self.channel_socket, self.log) + assert self._channel_queues is not None self._channel_queues["stdin"] = self._stdin_channel return self._stdin_channel @@ -668,7 +683,9 @@ def hb_channel(self): """Get the hb channel object for this kernel.""" if self._hb_channel is None: self.log.debug("creating hb channel queue") + assert self.channel_socket is not None self._hb_channel = HBChannelQueue("hb", self.channel_socket, self.log) + assert self._channel_queues is not None self._channel_queues["hb"] = self._hb_channel return self._hb_channel @@ -677,7 +694,9 @@ def control_channel(self): """Get the control channel object for this kernel.""" if self._control_channel is None: self.log.debug("creating control channel queue") + assert self.channel_socket is not None self._control_channel = ChannelQueue("control", self.channel_socket, self.log) + assert self._channel_queues is not None self._channel_queues["control"] = self._control_channel return self._control_channel @@ -691,11 +710,13 @@ def _route_responses(self): """ try: while not self._channels_stopped: + assert self.channel_socket is not None raw_message = self.channel_socket.recv() if not raw_message: break response_message = json_decode(utf8(raw_message)) channel = response_message["channel"] + assert self._channel_queues is not None self._channel_queues[channel].put_nowait(response_message) except websocket.WebSocketConnectionClosedException: diff --git a/jupyter_server/i18n/__init__.py b/jupyter_server/i18n/__init__.py index e44aa11393..8a791f4c58 100644 --- a/jupyter_server/i18n/__init__.py +++ b/jupyter_server/i18n/__init__.py @@ -15,7 +15,7 @@ # ... # } # }} -TRANSLATIONS_CACHE = {"nbjs": {}} +TRANSLATIONS_CACHE: dict = {"nbjs": {}} _accept_lang_re = re.compile( @@ -87,7 +87,7 @@ def combine_translations(accept_language, domain="nbjs"): Returns data re-packaged in jed1.x format. """ lang_codes = parse_accept_lang_header(accept_language) - combined = {} + combined: dict = {} for language in lang_codes: if language == "en": # en is default, all translations are in frontend. diff --git a/jupyter_server/py.typed b/jupyter_server/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/jupyter_server/pytest_plugin.py b/jupyter_server/pytest_plugin.py index 8fd4dda943..e5d0d49907 100644 --- a/jupyter_server/pytest_plugin.py +++ b/jupyter_server/pytest_plugin.py @@ -34,7 +34,9 @@ import asyncio if os.name == "nt" and sys.version_info >= (3, 7): - asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + asyncio.set_event_loop_policy( + asyncio.WindowsSelectorEventLoopPolicy() # type:ignore[attr-defined] + ) # ============ Move to Jupyter Core ============= diff --git a/jupyter_server/serverapp.py b/jupyter_server/serverapp.py index 1a8d400df1..5e59a7dd71 100644 --- a/jupyter_server/serverapp.py +++ b/jupyter_server/serverapp.py @@ -32,7 +32,7 @@ import resource except ImportError: # Windows - resource = None + resource = None # type:ignore[assignment] from jinja2 import Environment, FileSystemLoader from jupyter_core.paths import secure_write @@ -282,7 +282,7 @@ def init_settings( _template_path = (_template_path,) template_path = [os.path.expanduser(path) for path in _template_path] - jenv_opt = {"autoescape": True} + jenv_opt: dict = {"autoescape": True} jenv_opt.update(jinja_env_options if jinja_env_options else {}) env = Environment( @@ -1041,7 +1041,7 @@ def _token_default(self): return os.getenv("JUPYTER_TOKEN") if os.getenv("JUPYTER_TOKEN_FILE"): self._token_generated = False - with open(os.getenv("JUPYTER_TOKEN_FILE")) as token_file: + with open(os.getenv("JUPYTER_TOKEN_FILE", "")) as token_file: return token_file.read() if self.password: # no token if password is enabled @@ -1196,10 +1196,10 @@ def _default_allow_remote(self): except ValueError: # Address is a hostname for info in socket.getaddrinfo(self.ip, self.port, 0, socket.SOCK_STREAM): - addr = info[4][0] + addr = info[4][0] # type:ignore[assignment] try: - parsed = ipaddress.ip_address(addr.split("%")[0]) + parsed = ipaddress.ip_address(addr.split("%")[0]) # type:ignore[union-attr] except ValueError: self.log.warning("Unrecognised IP address: %r", addr) continue @@ -1207,7 +1207,10 @@ def _default_allow_remote(self): # Macs map localhost to 'fe80::1%lo0', a link local address # scoped to the loopback interface. For now, we'll assume that # any scoped link-local address is effectively local. - if not (parsed.is_loopback or (("%" in addr) and parsed.is_link_local)): + if not ( + parsed.is_loopback + or (("%" in addr) and parsed.is_link_local) # type:ignore[operator] + ): return True return False else: @@ -2024,12 +2027,7 @@ def _get_urlparts(self, path=None, include_token=False): query = urllib.parse.urlencode({"token": token}) # Build the URL Parts to dump. urlparts = urllib.parse.ParseResult( - scheme=scheme, - netloc=netloc, - path=path, - params=None, - query=query, - fragment=None, + scheme=scheme, netloc=netloc, path=path, query=query or "", params="", fragment="" ) return urlparts @@ -2650,6 +2648,7 @@ def launch_browser(self): assembled_url, _ = self._prepare_browser_open() def target(): + assert browser is not None browser.open(assembled_url, new=self.webbrowser_open_new) threading.Thread(target=target).start() diff --git a/jupyter_server/services/config/manager.py b/jupyter_server/services/config/manager.py index 5f04925fe7..bc42deb645 100644 --- a/jupyter_server/services/config/manager.py +++ b/jupyter_server/services/config/manager.py @@ -22,7 +22,7 @@ class ConfigManager(LoggingConfigurable): def get(self, section_name): """Get the config from all config sections.""" - config = {} + config: dict = {} # step through back to front, to ensure front of the list is top priority for p in self.read_config_path[::-1]: cm = BaseJSONConfigManager(config_dir=p) diff --git a/jupyter_server/services/contents/fileio.py b/jupyter_server/services/contents/fileio.py index d01bfd16dc..e1d6ae66dc 100644 --- a/jupyter_server/services/contents/fileio.py +++ b/jupyter_server/services/contents/fileio.py @@ -204,11 +204,12 @@ def atomic_writing(self, os_path, *args, **kwargs): Depending on flag 'use_atomic_writing', the wrapper perform an actual atomic writing or simply writes the file (whatever an old exists or not)""" with self.perm_to_403(os_path): + kwargs["log"] = self.log if self.use_atomic_writing: - with atomic_writing(os_path, *args, log=self.log, **kwargs) as f: + with atomic_writing(os_path, *args, **kwargs) as f: yield f else: - with _simple_writing(os_path, *args, log=self.log, **kwargs) as f: + with _simple_writing(os_path, *args, **kwargs) as f: yield f @contextmanager diff --git a/jupyter_server/services/contents/filemanager.py b/jupyter_server/services/contents/filemanager.py index 88aa0e3620..7baaf842f6 100644 --- a/jupyter_server/services/contents/filemanager.py +++ b/jupyter_server/services/contents/filemanager.py @@ -331,7 +331,7 @@ def _notebook_model(self, path, content=True): os_path = self._get_os_path(path) if content: - validation_error = {} + validation_error: dict = {} nb = self._read_notebook( os_path, as_version=4, capture_validation_error=validation_error ) @@ -412,7 +412,7 @@ def save(self, model, path=""): os_path = self._get_os_path(path) self.log.debug("Saving %s", os_path) - validation_error = {} + validation_error: dict = {} try: if model["type"] == "notebook": nb = nbformat.from_dict(model["content"]) @@ -657,7 +657,7 @@ async def _notebook_model(self, path, content=True): os_path = self._get_os_path(path) if content: - validation_error = {} + validation_error: dict = {} nb = await self._read_notebook( os_path, as_version=4, capture_validation_error=validation_error ) @@ -738,7 +738,7 @@ async def save(self, model, path=""): os_path = self._get_os_path(path) self.log.debug("Saving %s", os_path) - validation_error = {} + validation_error: dict = {} try: if model["type"] == "notebook": nb = nbformat.from_dict(model["content"]) diff --git a/jupyter_server/services/contents/handlers.py b/jupyter_server/services/contents/handlers.py index 59c109ad84..6b98c5d6cf 100644 --- a/jupyter_server/services/contents/handlers.py +++ b/jupyter_server/services/contents/handlers.py @@ -54,7 +54,7 @@ def validate_model(model, expect_content): f"Keys unexpectedly None: {errors}", ) else: - errors = {key: model[key] for key in maybe_none_keys if model[key] is not None} + errors = {key: model[key] for key in maybe_none_keys if model[key] is not None} # type: ignore[assignment] if errors: raise web.HTTPError( 500, @@ -102,10 +102,10 @@ async def get(self, path=""): format = self.get_query_argument("format", default=None) if format not in {None, "text", "base64"}: raise web.HTTPError(400, "Format %r is invalid" % format) - content = self.get_query_argument("content", default="1") - if content not in {"0", "1"}: - raise web.HTTPError(400, "Content %r is invalid" % content) - content = int(content) + content_str = self.get_query_argument("content", default="1") + if content_str not in {"0", "1"}: + raise web.HTTPError(400, "Content %r is invalid" % content_str) + content = int(content_str or "") model = await ensure_async( self.contents_manager.get( diff --git a/jupyter_server/services/kernels/handlers.py b/jupyter_server/services/kernels/handlers.py index c5fd110fa9..8b3ff7dc62 100644 --- a/jupyter_server/services/kernels/handlers.py +++ b/jupyter_server/services/kernels/handlers.py @@ -114,7 +114,10 @@ class ZMQChannelsHandler(AuthenticatedZMQStreamHandler): # class-level registry of open sessions # allows checking for conflict on session-id, # which is used as a zmq identity and must be unique. - _open_sessions = {} + _open_sessions: dict = {} + + _kernel_info_future: Future + _close_future: Future @property def kernel_info_timeout(self): @@ -177,7 +180,7 @@ def nudge(self): # establishing its zmq subscriptions before processing the next request. if getattr(kernel, "execution_state", None) == "busy": self.log.debug("Nudge: not nudging busy kernel %s", self.kernel_id) - f = Future() + f: Future = Future() f.set_result(None) return f # Use a transient shell channel to prevent leaking @@ -189,8 +192,8 @@ def nudge(self): # The IOPub used by the client, whose subscriptions we are verifying. iopub_channel = self.channels["iopub"] - info_future = Future() - iopub_future = Future() + info_future: Future = Future() + iopub_future: Future = Future() both_done = gen.multi([info_future, iopub_future]) def finish(_=None): @@ -203,7 +206,7 @@ def finish(_=None): def cleanup(_=None): """Common cleanup""" - loop.remove_timeout(nudge_handle) + loop.remove_timeout(nudge_handle) # type:ignore[has-type] iopub_channel.stop_on_recv() if not shell_channel.closed(): shell_channel.close() @@ -271,7 +274,7 @@ def nudge(count): log(f"Nudge: attempt {count} on kernel {self.kernel_id}") self.session.send(shell_channel, "kernel_info_request") self.session.send(control_channel, "kernel_info_request") - nonlocal nudge_handle + nonlocal nudge_handle # type:ignore[misc] nudge_handle = loop.call_later(0.5, nudge, count) nudge_handle = loop.call_later(0, nudge, count=0) @@ -293,8 +296,9 @@ def request_kernel_info(self): self.log.debug("Requesting kernel info from %s", self.kernel_id) # Create a kernel_info channel to query the kernel protocol version. # This channel will be closed after the kernel_info reply is received. - if self.kernel_info_channel is None: + if self.kernel_info_channel is None: # type:ignore[has-type] self.kernel_info_channel = km.connect_shell(self.kernel_id) + assert self.kernel_info_channel is not None self.kernel_info_channel.on_recv(self._handle_kernel_info_reply) self.session.send(self.kernel_info_channel, "kernel_info_request") # store the future on the kernel, so only one request is sent @@ -512,6 +516,7 @@ def on_message(self, ws_msg): ignore_msg = False if am: msg["header"] = self.get_part("header", msg["header"], msg_list) + assert msg["header"] is not None if msg["header"]["msg_type"] not in am: self.log.warning( 'Received message of type "%s", which is not allowed. Ignoring.' diff --git a/jupyter_server/services/kernels/kernelmanager.py b/jupyter_server/services/kernels/kernelmanager.py index 190d60eeb4..1372232d00 100644 --- a/jupyter_server/services/kernels/kernelmanager.py +++ b/jupyter_server/services/kernels/kernelmanager.py @@ -411,7 +411,7 @@ async def restart_kernel(self, kernel_id, now=False): kernel = self.get_kernel(kernel_id) # return a Future that will resolve when the kernel has successfully restarted channel = kernel.connect_shell() - future = Future() + future: Future = Future() def finish(): """Common cleanup when restart finishes/fails for any reason.""" diff --git a/jupyter_server/services/sessions/sessionmanager.py b/jupyter_server/services/sessions/sessionmanager.py index 5ea14af5ac..51489116a0 100644 --- a/jupyter_server/services/sessions/sessionmanager.py +++ b/jupyter_server/services/sessions/sessionmanager.py @@ -8,7 +8,7 @@ import sqlite3 except ImportError: # fallback on pysqlite2 if Python was build without sqlite - from pysqlite2 import dbapi2 as sqlite3 + from pysqlite2 import dbapi2 as sqlite3 # type:ignore[no-redef] from dataclasses import dataclass, fields from typing import Union @@ -41,7 +41,7 @@ class KernelSessionRecord: session_id: Union[None, str] = None kernel_id: Union[None, str] = None - def __eq__(self, other: "KernelSessionRecord") -> bool: + def __eq__(self, other: object) -> bool: if isinstance(other, KernelSessionRecord): condition1 = self.kernel_id and self.kernel_id == other.kernel_id condition2 = all( @@ -103,7 +103,7 @@ def __init__(self, *records): def __str__(self): return str(self._records) - def __contains__(self, record: Union[KernelSessionRecord, str]): + def __contains__(self, record: Union[KernelSessionRecord, str]) -> bool: """Search for records by kernel_id and session_id""" if isinstance(record, KernelSessionRecord) and record in self._records: return True diff --git a/jupyter_server/traittypes.py b/jupyter_server/traittypes.py index cad8b4e204..1034f6935c 100644 --- a/jupyter_server/traittypes.py +++ b/jupyter_server/traittypes.py @@ -8,6 +8,8 @@ class TypeFromClasses(ClassBasedTraitType): """A trait whose value must be a subclass of a class in a specified list of classes.""" + default_value: Undefined + def __init__(self, default_value=Undefined, klasses=None, **kwargs): """Construct a Type trait A Type trait specifies that its values must be subclasses of @@ -181,6 +183,7 @@ def validate(self, obj, value): def info(self): result = "an instance of " + assert self.klasses is not None for klass in self.klasses: if isinstance(klass, str): result += klass @@ -199,6 +202,7 @@ def instance_init(self, obj): def _resolve_classes(self): # Resolve all string names to actual classes. self.importable_klasses = [] + assert self.klasses is not None for klass in self.klasses: if isinstance(klass, str): # Try importing the classes to compare. Silently, ignore if not importable. diff --git a/jupyter_server/utils.py b/jupyter_server/utils.py index 714bd9836c..33365c501c 100644 --- a/jupyter_server/utils.py +++ b/jupyter_server/utils.py @@ -149,7 +149,7 @@ def _check_pid_win32(pid): # OpenProcess returns 0 if no such process (of ours) exists # positive int otherwise - return bool(ctypes.windll.kernel32.OpenProcess(1, 0, pid)) + return bool(ctypes.windll.kernel32.OpenProcess(1, 0, pid)) # type:ignore[attr-defined] def _check_pid_posix(pid): diff --git a/pyproject.toml b/pyproject.toml index 4cc5a56e50..770723662b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -123,3 +123,46 @@ default = "" [[tool.tbump.field]] name = "release" default = "" + +[tool.mypy] +check_untyped_defs = true +disallow_incomplete_defs = true +no_implicit_optional = true +pretty = true +show_error_context = true +show_error_codes = true +strict_equality = true +warn_unused_configs = true +warn_unused_ignores = true +warn_redundant_casts = true +exclude = [ + "examples/simple/setup.py", +] + + +[[tool.mypy.overrides]] +module = [ + "traitlets", + "traitlets.config", + "traitlets.config.application", + "traitlets.config.configurable", + "traitlets.tests.utils", + "traitlets.utils.importstring", + "traitlets.traitlets", + "traitlets.utils.descriptions", + "jupyter_core", + "jupyter_core.application", + "jupyter_core.paths", + "jupyter_core.utils", + "nbconvert.exporters", + "nbconvert.exporters.base", + "nbformat", + "nbformat.sign", + "nbformat.v4", + "pysqlite2", + "_frozen_importlib_external", + "send2trash", + "terminado", + "websocket" +] +ignore_missing_imports = true diff --git a/tests/auth/test_authorizer.py b/tests/auth/test_authorizer.py index 1fc694d3c4..ba00bf33a2 100644 --- a/tests/auth/test_authorizer.py +++ b/tests/auth/test_authorizer.py @@ -18,7 +18,7 @@ class AuthorizerforTesting(Authorizer): # Set these class attributes from within a test # to verify that they match the arguments passed # by the REST API. - permissions = {} + permissions: dict = {} def normalize_url(self, path): """Drop the base URL and make sure path leads with a /""" diff --git a/tests/auth/test_identity.py b/tests/auth/test_identity.py index e60997f2fd..0b7c3f8364 100644 --- a/tests/auth/test_identity.py +++ b/tests/auth/test_identity.py @@ -46,7 +46,7 @@ def test_identity_model(old_user, expected): idp = IdentityProvider() identity = idp.identity_model(user) print(identity) - identity_subset = {key: identity[key] for key in expected} + identity_subset = {key: identity[key] for key in expected} # type:ignore[union-attr] print(type(identity), type(identity_subset), type(expected)) assert identity_subset == expected @@ -92,8 +92,8 @@ def test_user_defaults(fields, expected): user = User(**fields) # check expected fields - for key in expected: - assert getattr(user, key) == expected[key] + for key in expected: # type:ignore[union-attr] + assert getattr(user, key) == expected[key] # type:ignore[index] # check types for key in ("username", "name", "display_name"): diff --git a/tests/auth/test_login.py b/tests/auth/test_login.py index 9f120c13ab..1d185d3b79 100644 --- a/tests/auth/test_login.py +++ b/tests/auth/test_login.py @@ -49,6 +49,7 @@ async def _login(jp_serverapp, http_server_client, jp_base_url, next): except HTTPClientError as e: if e.code != 302: raise + assert e.response is not None return e.response.headers["Location"] else: assert resp.code == 302, "Should have returned a redirect!" diff --git a/tests/extension/test_app.py b/tests/extension/test_app.py index 006c3e09b0..b7030d146e 100644 --- a/tests/extension/test_app.py +++ b/tests/extension/test_app.py @@ -1,3 +1,5 @@ +from typing import Any + import pytest from traitlets.config import Config @@ -78,7 +80,7 @@ def test_extensionapp_no_parent(): assert app.serverapp is not None -OPEN_BROWSER_COMBINATIONS = ( +OPEN_BROWSER_COMBINATIONS: Any = ( (True, {}), (True, {"ServerApp": {"open_browser": True}}), (False, {"ServerApp": {"open_browser": False}}), diff --git a/tests/nbconvert/test_handlers.py b/tests/nbconvert/test_handlers.py index 809f0ba3ec..f14fde35a2 100644 --- a/tests/nbconvert/test_handlers.py +++ b/tests/nbconvert/test_handlers.py @@ -3,9 +3,9 @@ from shutil import which import pytest -import tornado from nbformat import writes from nbformat.v4 import new_code_cell, new_markdown_cell, new_notebook, new_output +from tornado.httpclient import HTTPClientError from ..utils import expected_http_error @@ -75,7 +75,7 @@ async def test_from_file(jp_fetch, notebook): async def test_from_file_404(jp_fetch, notebook): - with pytest.raises(tornado.httpclient.HTTPClientError) as e: + with pytest.raises(HTTPClientError) as e: await jp_fetch( "nbconvert", "html", diff --git a/tests/services/contents/test_api.py b/tests/services/contents/test_api.py index 988dcdb603..31b07e4137 100644 --- a/tests/services/contents/test_api.py +++ b/tests/services/contents/test_api.py @@ -54,7 +54,7 @@ def contents_dir(tmp_path, jp_serverapp): @pytest.fixture def contents(contents_dir): # Create files in temporary directory - paths = { + paths: dict = { "notebooks": [], "textfiles": [], "blobs": [], diff --git a/tests/services/contents/test_manager.py b/tests/services/contents/test_manager.py index 6765cbbe54..e3b37642b3 100644 --- a/tests/services/contents/test_manager.py +++ b/tests/services/contents/test_manager.py @@ -77,7 +77,7 @@ def add_invalid_cell(notebook): async def prepare_notebook( - jp_contents_manager, make_invalid: Optional[bool] = False + jp_contents_manager: FileContentsManager, make_invalid: Optional[bool] = False ) -> Tuple[Dict, str]: cm = jp_contents_manager model = await ensure_async(cm.new_untitled(type="notebook")) @@ -756,7 +756,7 @@ async def test_validate_notebook_model(jp_contents_manager): with patch("jupyter_server.services.contents.manager.validate_nb") as mock_validate_nb: # Valid notebook and a non-None dictionary, no validate call expected - validation_error = {} + validation_error: dict = {} cm.validate_notebook_model(model, validation_error) assert mock_validate_nb.call_count == 0 mock_validate_nb.reset_mock() diff --git a/tests/services/kernels/test_api.py b/tests/services/kernels/test_api.py index b7c43453f5..8e70ba5f49 100644 --- a/tests/services/kernels/test_api.py +++ b/tests/services/kernels/test_api.py @@ -27,7 +27,7 @@ async def _(kernel_id): return _ -configs = [ +configs: list = [ { "ServerApp": { "kernel_manager_class": "jupyter_server.services.kernels.kernelmanager.MappingKernelManager" diff --git a/tests/services/kernelspecs/test_api.py b/tests/services/kernelspecs/test_api.py index ee14d6afb0..461cc40e3e 100644 --- a/tests/services/kernelspecs/test_api.py +++ b/tests/services/kernelspecs/test_api.py @@ -1,8 +1,8 @@ import json import pytest -import tornado from jupyter_client.kernelspec import NATIVE_KERNEL_NAME +from tornado.httpclient import HTTPClientError from ...utils import expected_http_error, some_resource @@ -51,7 +51,7 @@ async def test_get_kernelspecs(jp_fetch, jp_kernelspecs): async def test_get_nonexistant_kernelspec(jp_fetch, jp_kernelspecs): - with pytest.raises(tornado.httpclient.HTTPClientError) as e: + with pytest.raises(HTTPClientError) as e: await jp_fetch("api", "kernelspecs", "nonexistant", method="GET") assert expected_http_error(e, 404) @@ -63,10 +63,10 @@ async def test_get_kernel_resource_file(jp_fetch, jp_kernelspecs): async def test_get_nonexistant_resource(jp_fetch, jp_kernelspecs): - with pytest.raises(tornado.httpclient.HTTPClientError) as e: + with pytest.raises(HTTPClientError) as e: await jp_fetch("kernelspecs", "nonexistant", "resource.txt", method="GET") assert expected_http_error(e, 404) - with pytest.raises(tornado.httpclient.HTTPClientError) as e: + with pytest.raises(HTTPClientError) as e: await jp_fetch("kernelspecs", "sample", "nonexistant.txt", method="GET") assert expected_http_error(e, 404) diff --git a/tests/services/sessions/test_api.py b/tests/services/sessions/test_api.py index d2a234e4b2..1bd4f58dfd 100644 --- a/tests/services/sessions/test_api.py +++ b/tests/services/sessions/test_api.py @@ -2,6 +2,7 @@ import os import shutil import time +from typing import Any import jupyter_client import pytest @@ -29,7 +30,7 @@ class NewPortsKernelManager(AsyncIOLoopKernelManager): def _default_cache_ports(self) -> bool: return False - async def restart_kernel(self, now: bool = False, newports: bool = True, **kw) -> None: + async def restart_kernel(self, now: bool = False, newports: bool = True, **kw: Any) -> None: self.log.debug(f"DEBUG**** calling super().restart_kernel with newports={newports}") return await super().restart_kernel(now=now, newports=newports, **kw) @@ -41,7 +42,7 @@ def _default_kernel_manager_class(self): return "tests.services.sessions.test_api.NewPortsKernelManager" -configs = [ +configs: list = [ { "ServerApp": { "kernel_manager_class": "jupyter_server.services.kernels.kernelmanager.MappingKernelManager" @@ -65,7 +66,7 @@ def _default_kernel_manager_class(self): # See https://github.com/jupyter-server/jupyter_server/issues/672 if os.name != "nt" and jupyter_client._version.version_info >= (7, 1): # Add a pending kernels condition - c = { + c: dict = { "ServerApp": { "kernel_manager_class": "tests.services.sessions.test_api.NewPortsMappingKernelManager" }, diff --git a/tests/services/sessions/test_manager.py b/tests/services/sessions/test_manager.py index a67dd6398e..48d5761746 100644 --- a/tests/services/sessions/test_manager.py +++ b/tests/services/sessions/test_manager.py @@ -16,6 +16,9 @@ class DummyKernel: + execution_state: str + last_activity: str + def __init__(self, kernel_name="python"): self.kernel_name = kernel_name @@ -132,7 +135,7 @@ def test_kernel_record_list(): # Test .get() r_ = records.get(r) assert r == r_ - r_ = records.get(r.kernel_id) + r_ = records.get(r.kernel_id or "") assert r == r_ with pytest.raises(ValueError): diff --git a/tests/test_files.py b/tests/test_files.py index 7fac8419d4..06f1932591 100644 --- a/tests/test_files.py +++ b/tests/test_files.py @@ -3,9 +3,9 @@ from pathlib import Path import pytest -import tornado from nbformat import writes from nbformat.v4 import new_code_cell, new_markdown_cell, new_notebook, new_output +from tornado.httpclient import HTTPClientError from .utils import expected_http_error @@ -28,7 +28,7 @@ async def fetch_expect_200(jp_fetch, *path_parts): async def fetch_expect_404(jp_fetch, *path_parts): - with pytest.raises(tornado.httpclient.HTTPClientError) as e: + with pytest.raises(HTTPClientError) as e: await jp_fetch("files", *path_parts, method="GET") assert expected_http_error(e, 404), [path_parts, e] diff --git a/tests/test_gateway.py b/tests/test_gateway.py index d040999558..4eb3c4a9bf 100644 --- a/tests/test_gateway.py +++ b/tests/test_gateway.py @@ -3,7 +3,7 @@ import os import uuid from datetime import datetime -from io import StringIO +from io import BytesIO from unittest.mock import patch import pytest @@ -34,7 +34,7 @@ def generate_kernelspec(name): # We'll mock up two kernelspecs - kspec_foo and kspec_bar -kernelspecs = { +kernelspecs: dict = { "default": "kspec_foo", "kernelspecs": { "kspec_foo": generate_kernelspec("kspec_foo"), @@ -72,16 +72,17 @@ async def mock_gateway_request(url, **kwargs): # Fetch all kernelspecs if endpoint.endswith("/api/kernelspecs") and method == "GET": - response_buf = StringIO(json.dumps(kernelspecs)) + response_buf = BytesIO(json.dumps(kernelspecs).encode("utf-8")) response = await ensure_async(HTTPResponse(request, 200, buffer=response_buf)) return response # Fetch named kernelspec if endpoint.rfind("/api/kernelspecs/") >= 0 and method == "GET": requested_kernelspec = endpoint.rpartition("/")[2] - kspecs = kernelspecs.get("kernelspecs") + kspecs: dict = kernelspecs["kernelspecs"] if requested_kernelspec in kspecs: - response_buf = StringIO(json.dumps(kspecs.get(requested_kernelspec))) + response_str = json.dumps(kspecs.get(requested_kernelspec)) + response_buf = BytesIO(response_str.encode("utf-8")) response = await ensure_async(HTTPResponse(request, 200, buffer=response_buf)) return response else: @@ -96,7 +97,7 @@ async def mock_gateway_request(url, **kwargs): assert name == kspec_name # Ensure that KERNEL_ env values get propagated model = generate_model(name) running_kernels[model.get("id")] = model # Register model as a running kernel - response_buf = StringIO(json.dumps(model)) + response_buf = BytesIO(json.dumps(model).encode("utf-8")) response = await ensure_async(HTTPResponse(request, 201, buffer=response_buf)) return response @@ -106,7 +107,7 @@ async def mock_gateway_request(url, **kwargs): for kernel_id in running_kernels.keys(): model = running_kernels.get(kernel_id) kernels.append(model) - response_buf = StringIO(json.dumps(kernels)) + response_buf = BytesIO(json.dumps(kernels).encode("utf-8")) response = await ensure_async(HTTPResponse(request, 200, buffer=response_buf)) return response @@ -122,7 +123,8 @@ async def mock_gateway_request(url, **kwargs): raise HTTPError(404, message="Kernel does not exist: %s" % requested_kernel_id) elif action == "restart": if requested_kernel_id in running_kernels: - response_buf = StringIO(json.dumps(running_kernels.get(requested_kernel_id))) + response_str = json.dumps(running_kernels.get(requested_kernel_id)) + response_buf = BytesIO(response_str.encode("utf-8")) response = await ensure_async(HTTPResponse(request, 204, buffer=response_buf)) return response else: @@ -143,7 +145,8 @@ async def mock_gateway_request(url, **kwargs): if endpoint.rfind("/api/kernels/") >= 0 and method == "GET": requested_kernel_id = endpoint.rpartition("/")[2] if requested_kernel_id in running_kernels: - response_buf = StringIO(json.dumps(running_kernels.get(requested_kernel_id))) + response_str = json.dumps(running_kernels.get(requested_kernel_id)) + response_buf = BytesIO(response_str.encode("utf-8")) response = await ensure_async(HTTPResponse(request, 200, buffer=response_buf)) return response else: @@ -313,6 +316,7 @@ async def create_session(root_dir, jp_fetch, kernel_name): kernel_id = model.get("kernel").get("id") # ensure its in the running_kernels and name matches. running_kernel = running_kernels.get(kernel_id) + assert running_kernel is not None assert kernel_id == running_kernel.get("id") assert model.get("kernel").get("name") == running_kernel.get("name") session_id = model.get("id") @@ -359,6 +363,7 @@ async def create_kernel(jp_fetch, kernel_name): kernel_id = model.get("id") # ensure its in the running_kernels and name matches. running_kernel = running_kernels.get(kernel_id) + assert running_kernel is not None assert kernel_id == running_kernel.get("id") assert model.get("name") == kernel_name @@ -398,6 +403,7 @@ async def restart_kernel(jp_fetch, kernel_id): restarted_kernel_id = model.get("id") # ensure its in the running_kernels and name matches. running_kernel = running_kernels.get(restarted_kernel_id) + assert running_kernel is not None assert restarted_kernel_id == running_kernel.get("id") assert model.get("name") == running_kernel.get("name") diff --git a/tests/test_paths.py b/tests/test_paths.py index 0789be4ded..9a6a41b3ba 100644 --- a/tests/test_paths.py +++ b/tests/test_paths.py @@ -63,6 +63,7 @@ async def test_trailing_slash( ) # Capture the response from the raised exception value. response = err.value.response + assert response is not None assert response.code == 302 assert "Location" in response.headers assert response.headers["Location"] == url_path_join(jp_base_url, expected) diff --git a/tests/test_serverapp.py b/tests/test_serverapp.py index 145eaf1de7..1ab792d1e5 100644 --- a/tests/test_serverapp.py +++ b/tests/test_serverapp.py @@ -328,7 +328,7 @@ def test_preferred_dir_validation( config_file.write_text("\n".join(config_lines)) if argv: - kwargs["argv"] = argv + kwargs["argv"] = argv # type:ignore if root_dir_loc == "default" and preferred_dir_loc != "default": # error expected with pytest.raises(SystemExit): diff --git a/tests/unix_sockets/test_serverapp_integration.py b/tests/unix_sockets/test_serverapp_integration.py index 5bb1038234..9661539d7e 100644 --- a/tests/unix_sockets/test_serverapp_integration.py +++ b/tests/unix_sockets/test_serverapp_integration.py @@ -35,6 +35,7 @@ def test_shutdown_sock_server_integration(jp_unix_socket_file): ) complete = False + assert p.stderr is not None for line in iter(p.stderr.readline, b""): if url in line: complete = True diff --git a/tests/utils.py b/tests/utils.py index 6e6649af42..4eabcdceaa 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,6 +1,7 @@ import json -import tornado +from tornado.httpclient import HTTPClientError +from tornado.web import HTTPError some_resource = "The very model of a modern major general" @@ -20,7 +21,7 @@ def mkdir(tmp_path, *parts): def expected_http_error(error, expected_code, expected_message=None): """Check that the error matches the expected output error.""" e = error.value - if isinstance(e, tornado.web.HTTPError): + if isinstance(e, HTTPError): if expected_code != e.status_code: return False if expected_message is not None and expected_message != str(e): @@ -28,8 +29,8 @@ def expected_http_error(error, expected_code, expected_message=None): return True elif any( [ - isinstance(e, tornado.httpclient.HTTPClientError), - isinstance(e, tornado.httpclient.HTTPError), + isinstance(e, HTTPClientError), + isinstance(e, HTTPError), ] ): if expected_code != e.code: