Skip to content

Commit

Permalink
naming is hard
Browse files Browse the repository at this point in the history
  • Loading branch information
Zsailer committed Mar 23, 2022
1 parent 2b19487 commit 6267330
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 66 deletions.
54 changes: 27 additions & 27 deletions jupyter_server/services/sessions/sessionmanager.py
Expand Up @@ -27,14 +27,14 @@
from dataclasses import fields


class KernelRecordConflict(Exception):
class KernelSessionRecordConflict(Exception):
"""An exception raised when"""

pass


@dataclass
class KernelRecord:
class KernelSessionRecord:
"""A record object for tracking a Jupyter Server Kernel Session.
Two records are equal if they share the
Expand All @@ -43,8 +43,8 @@ class KernelRecord:
session_id: Union[None, str] = None
kernel_id: Union[None, str] = None

def __eq__(self, other: "KernelRecord") -> bool:
if isinstance(other, KernelRecord):
def __eq__(self, other: "KernelSessionRecord") -> bool:
if isinstance(other, KernelSessionRecord):
condition1 = self.kernel_id and self.kernel_id == other.kernel_id
condition2 = all(
[
Expand All @@ -64,22 +64,22 @@ def __eq__(self, other: "KernelRecord") -> bool:
self.kernel_id != other.kernel_id,
]
):
raise KernelRecordConflict(
raise KernelSessionRecordConflict(
"A single session_id can only have one kernel_id "
"associated with. These two KernelRecords share the same "
"associated with. These two KernelSessionRecords share the same "
"session_id but have different kernel_ids. This should "
"not be possible and is likely an issue with the session "
"records."
)
return False

def update(self, other: "KernelRecord") -> None:
def update(self, other: "KernelSessionRecord") -> None:
"""Updates in-place a kernel from other (only accepts positive updates"""
if not isinstance(other, KernelRecord):
raise TypeError("'other' must be an instance of KernelRecord.")
if not isinstance(other, KernelSessionRecord):
raise TypeError("'other' must be an instance of KernelSessionRecord.")

if other.kernel_id and self.kernel_id and other.kernel_id != self.kernel_id:
raise KernelRecordConflict(
raise KernelSessionRecordConflict(
"Could not update the record from 'other' because the two records conflict."
)

Expand All @@ -88,8 +88,8 @@ def update(self, other: "KernelRecord") -> None:
setattr(self, field.name, getattr(other, field.name))


class KernelRecordList:
"""Handy object for storing and managing a list of KernelRecords.
class KernelSessionRecordList:
"""Handy object for storing and managing a list of KernelSessionRecords.
When adding a record to the list, first checks if the record
already exists. If it does, the record will be updated with
Expand All @@ -104,9 +104,9 @@ def __init__(self, *records):
def __str__(self):
return str(self._records)

def __contains__(self, record: Union[KernelRecord, str]):
def __contains__(self, record: Union[KernelSessionRecord, str]):
"""Search for records by kernel_id and session_id"""
if isinstance(record, KernelRecord) and record in self._records:
if isinstance(record, KernelSessionRecord) and record in self._records:
return True

if isinstance(record, str):
Expand All @@ -118,26 +118,26 @@ def __contains__(self, record: Union[KernelRecord, str]):
def __len__(self):
return len(self._records)

def get(self, record: Union[KernelRecord, str]) -> KernelRecord:
def get(self, record: Union[KernelSessionRecord, str]) -> KernelSessionRecord:
if isinstance(record, str):
for r in self._records:
if record == r.kernel_id or record == r.session_id:
return r
elif isinstance(record, KernelRecord):
elif isinstance(record, KernelSessionRecord):
for r in self._records:
if record == r:
return record
raise ValueError(f"{record} not found in KernelRecordList.")
raise ValueError(f"{record} not found in KernelSessionRecordList.")

def update(self, record: KernelRecord) -> None:
def update(self, record: KernelSessionRecord) -> None:
"""Update a record in-place or append it if not in the list."""
try:
idx = self._records.index(record)
self._records[idx].update(record)
except ValueError:
self._records.append(record)

def remove(self, record: KernelRecord) -> None:
def remove(self, record: KernelSessionRecord) -> None:
"""Remove a record if its found in the list. If it's not found,
do nothing.
"""
Expand Down Expand Up @@ -187,7 +187,7 @@ def _validate_database_filepath(self, proposal):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._pending_kernels = KernelRecordList()
self._pending_sessions = KernelSessionRecordList()

# Session database initialized below
_cursor = None
Expand Down Expand Up @@ -249,20 +249,20 @@ async def create_session(
):
"""Creates a session and returns its model"""
session_id = self.new_session_id()
record = KernelRecord(session_id=session_id)
self._pending_kernels.update(record)
record = KernelSessionRecord(session_id=session_id)
self._pending_sessions.update(record)
if kernel_id is not None and kernel_id in self.kernel_manager:
pass
else:
kernel_id = await self.start_kernel_for_session(
session_id, path, name, type, kernel_name
)
record.kernel_id = kernel_id
self._pending_kernels.update(record)
self._pending_sessions.update(record)
result = await self.save_session(
session_id, path=path, name=name, type=type, kernel_id=kernel_id
)
self._pending_kernels.remove(record)
self._pending_sessions.remove(record)
return result

async def start_kernel_for_session(self, session_id, path, name, type, kernel_name):
Expand Down Expand Up @@ -441,9 +441,9 @@ async def list_sessions(self):

async def delete_session(self, session_id):
"""Deletes the row in the session database with given session_id"""
record = KernelRecord(session_id=session_id)
self._pending_kernels.update(record)
record = KernelSessionRecord(session_id=session_id)
self._pending_sessions.update(record)
session = await self.get_session(session_id=session_id)
await ensure_async(self.kernel_manager.shutdown_kernel(session["kernel"]["id"]))
self.cursor.execute("DELETE FROM session WHERE session_id=?", (session_id,))
self._pending_kernels.remove(record)
self._pending_sessions.remove(record)
78 changes: 39 additions & 39 deletions tests/services/sessions/test_manager.py
Expand Up @@ -8,9 +8,9 @@
from jupyter_server._tz import utcnow
from jupyter_server.services.contents.manager import ContentsManager
from jupyter_server.services.kernels.kernelmanager import MappingKernelManager
from jupyter_server.services.sessions.sessionmanager import KernelRecord
from jupyter_server.services.sessions.sessionmanager import KernelRecordConflict
from jupyter_server.services.sessions.sessionmanager import KernelRecordList
from jupyter_server.services.sessions.sessionmanager import KernelSessionRecord
from jupyter_server.services.sessions.sessionmanager import KernelSessionRecordConflict
from jupyter_server.services.sessions.sessionmanager import KernelSessionRecordList
from jupyter_server.services.sessions.sessionmanager import SessionManager


Expand Down Expand Up @@ -63,66 +63,66 @@ def session_manager():


def test_kernel_record_equals():
record1 = KernelRecord(session_id="session1")
record2 = KernelRecord(session_id="session1", kernel_id="kernel1")
record3 = KernelRecord(session_id="session2", kernel_id="kernel1")
record4 = KernelRecord(session_id="session1", kernel_id="kernel2")
record1 = KernelSessionRecord(session_id="session1")
record2 = KernelSessionRecord(session_id="session1", kernel_id="kernel1")
record3 = KernelSessionRecord(session_id="session2", kernel_id="kernel1")
record4 = KernelSessionRecord(session_id="session1", kernel_id="kernel2")

assert record1 == record2
assert record2 == record3
assert record3 != record4
assert record1 != record3
assert record3 != record4

with pytest.raises(KernelRecordConflict):
with pytest.raises(KernelSessionRecordConflict):
assert record2 == record4


def test_kernel_record_update():
record1 = KernelRecord(session_id="session1")
record2 = KernelRecord(session_id="session1", kernel_id="kernel1")
record1 = KernelSessionRecord(session_id="session1")
record2 = KernelSessionRecord(session_id="session1", kernel_id="kernel1")
record1.update(record2)
assert record1.kernel_id == "kernel1"

record1 = KernelRecord(session_id="session1")
record2 = KernelRecord(kernel_id="kernel1")
record1 = KernelSessionRecord(session_id="session1")
record2 = KernelSessionRecord(kernel_id="kernel1")
record1.update(record2)
assert record1.kernel_id == "kernel1"

record1 = KernelRecord(kernel_id="kernel1")
record2 = KernelRecord(session_id="session1")
record1 = KernelSessionRecord(kernel_id="kernel1")
record2 = KernelSessionRecord(session_id="session1")
record1.update(record2)
assert record1.session_id == "session1"

record1 = KernelRecord(kernel_id="kernel1")
record2 = KernelRecord(session_id="session1", kernel_id="kernel1")
record1 = KernelSessionRecord(kernel_id="kernel1")
record2 = KernelSessionRecord(session_id="session1", kernel_id="kernel1")
record1.update(record2)
assert record1.session_id == "session1"

record1 = KernelRecord(kernel_id="kernel1")
record2 = KernelRecord(session_id="session1", kernel_id="kernel2")
with pytest.raises(KernelRecordConflict):
record1 = KernelSessionRecord(kernel_id="kernel1")
record2 = KernelSessionRecord(session_id="session1", kernel_id="kernel2")
with pytest.raises(KernelSessionRecordConflict):
record1.update(record2)

record1 = KernelRecord(kernel_id="kernel1", session_id="session1")
record2 = KernelRecord(kernel_id="kernel2")
with pytest.raises(KernelRecordConflict):
record1 = KernelSessionRecord(kernel_id="kernel1", session_id="session1")
record2 = KernelSessionRecord(kernel_id="kernel2")
with pytest.raises(KernelSessionRecordConflict):
record1.update(record2)

record1 = KernelRecord(kernel_id="kernel1", session_id="session1")
record2 = KernelRecord(kernel_id="kernel2", session_id="session1")
with pytest.raises(KernelRecordConflict):
record1 = KernelSessionRecord(kernel_id="kernel1", session_id="session1")
record2 = KernelSessionRecord(kernel_id="kernel2", session_id="session1")
with pytest.raises(KernelSessionRecordConflict):
record1.update(record2)

record1 = KernelRecord(session_id="session1", kernel_id="kernel1")
record2 = KernelRecord(session_id="session2", kernel_id="kernel1")
record1 = KernelSessionRecord(session_id="session1", kernel_id="kernel1")
record2 = KernelSessionRecord(session_id="session2", kernel_id="kernel1")
record1.update(record2)
assert record1.session_id == "session2"


def test_kernel_record_list():
records = KernelRecordList()
r = KernelRecord(kernel_id="kernel1")
records = KernelSessionRecordList()
r = KernelSessionRecord(kernel_id="kernel1")
records.update(r)
assert r in records
assert "kernel1" in records
Expand All @@ -137,12 +137,12 @@ def test_kernel_record_list():
with pytest.raises(ValueError):
records.get("badkernel")

r_update = KernelRecord(kernel_id="kernel1", session_id="session1")
r_update = KernelSessionRecord(kernel_id="kernel1", session_id="session1")
records.update(r_update)
assert len(records) == 1
assert "session1" in records

r2 = KernelRecord(kernel_id="kernel2")
r2 = KernelSessionRecord(kernel_id="kernel2")
records.update(r2)
assert r2 in records
assert len(records) == 2
Expand Down Expand Up @@ -482,23 +482,23 @@ async def test_pending_kernel():
)
task = asyncio.create_task(fut)
await asyncio.sleep(0.1)
assert len(session_manager._pending_kernels) == 1
assert len(session_manager._pending_sessions) == 1
# Get a handle on the record
record = session_manager._pending_kernels._records[0]
record = session_manager._pending_sessions._records[0]
session = await task
# Check that record is cleared after the task has completed.
assert record not in session_manager._pending_kernels
assert record not in session_manager._pending_sessions

# Check pending kernel list when sessions are
fut = session_manager.delete_session(session_id=session["id"])
task = asyncio.create_task(fut)
await asyncio.sleep(0.1)
assert len(session_manager._pending_kernels) == 1
assert len(session_manager._pending_sessions) == 1
# Get a handle on the record
record = session_manager._pending_kernels._records[0]
record = session_manager._pending_sessions._records[0]
session = await task
# Check that record is cleared after the task has completed.
assert record not in session_manager._pending_kernels
assert record not in session_manager._pending_sessions

# Test multiple, parallel pending kernels
fut1 = session_manager.create_session(
Expand All @@ -511,9 +511,9 @@ async def test_pending_kernel():
await asyncio.sleep(0.1)
task2 = asyncio.create_task(fut2)
await asyncio.sleep(0.1)
assert len(session_manager._pending_kernels) == 2
assert len(session_manager._pending_sessions) == 2

await task1
await task2
session1, session2 = await asyncio.gather(task1, task2)
assert len(session_manager._pending_kernels) == 0
assert len(session_manager._pending_sessions) == 0

0 comments on commit 6267330

Please sign in to comment.