Skip to content

Commit

Permalink
Raise RuntimeError on queue.join() after queue closing. (#295)
Browse files Browse the repository at this point in the history
Co-authored-by: Andrew Svetlov <andrew.svetlov@gmail.com>
  • Loading branch information
linw1995 and asvetlov committed Oct 26, 2020
1 parent d8803b6 commit df9c4f8
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 1 deletion.
7 changes: 6 additions & 1 deletion janus/__init__.py
Expand Up @@ -69,6 +69,8 @@ def close(self) -> None:
self._closing = True
for fut in self._pending:
fut.cancel()
self._finished.set() # unblocks all async_q.join()
self._all_tasks_done.notify_all() # unblocks all sync_q.join()

async def wait_closed(self) -> None:
# should be called from loop after close().
Expand Down Expand Up @@ -172,7 +174,7 @@ def task_maker() -> None:

def _check_closing(self) -> None:
if self._closing:
raise RuntimeError("Modification of closed queue is forbidden")
raise RuntimeError("Operation on the closed queue is forbidden")


class _SyncQueueProxy(Generic[T]):
Expand Down Expand Up @@ -225,9 +227,11 @@ def join(self) -> None:
When the count of unfinished tasks drops to zero, join() unblocks.
"""
self._parent._check_closing()
with self._parent._all_tasks_done:
while self._parent._unfinished_tasks:
self._parent._all_tasks_done.wait()
self._parent._check_closing()

def qsize(self) -> int:
"""Return the approximate size of the queue (not reliable!)."""
Expand Down Expand Up @@ -513,6 +517,7 @@ async def join(self) -> None:
"""
while True:
with self._parent._sync_mutex:
self._parent._check_closing()
if self._parent._unfinished_tasks == 0:
break
await self._parent._finished.wait()
Expand Down
74 changes: 74 additions & 0 deletions tests/test_mixed.py
@@ -1,5 +1,7 @@
import asyncio
import contextlib
import sys
import threading

import pytest

Expand Down Expand Up @@ -231,3 +233,75 @@ async def test_closed(self):
assert q.closed
assert q.async_q.closed
assert q.sync_q.closed

@pytest.mark.asyncio
async def test_async_join_after_closing(self):
q = janus.Queue()
q.close()
with pytest.raises(RuntimeError), contextlib.suppress(asyncio.TimeoutError):
await asyncio.wait_for(q.async_q.join(), timeout=0.1)

await q.wait_closed()

@pytest.mark.asyncio
async def test_close_after_async_join(self):
q = janus.Queue()
q.sync_q.put(1)

task = asyncio.ensure_future(q.async_q.join())
await asyncio.sleep(0.1) # ensure tasks are blocking

q.close()
with pytest.raises(RuntimeError), contextlib.suppress(asyncio.TimeoutError):
await asyncio.wait_for(task, timeout=0.1)

await q.wait_closed()

@pytest.mark.asyncio
async def test_sync_join_after_closing(self):
q = janus.Queue()
q.sync_q.put(1)

q.close()

loop = asyncio.get_event_loop()
fut = asyncio.Future()

def sync_join():
try:
q.sync_q.join()
except Exception as exc:
loop.call_soon_threadsafe(fut.set_exception, exc)

thr = threading.Thread(target=sync_join, daemon=True)
thr.start()

with pytest.raises(RuntimeError), contextlib.suppress(asyncio.TimeoutError):
await asyncio.wait_for(fut, timeout=0.1)

await q.wait_closed()

@pytest.mark.asyncio
async def test_close_after_sync_join(self):
q = janus.Queue()
q.sync_q.put(1)

loop = asyncio.get_event_loop()
fut = asyncio.Future()

def sync_join():
try:
q.sync_q.join()
except Exception as exc:
loop.call_soon_threadsafe(fut.set_exception, exc)

thr = threading.Thread(target=sync_join, daemon=True)
thr.start()
thr.join(0.1) # ensure tasks are blocking

q.close()

with pytest.raises(RuntimeError), contextlib.suppress(asyncio.TimeoutError):
await asyncio.wait_for(fut, timeout=0.1)

await q.wait_closed()

0 comments on commit df9c4f8

Please sign in to comment.