Skip to content

Commit

Permalink
add unit tests for pending kernels in sessions
Browse files Browse the repository at this point in the history
  • Loading branch information
Zsailer committed Mar 23, 2022
1 parent 4abb13a commit 2b19487
Show file tree
Hide file tree
Showing 2 changed files with 236 additions and 10 deletions.
92 changes: 82 additions & 10 deletions jupyter_server/services/sessions/sessionmanager.py
Expand Up @@ -27,9 +27,15 @@
from dataclasses import fields


class KernelRecordConflict(Exception):
"""An exception raised when"""

pass


@dataclass
class KernelRecord:
"""A temporary record.
"""A record object for tracking a Jupyter Server Kernel Session.
Two records are equal if they share the
"""
Expand All @@ -39,39 +45,102 @@ class KernelRecord:

def __eq__(self, other: "KernelRecord") -> bool:
if isinstance(other, KernelRecord):
if any(
condition1 = self.kernel_id and self.kernel_id == other.kernel_id
condition2 = all(
[
# Check if the session_id matches
self.session_id and other.session_id and self.session_id == other.session_id,
# Check if the kernel_id matches.
self.kernel_id and other.kernel_id and self.kernel_id == other.kernel_id,
self.session_id == other.session_id,
self.kernel_id is None or other.kernel_id is None,
]
):
)
if any([condition1, condition2]):
return True
# If two records share session_id but have different kernels, this is
# and ill-posed expression. This should never be true. Raise an exception
# to inform the user.
if all(
[
self.session_id,
self.session_id == other.session_id,
self.kernel_id != other.kernel_id,
]
):
raise KernelRecordConflict(
"A single session_id can only have one kernel_id "
"associated with. These two KernelRecords share the same "
"session_id but have different kernel_ids. This should "
"not be possible and is likely an issue with the session "
"records."
)
return False

def update(self, other: "KernelRecord") -> None:
"""Updates in-place a kernel from other (only accepts positive updates"""
if not isinstance(other, KernelRecord):
raise TypeError("'other' must be an instance of KernelRecord.")

if other.kernel_id and self.kernel_id and other.kernel_id != self.kernel_id:
raise KernelRecordConflict(
"Could not update the record from 'other' because the two records conflict."
)

for field in fields(self):
if hasattr(other, field.name) and getattr(other, field.name):
setattr(self, field.name, getattr(other, field.name))


class KernelRecordList:
"""Handy object for storing and managing a list of KernelRecords.
_records = []
When adding a record to the list, first checks if the record
already exists. If it does, the record will be updated with
the new information.
"""

def __init__(self, *records):
self._records = []
for record in records:
self.update(record)

def __str__(self):
return str(self._records)

def __contains__(self, record: Union[KernelRecord, str]):
"""Search for records by kernel_id and session_id"""
if isinstance(record, KernelRecord) and record in self._records:
return True

if isinstance(record, str):
for r in self._records:
if record in [r.session_id, r.kernel_id]:
return True
return False

def __len__(self):
return len(self._records)

def get(self, record: Union[KernelRecord, str]) -> KernelRecord:
if isinstance(record, str):
for r in self._records:
if record == r.kernel_id or record == r.session_id:
return r
elif isinstance(record, KernelRecord):
for r in self._records:
if record == r:
return record
raise ValueError(f"{record} not found in KernelRecordList.")

def update(self, record: KernelRecord) -> None:
"""Update a record in-place or append it if not in the list."""
try:
idx = self._records.index(record)
self._records[idx].update(record)
except ValueError:
self.append(record)
self._records.append(record)

def remove(self, record: KernelRecord) -> None:
"""Remove a record if its found in the list. If it's not found,
do nothing.
"""
if record in self._records:
self._records.remove(record)

Expand Down Expand Up @@ -116,7 +185,9 @@ def _validate_database_filepath(self, proposal):
]
)

_pending_kernels = KernelRecordList()
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._pending_kernels = KernelRecordList()

# Session database initialized below
_cursor = None
Expand Down Expand Up @@ -186,6 +257,7 @@ async def create_session(
kernel_id = await self.start_kernel_for_session(
session_id, path, name, type, kernel_name
)
record.kernel_id = kernel_id
self._pending_kernels.update(record)
result = await self.save_session(
session_id, path=path, name=name, type=type, kernel_id=kernel_id
Expand Down
154 changes: 154 additions & 0 deletions tests/services/sessions/test_manager.py
@@ -1,3 +1,5 @@
import asyncio

import pytest
from tornado import web
from traitlets import TraitError
Expand All @@ -6,6 +8,9 @@
from jupyter_server._tz import utcnow
from jupyter_server.services.contents.manager import ContentsManager
from jupyter_server.services.kernels.kernelmanager import MappingKernelManager
from jupyter_server.services.sessions.sessionmanager import KernelRecord
from jupyter_server.services.sessions.sessionmanager import KernelRecordConflict
from jupyter_server.services.sessions.sessionmanager import KernelRecordList
from jupyter_server.services.sessions.sessionmanager import SessionManager


Expand Down Expand Up @@ -40,11 +45,113 @@ async def shutdown_kernel(self, kernel_id, now=False):
del self._kernels[kernel_id]


class SlowDummyMKM(DummyMKM):
async def start_kernel(self, kernel_id=None, path=None, kernel_name="python", **kwargs):
await asyncio.sleep(1.0)
return await super().start_kernel(
kernel_id=kernel_id, path=path, kernel_name=kernel_name, **kwargs
)

async def shutdown_kernel(self, kernel_id, now=False):
await asyncio.sleep(1.0)
await super().shutdown_kernel(kernel_id, now=now)


@pytest.fixture
def session_manager():
return SessionManager(kernel_manager=DummyMKM(), contents_manager=ContentsManager())


def test_kernel_record_equals():
record1 = KernelRecord(session_id="session1")
record2 = KernelRecord(session_id="session1", kernel_id="kernel1")
record3 = KernelRecord(session_id="session2", kernel_id="kernel1")
record4 = KernelRecord(session_id="session1", kernel_id="kernel2")

assert record1 == record2
assert record2 == record3
assert record3 != record4
assert record1 != record3
assert record3 != record4

with pytest.raises(KernelRecordConflict):
assert record2 == record4


def test_kernel_record_update():
record1 = KernelRecord(session_id="session1")
record2 = KernelRecord(session_id="session1", kernel_id="kernel1")
record1.update(record2)
assert record1.kernel_id == "kernel1"

record1 = KernelRecord(session_id="session1")
record2 = KernelRecord(kernel_id="kernel1")
record1.update(record2)
assert record1.kernel_id == "kernel1"

record1 = KernelRecord(kernel_id="kernel1")
record2 = KernelRecord(session_id="session1")
record1.update(record2)
assert record1.session_id == "session1"

record1 = KernelRecord(kernel_id="kernel1")
record2 = KernelRecord(session_id="session1", kernel_id="kernel1")
record1.update(record2)
assert record1.session_id == "session1"

record1 = KernelRecord(kernel_id="kernel1")
record2 = KernelRecord(session_id="session1", kernel_id="kernel2")
with pytest.raises(KernelRecordConflict):
record1.update(record2)

record1 = KernelRecord(kernel_id="kernel1", session_id="session1")
record2 = KernelRecord(kernel_id="kernel2")
with pytest.raises(KernelRecordConflict):
record1.update(record2)

record1 = KernelRecord(kernel_id="kernel1", session_id="session1")
record2 = KernelRecord(kernel_id="kernel2", session_id="session1")
with pytest.raises(KernelRecordConflict):
record1.update(record2)

record1 = KernelRecord(session_id="session1", kernel_id="kernel1")
record2 = KernelRecord(session_id="session2", kernel_id="kernel1")
record1.update(record2)
assert record1.session_id == "session2"


def test_kernel_record_list():
records = KernelRecordList()
r = KernelRecord(kernel_id="kernel1")
records.update(r)
assert r in records
assert "kernel1" in records
assert len(records) == 1

# Test .get()
r_ = records.get(r)
assert r == r_
r_ = records.get(r.kernel_id)
assert r == r_

with pytest.raises(ValueError):
records.get("badkernel")

r_update = KernelRecord(kernel_id="kernel1", session_id="session1")
records.update(r_update)
assert len(records) == 1
assert "session1" in records

r2 = KernelRecord(kernel_id="kernel2")
records.update(r2)
assert r2 in records
assert len(records) == 2

records.remove(r2)
assert r2 not in records
assert len(records) == 1


async def create_multiple_sessions(session_manager, *kwargs_list):
sessions = []
for kwargs in kwargs_list:
Expand Down Expand Up @@ -363,3 +470,50 @@ async def test_session_persistence(jp_runtime_dir):

# Assert that the session database persists.
session = await session_manager.get_session(session_id=session["id"])


async def test_pending_kernel():
session_manager = SessionManager(
kernel_manager=SlowDummyMKM(), contents_manager=ContentsManager()
)
# Create a session with a slow starting kernel
fut = session_manager.create_session(
path="/path/to/test.ipynb", kernel_name="python", type="notebook"
)
task = asyncio.create_task(fut)
await asyncio.sleep(0.1)
assert len(session_manager._pending_kernels) == 1
# Get a handle on the record
record = session_manager._pending_kernels._records[0]
session = await task
# Check that record is cleared after the task has completed.
assert record not in session_manager._pending_kernels

# Check pending kernel list when sessions are
fut = session_manager.delete_session(session_id=session["id"])
task = asyncio.create_task(fut)
await asyncio.sleep(0.1)
assert len(session_manager._pending_kernels) == 1
# Get a handle on the record
record = session_manager._pending_kernels._records[0]
session = await task
# Check that record is cleared after the task has completed.
assert record not in session_manager._pending_kernels

# Test multiple, parallel pending kernels
fut1 = session_manager.create_session(
path="/path/to/test.ipynb", kernel_name="python", type="notebook"
)
fut2 = session_manager.create_session(
path="/path/to/test.ipynb", kernel_name="python", type="notebook"
)
task1 = asyncio.create_task(fut1)
await asyncio.sleep(0.1)
task2 = asyncio.create_task(fut2)
await asyncio.sleep(0.1)
assert len(session_manager._pending_kernels) == 2

await task1
await task2
session1, session2 = await asyncio.gather(task1, task2)
assert len(session_manager._pending_kernels) == 0

0 comments on commit 2b19487

Please sign in to comment.