diff --git a/ipykernel/comm/comm.py b/ipykernel/comm/comm.py index 77f9f072f..6ec6ff3d3 100644 --- a/ipykernel/comm/comm.py +++ b/ipykernel/comm/comm.py @@ -70,13 +70,16 @@ def _default_comm_id(self): def __init__(self, target_name='', data=None, metadata=None, buffers=None, **kwargs): # Handle differing arguments between base classes. + had_kernel = 'kernel' in kwargs kernel = kwargs.pop('kernel', None) if target_name: kwargs['target_name'] = target_name BaseComm.__init__( self, data=data, metadata=metadata, buffers=buffers, **kwargs ) # type:ignore[call-arg] - kwargs['kernel'] = kernel + # only re-add kernel if explicitly provided + if had_kernel: + kwargs['kernel'] = kernel traitlets.config.LoggingConfigurable.__init__(self, **kwargs) diff --git a/ipykernel/tests/test_comm.py b/ipykernel/tests/test_comm.py index 405841f0c..5ec0d455a 100644 --- a/ipykernel/tests/test_comm.py +++ b/ipykernel/tests/test_comm.py @@ -2,15 +2,18 @@ from ipykernel.comm import Comm, CommManager from ipykernel.ipkernel import IPythonKernel +from ipykernel.kernelbase import Kernel -def test_comm(kernel): +def test_comm(kernel: Kernel) -> None: manager = CommManager(kernel=kernel) - kernel.comm_manager = manager + kernel.comm_manager = manager # type:ignore c = Comm(kernel=kernel, target_name="bar") msgs = [] + assert kernel is c.kernel # type:ignore + def on_close(msg): msgs.append(msg) @@ -28,7 +31,7 @@ def on_message(msg): assert c.target_name == "bar" -def test_comm_manager(kernel): +def test_comm_manager(kernel: Kernel) -> None: manager = CommManager(kernel=kernel) msgs = [] @@ -48,7 +51,7 @@ def on_msg(msg): manager.register_target("foo", foo) manager.register_target("fizz", fizz) - kernel.comm_manager = manager + kernel.comm_manager = manager # type:ignore with unittest.mock.patch.object(Comm, "publish_msg") as publish_msg: comm = Comm() comm.on_msg(on_msg) @@ -56,6 +59,11 @@ def on_msg(msg): manager.register_comm(comm) assert publish_msg.call_count == 1 + # make sure that when we don't pass a kernel, the 'default' kernel is taken + Kernel._instance = kernel # type:ignore + assert comm.kernel is kernel # type:ignore + Kernel.clear_instance() + assert manager.get_comm(comm.comm_id) == comm assert manager.get_comm('foo') is None