From 8141b93c59cab25a68554e7704b9e2ef027db541 Mon Sep 17 00:00:00 2001 From: iomintz Date: Sat, 18 Jul 2020 11:31:49 -0500 Subject: [PATCH] Add support for connection termination listeners (#525) The new `Connection.add_termination_listener()` method can be used to register callbacks to be invoked when a connection has been terminated. Co-authored-by: Elvis Pranskevichus --- asyncpg/_testbase/fuzzer.py | 28 ++++++++++++++-------- asyncpg/connection.py | 47 +++++++++++++++++++++++++++++++++++-- tests/test_listeners.py | 42 +++++++++++++++++++++++++++++++++ 3 files changed, 105 insertions(+), 12 deletions(-) diff --git a/asyncpg/_testbase/fuzzer.py b/asyncpg/_testbase/fuzzer.py index 88f6e5c1..5c0b870c 100644 --- a/asyncpg/_testbase/fuzzer.py +++ b/asyncpg/_testbase/fuzzer.py @@ -145,6 +145,10 @@ def _close_connection(self, connection): if conn_task is not None: conn_task.cancel() + def close_all_connections(self): + for conn in list(self.connections): + self.loop.call_soon_threadsafe(self._close_connection, conn) + class Connection: def __init__(self, client_sock, backend_sock, proxy): @@ -215,10 +219,11 @@ async def _read(self, sock, n): else: return read_task.result() finally: - if not read_task.done(): - read_task.cancel() - if not conn_event_task.done(): - conn_event_task.cancel() + if not self.loop.is_closed(): + if not read_task.done(): + read_task.cancel() + if not conn_event_task.done(): + conn_event_task.cancel() async def _write(self, sock, data): write_task = asyncio.ensure_future( @@ -236,10 +241,11 @@ async def _write(self, sock, data): else: return write_task.result() finally: - if not write_task.done(): - write_task.cancel() - if not conn_event_task.done(): - conn_event_task.cancel() + if not self.loop.is_closed(): + if not write_task.done(): + write_task.cancel() + if not conn_event_task.done(): + conn_event_task.cancel() async def proxy_to_backend(self): buf = None @@ -264,7 +270,8 @@ async def proxy_to_backend(self): pass finally: - self.loop.call_soon(self.close) + if not self.loop.is_closed(): + self.loop.call_soon(self.close) async def proxy_from_backend(self): buf = None @@ -289,4 +296,5 @@ async def proxy_from_backend(self): pass finally: - self.loop.call_soon(self.close) + if not self.loop.is_closed(): + self.loop.call_soon(self.close) diff --git a/asyncpg/connection.py b/asyncpg/connection.py index a6941132..a78aafa7 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -46,8 +46,8 @@ class Connection(metaclass=ConnectionMeta): '_listeners', '_server_version', '_server_caps', '_intro_query', '_reset_query', '_proxy', '_stmt_exclusive_section', '_config', '_params', '_addr', - '_log_listeners', '_cancellations', '_source_traceback', - '__weakref__') + '_log_listeners', '_termination_listeners', '_cancellations', + '_source_traceback', '__weakref__') def __init__(self, protocol, transport, loop, addr: (str, int) or str, @@ -78,6 +78,7 @@ def __init__(self, protocol, transport, loop, self._listeners = {} self._log_listeners = set() self._cancellations = set() + self._termination_listeners = set() settings = self._protocol.get_settings() ver_string = settings.server_version @@ -178,6 +179,28 @@ def remove_log_listener(self, callback): """ self._log_listeners.discard(callback) + def add_termination_listener(self, callback): + """Add a listener that will be called when the connection is closed. + + :param callable callback: + A callable receiving one argument: + **connection**: a Connection the callback is registered with. + + .. versionadded:: 0.21.0 + """ + self._termination_listeners.add(callback) + + def remove_termination_listener(self, callback): + """Remove a listening callback for connection termination. + + :param callable callback: + The callable that was passed to + :meth:`Connection.add_termination_listener`. + + .. versionadded:: 0.21.0 + """ + self._termination_listeners.discard(callback) + def get_server_pid(self): """Return the PID of the Postgres server the connection is bound to.""" return self._protocol.get_server_pid() @@ -1120,6 +1143,7 @@ def _abort(self): self._protocol = None def _cleanup(self): + self._call_termination_listeners() # Free the resources associated with this connection. # This must be called when a connection is terminated. @@ -1237,6 +1261,25 @@ def _call_log_listener(self, cb, con_ref, message): 'exception': ex }) + def _call_termination_listeners(self): + if not self._termination_listeners: + return + + con_ref = self._unwrap() + for cb in self._termination_listeners: + try: + cb(con_ref) + except Exception as ex: + self._loop.call_exception_handler({ + 'message': ( + 'Unhandled exception in asyncpg connection ' + 'termination listener callback {!r}'.format(cb) + ), + 'exception': ex + }) + + self._termination_listeners.clear() + def _process_notification(self, pid, channel, payload): if channel not in self._listeners: return diff --git a/tests/test_listeners.py b/tests/test_listeners.py index 4879cd88..a4726e2d 100644 --- a/tests/test_listeners.py +++ b/tests/test_listeners.py @@ -6,6 +6,10 @@ import asyncio +import os +import platform +import sys +import unittest from asyncpg import _testbase as tb from asyncpg import exceptions @@ -272,3 +276,41 @@ def listener1(*args): pass con.add_log_listener(listener1) + + +@unittest.skipIf(os.environ.get('PGHOST'), 'using remote cluster for testing') +@unittest.skipIf( + platform.system() == 'Windows' and + sys.version_info >= (3, 8), + 'not compatible with ProactorEventLoop which is default in Python 3.8') +class TestConnectionTerminationListener(tb.ProxiedClusterTestCase): + + async def test_connection_termination_callback_called_on_remote(self): + + called = False + + def close_cb(con): + nonlocal called + called = True + + con = await self.connect() + con.add_termination_listener(close_cb) + self.proxy.close_all_connections() + try: + await con.fetchval('SELECT 1') + except Exception: + pass + self.assertTrue(called) + + async def test_connection_termination_callback_called_on_local(self): + + called = False + + def close_cb(con): + nonlocal called + called = True + + con = await self.connect() + con.add_termination_listener(close_cb) + await con.close() + self.assertTrue(called)