diff --git a/asyncpg/_testbase/fuzzer.py b/asyncpg/_testbase/fuzzer.py index 88f6e5c1..5c0b870c 100644 --- a/asyncpg/_testbase/fuzzer.py +++ b/asyncpg/_testbase/fuzzer.py @@ -145,6 +145,10 @@ def _close_connection(self, connection): if conn_task is not None: conn_task.cancel() + def close_all_connections(self): + for conn in list(self.connections): + self.loop.call_soon_threadsafe(self._close_connection, conn) + class Connection: def __init__(self, client_sock, backend_sock, proxy): @@ -215,10 +219,11 @@ async def _read(self, sock, n): else: return read_task.result() finally: - if not read_task.done(): - read_task.cancel() - if not conn_event_task.done(): - conn_event_task.cancel() + if not self.loop.is_closed(): + if not read_task.done(): + read_task.cancel() + if not conn_event_task.done(): + conn_event_task.cancel() async def _write(self, sock, data): write_task = asyncio.ensure_future( @@ -236,10 +241,11 @@ async def _write(self, sock, data): else: return write_task.result() finally: - if not write_task.done(): - write_task.cancel() - if not conn_event_task.done(): - conn_event_task.cancel() + if not self.loop.is_closed(): + if not write_task.done(): + write_task.cancel() + if not conn_event_task.done(): + conn_event_task.cancel() async def proxy_to_backend(self): buf = None @@ -264,7 +270,8 @@ async def proxy_to_backend(self): pass finally: - self.loop.call_soon(self.close) + if not self.loop.is_closed(): + self.loop.call_soon(self.close) async def proxy_from_backend(self): buf = None @@ -289,4 +296,5 @@ async def proxy_from_backend(self): pass finally: - self.loop.call_soon(self.close) + if not self.loop.is_closed(): + self.loop.call_soon(self.close) diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 3511b633..f311d1e3 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -185,11 +185,16 @@ def add_close_listener(self, callback): :param callable callback: A callable receiving one argument: **connection**: a Connection the callback is registered with. + + .. versionadded:: 0.21.0 """ self._close_listeners.add(callback) def remove_close_listener(self, callback): - """Remove a listening callback for the connection closing.""" + """Remove a listening callback for the connection closing. + + .. versionadded:: 0.21.0 + """ self._close_listeners.discard(callback) def get_server_pid(self): diff --git a/tests/test_listeners.py b/tests/test_listeners.py index 4879cd88..ebf698f3 100644 --- a/tests/test_listeners.py +++ b/tests/test_listeners.py @@ -6,6 +6,10 @@ import asyncio +import os +import platform +import sys +import unittest from asyncpg import _testbase as tb from asyncpg import exceptions @@ -272,3 +276,41 @@ def listener1(*args): pass con.add_log_listener(listener1) + + +@unittest.skipIf(os.environ.get('PGHOST'), 'using remote cluster for testing') +@unittest.skipIf( + platform.system() == 'Windows' and + sys.version_info >= (3, 8), + 'not compatible with ProactorEventLoop which is default in Python 3.8') +class TestConnectionCloseListener(tb.ProxiedClusterTestCase): + + async def test_connection_close_callback_called_on_remote(self): + + called = False + + def close_cb(con): + nonlocal called + called = True + + con = await self.connect() + con.add_close_listener(close_cb) + self.proxy.close_all_connections() + try: + await con.fetchval('SELECT 1') + except Exception: + pass + self.assertTrue(called) + + async def test_connection_close_callback_called_on_local(self): + + called = False + + def close_cb(con): + nonlocal called + called = True + + con = await self.connect() + con.add_close_listener(close_cb) + await con.close() + self.assertTrue(called)