diff --git a/jupyter_server/services/sessions/sessionmanager.py b/jupyter_server/services/sessions/sessionmanager.py index b7c400af59..d8726c18b7 100644 --- a/jupyter_server/services/sessions/sessionmanager.py +++ b/jupyter_server/services/sessions/sessionmanager.py @@ -10,6 +10,9 @@ # fallback on pysqlite2 if Python was build without sqlite from pysqlite2 import dbapi2 as sqlite3 +from dataclasses import dataclass, fields +from typing import Union + from tornado import web from traitlets import Instance, TraitError, Unicode, validate from traitlets.config.configurable import LoggingConfigurable @@ -18,6 +21,132 @@ from jupyter_server.utils import ensure_async +class KernelSessionRecordConflict(Exception): + """Exception class to use when two KernelSessionRecords cannot + merge because of conflicting data. + """ + + pass + + +@dataclass +class KernelSessionRecord: + """A record object for tracking a Jupyter Server Kernel Session. + + Two records that share a session_id must also share a kernel_id, while + kernels can have multiple session (and thereby) session_ids + associated with them. + """ + + session_id: Union[None, str] = None + kernel_id: Union[None, str] = None + + def __eq__(self, other: "KernelSessionRecord") -> bool: + if isinstance(other, KernelSessionRecord): + condition1 = self.kernel_id and self.kernel_id == other.kernel_id + condition2 = all( + [ + self.session_id == other.session_id, + self.kernel_id is None or other.kernel_id is None, + ] + ) + if any([condition1, condition2]): + return True + # If two records share session_id but have different kernels, this is + # and ill-posed expression. This should never be true. Raise an exception + # to inform the user. + if all( + [ + self.session_id, + self.session_id == other.session_id, + self.kernel_id != other.kernel_id, + ] + ): + raise KernelSessionRecordConflict( + "A single session_id can only have one kernel_id " + "associated with. These two KernelSessionRecords share the same " + "session_id but have different kernel_ids. This should " + "not be possible and is likely an issue with the session " + "records." + ) + return False + + def update(self, other: "KernelSessionRecord") -> None: + """Updates in-place a kernel from other (only accepts positive updates""" + if not isinstance(other, KernelSessionRecord): + raise TypeError("'other' must be an instance of KernelSessionRecord.") + + if other.kernel_id and self.kernel_id and other.kernel_id != self.kernel_id: + raise KernelSessionRecordConflict( + "Could not update the record from 'other' because the two records conflict." + ) + + for field in fields(self): + if hasattr(other, field.name) and getattr(other, field.name): + setattr(self, field.name, getattr(other, field.name)) + + +class KernelSessionRecordList: + """An object for storing and managing a list of KernelSessionRecords. + + When adding a record to the list, the KernelSessionRecordList + first checks if the record already exists in the list. If it does, + the record will be updated with the new information; otherwise, + it will be appended. + """ + + def __init__(self, *records): + self._records = [] + for record in records: + self.update(record) + + def __str__(self): + return str(self._records) + + def __contains__(self, record: Union[KernelSessionRecord, str]): + """Search for records by kernel_id and session_id""" + if isinstance(record, KernelSessionRecord) and record in self._records: + return True + + if isinstance(record, str): + for r in self._records: + if record in [r.session_id, r.kernel_id]: + return True + return False + + def __len__(self): + return len(self._records) + + def get(self, record: Union[KernelSessionRecord, str]) -> KernelSessionRecord: + """Return a full KernelSessionRecord from a session_id, kernel_id, or + incomplete KernelSessionRecord. + """ + if isinstance(record, str): + for r in self._records: + if record == r.kernel_id or record == r.session_id: + return r + elif isinstance(record, KernelSessionRecord): + for r in self._records: + if record == r: + return record + raise ValueError(f"{record} not found in KernelSessionRecordList.") + + def update(self, record: KernelSessionRecord) -> None: + """Update a record in-place or append it if not in the list.""" + try: + idx = self._records.index(record) + self._records[idx].update(record) + except ValueError: + self._records.append(record) + + def remove(self, record: KernelSessionRecord) -> None: + """Remove a record if its found in the list. If it's not found, + do nothing. + """ + if record in self._records: + self._records.remove(record) + + class SessionManager(LoggingConfigurable): database_filepath = Unicode( @@ -58,6 +187,10 @@ def _validate_database_filepath(self, proposal): ] ) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._pending_sessions = KernelSessionRecordList() + # Session database initialized below _cursor = None _connection = None @@ -118,15 +251,20 @@ async def create_session( ): """Creates a session and returns its model""" session_id = self.new_session_id() + record = KernelSessionRecord(session_id=session_id) + self._pending_sessions.update(record) if kernel_id is not None and kernel_id in self.kernel_manager: pass else: kernel_id = await self.start_kernel_for_session( session_id, path, name, type, kernel_name ) + record.kernel_id = kernel_id + self._pending_sessions.update(record) result = await self.save_session( session_id, path=path, name=name, type=type, kernel_id=kernel_id ) + self._pending_sessions.remove(record) return result async def start_kernel_for_session(self, session_id, path, name, type, kernel_name): @@ -305,6 +443,9 @@ async def list_sessions(self): async def delete_session(self, session_id): """Deletes the row in the session database with given session_id""" + record = KernelSessionRecord(session_id=session_id) + self._pending_sessions.update(record) session = await self.get_session(session_id=session_id) await ensure_async(self.kernel_manager.shutdown_kernel(session["kernel"]["id"])) self.cursor.execute("DELETE FROM session WHERE session_id=?", (session_id,)) + self._pending_sessions.remove(record) diff --git a/tests/services/sessions/test_manager.py b/tests/services/sessions/test_manager.py index 7191cfae4c..bcd65af666 100644 --- a/tests/services/sessions/test_manager.py +++ b/tests/services/sessions/test_manager.py @@ -1,3 +1,5 @@ +import asyncio + import pytest from tornado import web from traitlets import TraitError @@ -5,7 +7,12 @@ from jupyter_server._tz import isoformat, utcnow from jupyter_server.services.contents.manager import ContentsManager from jupyter_server.services.kernels.kernelmanager import MappingKernelManager -from jupyter_server.services.sessions.sessionmanager import SessionManager +from jupyter_server.services.sessions.sessionmanager import ( + KernelSessionRecord, + KernelSessionRecordConflict, + KernelSessionRecordList, + SessionManager, +) class DummyKernel(object): @@ -17,11 +24,11 @@ def __init__(self, kernel_name="python"): dummy_date_s = isoformat(dummy_date) -class DummyMKM(MappingKernelManager): +class MockMKM(MappingKernelManager): """MappingKernelManager interface that doesn't start kernels, for testing""" def __init__(self, *args, **kwargs): - super(DummyMKM, self).__init__(*args, **kwargs) + super(MockMKM, self).__init__(*args, **kwargs) self.id_letters = iter("ABCDEFGHIJK") def _new_id(self): @@ -39,9 +46,111 @@ async def shutdown_kernel(self, kernel_id, now=False): del self._kernels[kernel_id] +class SlowStartingKernelsMKM(MockMKM): + async def start_kernel(self, kernel_id=None, path=None, kernel_name="python", **kwargs): + await asyncio.sleep(1.0) + return await super().start_kernel( + kernel_id=kernel_id, path=path, kernel_name=kernel_name, **kwargs + ) + + async def shutdown_kernel(self, kernel_id, now=False): + await asyncio.sleep(1.0) + await super().shutdown_kernel(kernel_id, now=now) + + @pytest.fixture def session_manager(): - return SessionManager(kernel_manager=DummyMKM(), contents_manager=ContentsManager()) + return SessionManager(kernel_manager=MockMKM(), contents_manager=ContentsManager()) + + +def test_kernel_record_equals(): + record1 = KernelSessionRecord(session_id="session1") + record2 = KernelSessionRecord(session_id="session1", kernel_id="kernel1") + record3 = KernelSessionRecord(session_id="session2", kernel_id="kernel1") + record4 = KernelSessionRecord(session_id="session1", kernel_id="kernel2") + + assert record1 == record2 + assert record2 == record3 + assert record3 != record4 + assert record1 != record3 + assert record3 != record4 + + with pytest.raises(KernelSessionRecordConflict): + assert record2 == record4 + + +def test_kernel_record_update(): + record1 = KernelSessionRecord(session_id="session1") + record2 = KernelSessionRecord(session_id="session1", kernel_id="kernel1") + record1.update(record2) + assert record1.kernel_id == "kernel1" + + record1 = KernelSessionRecord(session_id="session1") + record2 = KernelSessionRecord(kernel_id="kernel1") + record1.update(record2) + assert record1.kernel_id == "kernel1" + + record1 = KernelSessionRecord(kernel_id="kernel1") + record2 = KernelSessionRecord(session_id="session1") + record1.update(record2) + assert record1.session_id == "session1" + + record1 = KernelSessionRecord(kernel_id="kernel1") + record2 = KernelSessionRecord(session_id="session1", kernel_id="kernel1") + record1.update(record2) + assert record1.session_id == "session1" + + record1 = KernelSessionRecord(kernel_id="kernel1") + record2 = KernelSessionRecord(session_id="session1", kernel_id="kernel2") + with pytest.raises(KernelSessionRecordConflict): + record1.update(record2) + + record1 = KernelSessionRecord(kernel_id="kernel1", session_id="session1") + record2 = KernelSessionRecord(kernel_id="kernel2") + with pytest.raises(KernelSessionRecordConflict): + record1.update(record2) + + record1 = KernelSessionRecord(kernel_id="kernel1", session_id="session1") + record2 = KernelSessionRecord(kernel_id="kernel2", session_id="session1") + with pytest.raises(KernelSessionRecordConflict): + record1.update(record2) + + record1 = KernelSessionRecord(session_id="session1", kernel_id="kernel1") + record2 = KernelSessionRecord(session_id="session2", kernel_id="kernel1") + record1.update(record2) + assert record1.session_id == "session2" + + +def test_kernel_record_list(): + records = KernelSessionRecordList() + r = KernelSessionRecord(kernel_id="kernel1") + records.update(r) + assert r in records + assert "kernel1" in records + assert len(records) == 1 + + # Test .get() + r_ = records.get(r) + assert r == r_ + r_ = records.get(r.kernel_id) + assert r == r_ + + with pytest.raises(ValueError): + records.get("badkernel") + + r_update = KernelSessionRecord(kernel_id="kernel1", session_id="session1") + records.update(r_update) + assert len(records) == 1 + assert "session1" in records + + r2 = KernelSessionRecord(kernel_id="kernel2") + records.update(r2) + assert r2 in records + assert len(records) == 2 + + records.remove(r2) + assert r2 not in records + assert len(records) == 1 async def create_multiple_sessions(session_manager, *kwargs_list): @@ -267,7 +376,7 @@ async def test_bad_delete_session(session_manager): async def test_bad_database_filepath(jp_runtime_dir): - kernel_manager = DummyMKM() + kernel_manager = MockMKM() # Try to write to a path that's a directory, not a file. path_id_directory = str(jp_runtime_dir) @@ -294,7 +403,7 @@ async def test_bad_database_filepath(jp_runtime_dir): async def test_good_database_filepath(jp_runtime_dir): - kernel_manager = DummyMKM() + kernel_manager = MockMKM() # Try writing to an empty file. empty_file = jp_runtime_dir.joinpath("empty.db") @@ -328,7 +437,7 @@ async def test_good_database_filepath(jp_runtime_dir): async def test_session_persistence(jp_runtime_dir): session_db_path = jp_runtime_dir.joinpath("test-session.db") # Kernel manager needs to persist. - kernel_manager = DummyMKM() + kernel_manager = MockMKM() # Initialize a session and start a connection. # This should create the session database the first time. @@ -362,3 +471,50 @@ async def test_session_persistence(jp_runtime_dir): # Assert that the session database persists. session = await session_manager.get_session(session_id=session["id"]) + + +async def test_pending_kernel(): + session_manager = SessionManager( + kernel_manager=SlowStartingKernelsMKM(), contents_manager=ContentsManager() + ) + # Create a session with a slow starting kernel + fut = session_manager.create_session( + path="/path/to/test.ipynb", kernel_name="python", type="notebook" + ) + task = asyncio.create_task(fut) + await asyncio.sleep(0.1) + assert len(session_manager._pending_sessions) == 1 + # Get a handle on the record + record = session_manager._pending_sessions._records[0] + session = await task + # Check that record is cleared after the task has completed. + assert record not in session_manager._pending_sessions + + # Check pending kernel list when sessions are + fut = session_manager.delete_session(session_id=session["id"]) + task = asyncio.create_task(fut) + await asyncio.sleep(0.1) + assert len(session_manager._pending_sessions) == 1 + # Get a handle on the record + record = session_manager._pending_sessions._records[0] + session = await task + # Check that record is cleared after the task has completed. + assert record not in session_manager._pending_sessions + + # Test multiple, parallel pending kernels + fut1 = session_manager.create_session( + path="/path/to/test.ipynb", kernel_name="python", type="notebook" + ) + fut2 = session_manager.create_session( + path="/path/to/test.ipynb", kernel_name="python", type="notebook" + ) + task1 = asyncio.create_task(fut1) + await asyncio.sleep(0.1) + task2 = asyncio.create_task(fut2) + await asyncio.sleep(0.1) + assert len(session_manager._pending_sessions) == 2 + + await task1 + await task2 + session1, session2 = await asyncio.gather(task1, task2) + assert len(session_manager._pending_sessions) == 0