Skip to content

Commit

Permalink
Add support for connection termination listeners (#525)
Browse files Browse the repository at this point in the history
The new `Connection.add_termination_listener()` method can be used
to register callbacks to be invoked when a connection has been terminated.

Co-authored-by: Elvis Pranskevichus <elvis@magic.io>
  • Loading branch information
ioistired and elprans committed Jul 18, 2020
1 parent b081320 commit 8141b93
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 12 deletions.
28 changes: 18 additions & 10 deletions asyncpg/_testbase/fuzzer.py
Expand Up @@ -145,6 +145,10 @@ def _close_connection(self, connection):
if conn_task is not None:
conn_task.cancel()

def close_all_connections(self):
for conn in list(self.connections):
self.loop.call_soon_threadsafe(self._close_connection, conn)


class Connection:
def __init__(self, client_sock, backend_sock, proxy):
Expand Down Expand Up @@ -215,10 +219,11 @@ async def _read(self, sock, n):
else:
return read_task.result()
finally:
if not read_task.done():
read_task.cancel()
if not conn_event_task.done():
conn_event_task.cancel()
if not self.loop.is_closed():
if not read_task.done():
read_task.cancel()
if not conn_event_task.done():
conn_event_task.cancel()

async def _write(self, sock, data):
write_task = asyncio.ensure_future(
Expand All @@ -236,10 +241,11 @@ async def _write(self, sock, data):
else:
return write_task.result()
finally:
if not write_task.done():
write_task.cancel()
if not conn_event_task.done():
conn_event_task.cancel()
if not self.loop.is_closed():
if not write_task.done():
write_task.cancel()
if not conn_event_task.done():
conn_event_task.cancel()

async def proxy_to_backend(self):
buf = None
Expand All @@ -264,7 +270,8 @@ async def proxy_to_backend(self):
pass

finally:
self.loop.call_soon(self.close)
if not self.loop.is_closed():
self.loop.call_soon(self.close)

async def proxy_from_backend(self):
buf = None
Expand All @@ -289,4 +296,5 @@ async def proxy_from_backend(self):
pass

finally:
self.loop.call_soon(self.close)
if not self.loop.is_closed():
self.loop.call_soon(self.close)
47 changes: 45 additions & 2 deletions asyncpg/connection.py
Expand Up @@ -46,8 +46,8 @@ class Connection(metaclass=ConnectionMeta):
'_listeners', '_server_version', '_server_caps',
'_intro_query', '_reset_query', '_proxy',
'_stmt_exclusive_section', '_config', '_params', '_addr',
'_log_listeners', '_cancellations', '_source_traceback',
'__weakref__')
'_log_listeners', '_termination_listeners', '_cancellations',
'_source_traceback', '__weakref__')

def __init__(self, protocol, transport, loop,
addr: (str, int) or str,
Expand Down Expand Up @@ -78,6 +78,7 @@ def __init__(self, protocol, transport, loop,
self._listeners = {}
self._log_listeners = set()
self._cancellations = set()
self._termination_listeners = set()

settings = self._protocol.get_settings()
ver_string = settings.server_version
Expand Down Expand Up @@ -178,6 +179,28 @@ def remove_log_listener(self, callback):
"""
self._log_listeners.discard(callback)

def add_termination_listener(self, callback):
"""Add a listener that will be called when the connection is closed.
:param callable callback:
A callable receiving one argument:
**connection**: a Connection the callback is registered with.
.. versionadded:: 0.21.0
"""
self._termination_listeners.add(callback)

def remove_termination_listener(self, callback):
"""Remove a listening callback for connection termination.
:param callable callback:
The callable that was passed to
:meth:`Connection.add_termination_listener`.
.. versionadded:: 0.21.0
"""
self._termination_listeners.discard(callback)

def get_server_pid(self):
"""Return the PID of the Postgres server the connection is bound to."""
return self._protocol.get_server_pid()
Expand Down Expand Up @@ -1120,6 +1143,7 @@ def _abort(self):
self._protocol = None

def _cleanup(self):
self._call_termination_listeners()
# Free the resources associated with this connection.
# This must be called when a connection is terminated.

Expand Down Expand Up @@ -1237,6 +1261,25 @@ def _call_log_listener(self, cb, con_ref, message):
'exception': ex
})

def _call_termination_listeners(self):
if not self._termination_listeners:
return

con_ref = self._unwrap()
for cb in self._termination_listeners:
try:
cb(con_ref)
except Exception as ex:
self._loop.call_exception_handler({
'message': (
'Unhandled exception in asyncpg connection '
'termination listener callback {!r}'.format(cb)
),
'exception': ex
})

self._termination_listeners.clear()

def _process_notification(self, pid, channel, payload):
if channel not in self._listeners:
return
Expand Down
42 changes: 42 additions & 0 deletions tests/test_listeners.py
Expand Up @@ -6,6 +6,10 @@


import asyncio
import os
import platform
import sys
import unittest

from asyncpg import _testbase as tb
from asyncpg import exceptions
Expand Down Expand Up @@ -272,3 +276,41 @@ def listener1(*args):
pass

con.add_log_listener(listener1)


@unittest.skipIf(os.environ.get('PGHOST'), 'using remote cluster for testing')
@unittest.skipIf(
platform.system() == 'Windows' and
sys.version_info >= (3, 8),
'not compatible with ProactorEventLoop which is default in Python 3.8')
class TestConnectionTerminationListener(tb.ProxiedClusterTestCase):

async def test_connection_termination_callback_called_on_remote(self):

called = False

def close_cb(con):
nonlocal called
called = True

con = await self.connect()
con.add_termination_listener(close_cb)
self.proxy.close_all_connections()
try:
await con.fetchval('SELECT 1')
except Exception:
pass
self.assertTrue(called)

async def test_connection_termination_callback_called_on_local(self):

called = False

def close_cb(con):
nonlocal called
called = True

con = await self.connect()
con.add_termination_listener(close_cb)
await con.close()
self.assertTrue(called)

0 comments on commit 8141b93

Please sign in to comment.