Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add hook to observe pending sessions #751

Merged
merged 8 commits into from Mar 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
141 changes: 141 additions & 0 deletions jupyter_server/services/sessions/sessionmanager.py
Expand Up @@ -10,6 +10,9 @@
# fallback on pysqlite2 if Python was build without sqlite
from pysqlite2 import dbapi2 as sqlite3

from dataclasses import dataclass, fields
from typing import Union

from tornado import web
from traitlets import Instance, TraitError, Unicode, validate
from traitlets.config.configurable import LoggingConfigurable
Expand All @@ -18,6 +21,132 @@
from jupyter_server.utils import ensure_async


class KernelSessionRecordConflict(Exception):
"""Exception class to use when two KernelSessionRecords cannot
merge because of conflicting data.
"""

pass
Zsailer marked this conversation as resolved.
Show resolved Hide resolved


@dataclass
class KernelSessionRecord:
"""A record object for tracking a Jupyter Server Kernel Session.

Two records that share a session_id must also share a kernel_id, while
kernels can have multiple session (and thereby) session_ids
associated with them.
"""

session_id: Union[None, str] = None
kernel_id: Union[None, str] = None

def __eq__(self, other: "KernelSessionRecord") -> bool:
if isinstance(other, KernelSessionRecord):
condition1 = self.kernel_id and self.kernel_id == other.kernel_id
condition2 = all(
[
self.session_id == other.session_id,
self.kernel_id is None or other.kernel_id is None,
]
)
if any([condition1, condition2]):
return True
# If two records share session_id but have different kernels, this is
# and ill-posed expression. This should never be true. Raise an exception
# to inform the user.
if all(
[
self.session_id,
self.session_id == other.session_id,
self.kernel_id != other.kernel_id,
]
):
raise KernelSessionRecordConflict(
"A single session_id can only have one kernel_id "
"associated with. These two KernelSessionRecords share the same "
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"associated with. These two KernelSessionRecords share the same "
"associated with it. These two KernelSessionRecords share the same "

"session_id but have different kernel_ids. This should "
"not be possible and is likely an issue with the session "
"records."
)
return False

def update(self, other: "KernelSessionRecord") -> None:
"""Updates in-place a kernel from other (only accepts positive updates"""
Zsailer marked this conversation as resolved.
Show resolved Hide resolved
if not isinstance(other, KernelSessionRecord):
raise TypeError("'other' must be an instance of KernelSessionRecord.")

if other.kernel_id and self.kernel_id and other.kernel_id != self.kernel_id:
kevin-bates marked this conversation as resolved.
Show resolved Hide resolved
raise KernelSessionRecordConflict(
"Could not update the record from 'other' because the two records conflict."
)

for field in fields(self):
if hasattr(other, field.name) and getattr(other, field.name):
setattr(self, field.name, getattr(other, field.name))


class KernelSessionRecordList:
"""An object for storing and managing a list of KernelSessionRecords.

When adding a record to the list, the KernelSessionRecordList
first checks if the record already exists in the list. If it does,
the record will be updated with the new information; otherwise,
it will be appended.
"""

def __init__(self, *records):
self._records = []
for record in records:
self.update(record)

def __str__(self):
return str(self._records)

def __contains__(self, record: Union[KernelSessionRecord, str]):
"""Search for records by kernel_id and session_id"""
if isinstance(record, KernelSessionRecord) and record in self._records:
return True

if isinstance(record, str):
for r in self._records:
if record in [r.session_id, r.kernel_id]:
return True
return False

def __len__(self):
return len(self._records)

def get(self, record: Union[KernelSessionRecord, str]) -> KernelSessionRecord:
"""Return a full KernelSessionRecord from a session_id, kernel_id, or
incomplete KernelSessionRecord.
"""
if isinstance(record, str):
for r in self._records:
if record == r.kernel_id or record == r.session_id:
return r
elif isinstance(record, KernelSessionRecord):
for r in self._records:
if record == r:
return record
raise ValueError(f"{record} not found in KernelSessionRecordList.")

def update(self, record: KernelSessionRecord) -> None:
"""Update a record in-place or append it if not in the list."""
try:
idx = self._records.index(record)
self._records[idx].update(record)
except ValueError:
self._records.append(record)

def remove(self, record: KernelSessionRecord) -> None:
"""Remove a record if its found in the list. If it's not found,
do nothing.
"""
if record in self._records:
self._records.remove(record)


class SessionManager(LoggingConfigurable):

database_filepath = Unicode(
Expand Down Expand Up @@ -58,6 +187,10 @@ def _validate_database_filepath(self, proposal):
]
)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._pending_sessions = KernelSessionRecordList()

# Session database initialized below
_cursor = None
_connection = None
Expand Down Expand Up @@ -118,15 +251,20 @@ async def create_session(
):
"""Creates a session and returns its model"""
session_id = self.new_session_id()
record = KernelSessionRecord(session_id=session_id)
self._pending_sessions.update(record)
if kernel_id is not None and kernel_id in self.kernel_manager:
pass
else:
kernel_id = await self.start_kernel_for_session(
session_id, path, name, type, kernel_name
)
record.kernel_id = kernel_id
self._pending_sessions.update(record)
result = await self.save_session(
session_id, path=path, name=name, type=type, kernel_id=kernel_id
)
self._pending_sessions.remove(record)
return result

async def start_kernel_for_session(self, session_id, path, name, type, kernel_name):
Expand Down Expand Up @@ -305,6 +443,9 @@ async def list_sessions(self):

async def delete_session(self, session_id):
"""Deletes the row in the session database with given session_id"""
record = KernelSessionRecord(session_id=session_id)
self._pending_sessions.update(record)
session = await self.get_session(session_id=session_id)
await ensure_async(self.kernel_manager.shutdown_kernel(session["kernel"]["id"]))
self.cursor.execute("DELETE FROM session WHERE session_id=?", (session_id,))
self._pending_sessions.remove(record)