Skip to content

Commit

Permalink
add hook to observe pending sessions (#751)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
Zsailer and pre-commit-ci[bot] committed Mar 29, 2022
1 parent 57c0676 commit d32b887
Show file tree
Hide file tree
Showing 2 changed files with 304 additions and 7 deletions.
141 changes: 141 additions & 0 deletions jupyter_server/services/sessions/sessionmanager.py
Expand Up @@ -10,6 +10,9 @@
# fallback on pysqlite2 if Python was build without sqlite
from pysqlite2 import dbapi2 as sqlite3

from dataclasses import dataclass, fields
from typing import Union

from tornado import web
from traitlets import Instance, TraitError, Unicode, validate
from traitlets.config.configurable import LoggingConfigurable
Expand All @@ -18,6 +21,132 @@
from jupyter_server.utils import ensure_async


class KernelSessionRecordConflict(Exception):
"""Exception class to use when two KernelSessionRecords cannot
merge because of conflicting data.
"""

pass


@dataclass
class KernelSessionRecord:
"""A record object for tracking a Jupyter Server Kernel Session.
Two records that share a session_id must also share a kernel_id, while
kernels can have multiple session (and thereby) session_ids
associated with them.
"""

session_id: Union[None, str] = None
kernel_id: Union[None, str] = None

def __eq__(self, other: "KernelSessionRecord") -> bool:
if isinstance(other, KernelSessionRecord):
condition1 = self.kernel_id and self.kernel_id == other.kernel_id
condition2 = all(
[
self.session_id == other.session_id,
self.kernel_id is None or other.kernel_id is None,
]
)
if any([condition1, condition2]):
return True
# If two records share session_id but have different kernels, this is
# and ill-posed expression. This should never be true. Raise an exception
# to inform the user.
if all(
[
self.session_id,
self.session_id == other.session_id,
self.kernel_id != other.kernel_id,
]
):
raise KernelSessionRecordConflict(
"A single session_id can only have one kernel_id "
"associated with. These two KernelSessionRecords share the same "
"session_id but have different kernel_ids. This should "
"not be possible and is likely an issue with the session "
"records."
)
return False

def update(self, other: "KernelSessionRecord") -> None:
"""Updates in-place a kernel from other (only accepts positive updates"""
if not isinstance(other, KernelSessionRecord):
raise TypeError("'other' must be an instance of KernelSessionRecord.")

if other.kernel_id and self.kernel_id and other.kernel_id != self.kernel_id:
raise KernelSessionRecordConflict(
"Could not update the record from 'other' because the two records conflict."
)

for field in fields(self):
if hasattr(other, field.name) and getattr(other, field.name):
setattr(self, field.name, getattr(other, field.name))


class KernelSessionRecordList:
"""An object for storing and managing a list of KernelSessionRecords.
When adding a record to the list, the KernelSessionRecordList
first checks if the record already exists in the list. If it does,
the record will be updated with the new information; otherwise,
it will be appended.
"""

def __init__(self, *records):
self._records = []
for record in records:
self.update(record)

def __str__(self):
return str(self._records)

def __contains__(self, record: Union[KernelSessionRecord, str]):
"""Search for records by kernel_id and session_id"""
if isinstance(record, KernelSessionRecord) and record in self._records:
return True

if isinstance(record, str):
for r in self._records:
if record in [r.session_id, r.kernel_id]:
return True
return False

def __len__(self):
return len(self._records)

def get(self, record: Union[KernelSessionRecord, str]) -> KernelSessionRecord:
"""Return a full KernelSessionRecord from a session_id, kernel_id, or
incomplete KernelSessionRecord.
"""
if isinstance(record, str):
for r in self._records:
if record == r.kernel_id or record == r.session_id:
return r
elif isinstance(record, KernelSessionRecord):
for r in self._records:
if record == r:
return record
raise ValueError(f"{record} not found in KernelSessionRecordList.")

def update(self, record: KernelSessionRecord) -> None:
"""Update a record in-place or append it if not in the list."""
try:
idx = self._records.index(record)
self._records[idx].update(record)
except ValueError:
self._records.append(record)

def remove(self, record: KernelSessionRecord) -> None:
"""Remove a record if its found in the list. If it's not found,
do nothing.
"""
if record in self._records:
self._records.remove(record)


class SessionManager(LoggingConfigurable):

database_filepath = Unicode(
Expand Down Expand Up @@ -58,6 +187,10 @@ def _validate_database_filepath(self, proposal):
]
)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._pending_sessions = KernelSessionRecordList()

# Session database initialized below
_cursor = None
_connection = None
Expand Down Expand Up @@ -118,15 +251,20 @@ async def create_session(
):
"""Creates a session and returns its model"""
session_id = self.new_session_id()
record = KernelSessionRecord(session_id=session_id)
self._pending_sessions.update(record)
if kernel_id is not None and kernel_id in self.kernel_manager:
pass
else:
kernel_id = await self.start_kernel_for_session(
session_id, path, name, type, kernel_name
)
record.kernel_id = kernel_id
self._pending_sessions.update(record)
result = await self.save_session(
session_id, path=path, name=name, type=type, kernel_id=kernel_id
)
self._pending_sessions.remove(record)
return result

async def start_kernel_for_session(self, session_id, path, name, type, kernel_name):
Expand Down Expand Up @@ -305,6 +443,9 @@ async def list_sessions(self):

async def delete_session(self, session_id):
"""Deletes the row in the session database with given session_id"""
record = KernelSessionRecord(session_id=session_id)
self._pending_sessions.update(record)
session = await self.get_session(session_id=session_id)
await ensure_async(self.kernel_manager.shutdown_kernel(session["kernel"]["id"]))
self.cursor.execute("DELETE FROM session WHERE session_id=?", (session_id,))
self._pending_sessions.remove(record)

0 comments on commit d32b887

Please sign in to comment.