diff --git a/janus/__init__.py b/janus/__init__.py index 4c064a5..d2aea67 100644 --- a/janus/__init__.py +++ b/janus/__init__.py @@ -69,6 +69,8 @@ def close(self) -> None: self._closing = True for fut in self._pending: fut.cancel() + self._finished.set() # unblocks all async_q.join() + self._all_tasks_done.notify_all() # unblocks all sync_q.join() async def wait_closed(self) -> None: # should be called from loop after close(). @@ -172,7 +174,7 @@ def task_maker() -> None: def _check_closing(self) -> None: if self._closing: - raise RuntimeError("Modification of closed queue is forbidden") + raise RuntimeError("Operation on the closed queue is forbidden") class _SyncQueueProxy(Generic[T]): @@ -225,9 +227,11 @@ def join(self) -> None: When the count of unfinished tasks drops to zero, join() unblocks. """ + self._parent._check_closing() with self._parent._all_tasks_done: while self._parent._unfinished_tasks: self._parent._all_tasks_done.wait() + self._parent._check_closing() def qsize(self) -> int: """Return the approximate size of the queue (not reliable!).""" @@ -513,6 +517,7 @@ async def join(self) -> None: """ while True: with self._parent._sync_mutex: + self._parent._check_closing() if self._parent._unfinished_tasks == 0: break await self._parent._finished.wait() diff --git a/tests/test_mixed.py b/tests/test_mixed.py index 35c2d21..d98c884 100644 --- a/tests/test_mixed.py +++ b/tests/test_mixed.py @@ -1,5 +1,7 @@ import asyncio +import contextlib import sys +import threading import pytest @@ -231,3 +233,75 @@ async def test_closed(self): assert q.closed assert q.async_q.closed assert q.sync_q.closed + + @pytest.mark.asyncio + async def test_async_join_after_closing(self): + q = janus.Queue() + q.close() + with pytest.raises(RuntimeError), contextlib.suppress(asyncio.TimeoutError): + await asyncio.wait_for(q.async_q.join(), timeout=0.1) + + await q.wait_closed() + + @pytest.mark.asyncio + async def test_close_after_async_join(self): + q = janus.Queue() + q.sync_q.put(1) + + task = asyncio.ensure_future(q.async_q.join()) + await asyncio.sleep(0.1) # ensure tasks are blocking + + q.close() + with pytest.raises(RuntimeError), contextlib.suppress(asyncio.TimeoutError): + await asyncio.wait_for(task, timeout=0.1) + + await q.wait_closed() + + @pytest.mark.asyncio + async def test_sync_join_after_closing(self): + q = janus.Queue() + q.sync_q.put(1) + + q.close() + + loop = asyncio.get_event_loop() + fut = asyncio.Future() + + def sync_join(): + try: + q.sync_q.join() + except Exception as exc: + loop.call_soon_threadsafe(fut.set_exception, exc) + + thr = threading.Thread(target=sync_join, daemon=True) + thr.start() + + with pytest.raises(RuntimeError), contextlib.suppress(asyncio.TimeoutError): + await asyncio.wait_for(fut, timeout=0.1) + + await q.wait_closed() + + @pytest.mark.asyncio + async def test_close_after_sync_join(self): + q = janus.Queue() + q.sync_q.put(1) + + loop = asyncio.get_event_loop() + fut = asyncio.Future() + + def sync_join(): + try: + q.sync_q.join() + except Exception as exc: + loop.call_soon_threadsafe(fut.set_exception, exc) + + thr = threading.Thread(target=sync_join, daemon=True) + thr.start() + thr.join(0.1) # ensure tasks are blocking + + q.close() + + with pytest.raises(RuntimeError), contextlib.suppress(asyncio.TimeoutError): + await asyncio.wait_for(fut, timeout=0.1) + + await q.wait_closed()