Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for connection closed listeners #525

Merged
merged 3 commits into from Jul 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
28 changes: 18 additions & 10 deletions asyncpg/_testbase/fuzzer.py
Expand Up @@ -145,6 +145,10 @@ def _close_connection(self, connection):
if conn_task is not None:
conn_task.cancel()

def close_all_connections(self):
for conn in list(self.connections):
self.loop.call_soon_threadsafe(self._close_connection, conn)


class Connection:
def __init__(self, client_sock, backend_sock, proxy):
Expand Down Expand Up @@ -215,10 +219,11 @@ async def _read(self, sock, n):
else:
return read_task.result()
finally:
if not read_task.done():
read_task.cancel()
if not conn_event_task.done():
conn_event_task.cancel()
if not self.loop.is_closed():
if not read_task.done():
read_task.cancel()
if not conn_event_task.done():
conn_event_task.cancel()

async def _write(self, sock, data):
write_task = asyncio.ensure_future(
Expand All @@ -236,10 +241,11 @@ async def _write(self, sock, data):
else:
return write_task.result()
finally:
if not write_task.done():
write_task.cancel()
if not conn_event_task.done():
conn_event_task.cancel()
if not self.loop.is_closed():
if not write_task.done():
write_task.cancel()
if not conn_event_task.done():
conn_event_task.cancel()

async def proxy_to_backend(self):
buf = None
Expand All @@ -264,7 +270,8 @@ async def proxy_to_backend(self):
pass

finally:
self.loop.call_soon(self.close)
if not self.loop.is_closed():
self.loop.call_soon(self.close)

async def proxy_from_backend(self):
buf = None
Expand All @@ -289,4 +296,5 @@ async def proxy_from_backend(self):
pass

finally:
self.loop.call_soon(self.close)
if not self.loop.is_closed():
self.loop.call_soon(self.close)
47 changes: 45 additions & 2 deletions asyncpg/connection.py
Expand Up @@ -46,8 +46,8 @@ class Connection(metaclass=ConnectionMeta):
'_listeners', '_server_version', '_server_caps',
'_intro_query', '_reset_query', '_proxy',
'_stmt_exclusive_section', '_config', '_params', '_addr',
'_log_listeners', '_cancellations', '_source_traceback',
'__weakref__')
'_log_listeners', '_termination_listeners', '_cancellations',
'_source_traceback', '__weakref__')

def __init__(self, protocol, transport, loop,
addr: (str, int) or str,
Expand Down Expand Up @@ -78,6 +78,7 @@ def __init__(self, protocol, transport, loop,
self._listeners = {}
self._log_listeners = set()
self._cancellations = set()
self._termination_listeners = set()

settings = self._protocol.get_settings()
ver_string = settings.server_version
Expand Down Expand Up @@ -178,6 +179,28 @@ def remove_log_listener(self, callback):
"""
self._log_listeners.discard(callback)

def add_termination_listener(self, callback):
"""Add a listener that will be called when the connection is closed.

:param callable callback:
A callable receiving one argument:
**connection**: a Connection the callback is registered with.

.. versionadded:: 0.21.0
"""
self._termination_listeners.add(callback)

def remove_termination_listener(self, callback):
"""Remove a listening callback for connection termination.

:param callable callback:
The callable that was passed to
:meth:`Connection.add_termination_listener`.

.. versionadded:: 0.21.0
"""
self._termination_listeners.discard(callback)

def get_server_pid(self):
"""Return the PID of the Postgres server the connection is bound to."""
return self._protocol.get_server_pid()
Expand Down Expand Up @@ -1120,6 +1143,7 @@ def _abort(self):
self._protocol = None

def _cleanup(self):
self._call_termination_listeners()
# Free the resources associated with this connection.
# This must be called when a connection is terminated.

Expand Down Expand Up @@ -1237,6 +1261,25 @@ def _call_log_listener(self, cb, con_ref, message):
'exception': ex
})

def _call_termination_listeners(self):
if not self._termination_listeners:
return

con_ref = self._unwrap()
for cb in self._termination_listeners:
try:
cb(con_ref)
except Exception as ex:
self._loop.call_exception_handler({
'message': (
'Unhandled exception in asyncpg connection '
'termination listener callback {!r}'.format(cb)
),
'exception': ex
})

self._termination_listeners.clear()

def _process_notification(self, pid, channel, payload):
if channel not in self._listeners:
return
Expand Down
42 changes: 42 additions & 0 deletions tests/test_listeners.py
Expand Up @@ -6,6 +6,10 @@


import asyncio
import os
import platform
import sys
import unittest

from asyncpg import _testbase as tb
from asyncpg import exceptions
Expand Down Expand Up @@ -272,3 +276,41 @@ def listener1(*args):
pass

con.add_log_listener(listener1)


@unittest.skipIf(os.environ.get('PGHOST'), 'using remote cluster for testing')
@unittest.skipIf(
platform.system() == 'Windows' and
sys.version_info >= (3, 8),
'not compatible with ProactorEventLoop which is default in Python 3.8')
class TestConnectionTerminationListener(tb.ProxiedClusterTestCase):

async def test_connection_termination_callback_called_on_remote(self):

called = False

def close_cb(con):
nonlocal called
called = True

con = await self.connect()
con.add_termination_listener(close_cb)
self.proxy.close_all_connections()
try:
await con.fetchval('SELECT 1')
except Exception:
pass
self.assertTrue(called)

async def test_connection_termination_callback_called_on_local(self):

called = False

def close_cb(con):
nonlocal called
called = True

con = await self.connect()
con.add_termination_listener(close_cb)
await con.close()
self.assertTrue(called)