Skip to content

Commit

Permalink
Add query logging callbacks and context manager (#1043)
Browse files Browse the repository at this point in the history
  • Loading branch information
dcwatson committed Oct 9, 2023
1 parent 93a6f79 commit b2697ff
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 10 deletions.
128 changes: 118 additions & 10 deletions asyncpg/connection.py
Expand Up @@ -9,6 +9,7 @@
import asyncpg
import collections
import collections.abc
import contextlib
import functools
import itertools
import inspect
Expand Down Expand Up @@ -53,7 +54,7 @@ class Connection(metaclass=ConnectionMeta):
'_intro_query', '_reset_query', '_proxy',
'_stmt_exclusive_section', '_config', '_params', '_addr',
'_log_listeners', '_termination_listeners', '_cancellations',
'_source_traceback', '__weakref__')
'_source_traceback', '_query_loggers', '__weakref__')

def __init__(self, protocol, transport, loop,
addr,
Expand Down Expand Up @@ -87,6 +88,7 @@ def __init__(self, protocol, transport, loop,
self._log_listeners = set()
self._cancellations = set()
self._termination_listeners = set()
self._query_loggers = set()

settings = self._protocol.get_settings()
ver_string = settings.server_version
Expand Down Expand Up @@ -224,6 +226,30 @@ def remove_termination_listener(self, callback):
"""
self._termination_listeners.discard(_Callback.from_callable(callback))

def add_query_logger(self, callback):
"""Add a logger that will be called when queries are executed.
:param callable callback:
A callable or a coroutine function receiving one argument:
**record**: a LoggedQuery containing `query`, `args`, `timeout`,
`elapsed`, `exception`, `conn_addr`, and
`conn_params`.
.. versionadded:: 0.29.0
"""
self._query_loggers.add(_Callback.from_callable(callback))

def remove_query_logger(self, callback):
"""Remove a query logger callback.
:param callable callback:
The callable or coroutine function that was passed to
:meth:`Connection.add_query_logger`.
.. versionadded:: 0.29.0
"""
self._query_loggers.discard(_Callback.from_callable(callback))

def get_server_pid(self):
"""Return the PID of the Postgres server the connection is bound to."""
return self._protocol.get_server_pid()
Expand Down Expand Up @@ -317,7 +343,12 @@ async def execute(self, query: str, *args, timeout: float=None) -> str:
self._check_open()

if not args:
return await self._protocol.query(query, timeout)
if self._query_loggers:
with self._time_and_log(query, args, timeout):
result = await self._protocol.query(query, timeout)
else:
result = await self._protocol.query(query, timeout)
return result

_, status, _ = await self._execute(
query,
Expand Down Expand Up @@ -1487,6 +1518,7 @@ def _cleanup(self):
self._mark_stmts_as_closed()
self._listeners.clear()
self._log_listeners.clear()
self._query_loggers.clear()
self._clean_tasks()

def _clean_tasks(self):
Expand Down Expand Up @@ -1770,6 +1802,63 @@ async def _execute(
)
return result

@contextlib.contextmanager
def query_logger(self, callback):
"""Context manager that adds `callback` to the list of query loggers,
and removes it upon exit.
:param callable callback:
A callable or a coroutine function receiving one argument:
**record**: a LoggedQuery containing `query`, `args`, `timeout`,
`elapsed`, `exception`, `conn_addr`, and
`conn_params`.
Example:
.. code-block:: pycon
>>> class QuerySaver:
def __init__(self):
self.queries = []
def __call__(self, record):
self.queries.append(record.query)
>>> with con.query_logger(QuerySaver()):
>>> await con.execute("SELECT 1")
>>> print(log.queries)
['SELECT 1']
.. versionadded:: 0.29.0
"""
self.add_query_logger(callback)
yield
self.remove_query_logger(callback)

@contextlib.contextmanager
def _time_and_log(self, query, args, timeout):
start = time.monotonic()
exception = None
try:
yield
except BaseException as ex:
exception = ex
raise
finally:
elapsed = time.monotonic() - start
record = LoggedQuery(
query=query,
args=args,
timeout=timeout,
elapsed=elapsed,
exception=exception,
conn_addr=self._addr,
conn_params=self._params,
)
for cb in self._query_loggers:
if cb.is_async:
self._loop.create_task(cb.cb(record))
else:
self._loop.call_soon(cb.cb, record)

async def __execute(
self,
query,
Expand All @@ -1790,13 +1879,24 @@ async def __execute(
timeout=timeout,
)
timeout = self._protocol._get_timeout(timeout)
return await self._do_execute(
query,
executor,
timeout,
record_class=record_class,
ignore_custom_codec=ignore_custom_codec,
)
if self._query_loggers:
with self._time_and_log(query, args, timeout):
result, stmt = await self._do_execute(
query,
executor,
timeout,
record_class=record_class,
ignore_custom_codec=ignore_custom_codec,
)
else:
result, stmt = await self._do_execute(
query,
executor,
timeout,
record_class=record_class,
ignore_custom_codec=ignore_custom_codec,
)
return result, stmt

async def _executemany(self, query, args, timeout):
executor = lambda stmt, timeout: self._protocol.bind_execute_many(
Expand All @@ -1807,7 +1907,8 @@ async def _executemany(self, query, args, timeout):
)
timeout = self._protocol._get_timeout(timeout)
with self._stmt_exclusive_section:
result, _ = await self._do_execute(query, executor, timeout)
with self._time_and_log(query, args, timeout):
result, _ = await self._do_execute(query, executor, timeout)
return result

async def _do_execute(
Expand Down Expand Up @@ -2440,6 +2541,13 @@ class _ConnectionProxy:
__slots__ = ()


LoggedQuery = collections.namedtuple(
'LoggedQuery',
['query', 'args', 'timeout', 'elapsed', 'exception', 'conn_addr',
'conn_params'])
LoggedQuery.__doc__ = 'Log record of an executed query.'


ServerCapabilities = collections.namedtuple(
'ServerCapabilities',
['advisory_locks', 'notifications', 'plpgsql', 'sql_reset',
Expand Down
51 changes: 51 additions & 0 deletions tests/test_logging.py
@@ -0,0 +1,51 @@
import asyncio

from asyncpg import _testbase as tb
from asyncpg import exceptions


class LogCollector:
def __init__(self):
self.records = []

def __call__(self, record):
self.records.append(record)


class TestQueryLogging(tb.ConnectedTestCase):

async def test_logging_context(self):
queries = asyncio.Queue()

def query_saver(record):
queries.put_nowait(record)

log = LogCollector()

with self.con.query_logger(query_saver):
self.assertEqual(len(self.con._query_loggers), 1)
await self.con.execute("SELECT 1")
with self.con.query_logger(log):
self.assertEqual(len(self.con._query_loggers), 2)
await self.con.execute("SELECT 2")

r1 = await queries.get()
r2 = await queries.get()
self.assertEqual(r1.query, "SELECT 1")
self.assertEqual(r2.query, "SELECT 2")
self.assertEqual(len(log.records), 1)
self.assertEqual(log.records[0].query, "SELECT 2")
self.assertEqual(len(self.con._query_loggers), 0)

async def test_error_logging(self):
log = LogCollector()
with self.con.query_logger(log):
with self.assertRaises(exceptions.UndefinedColumnError):
await self.con.execute("SELECT x")

await asyncio.sleep(0) # wait for logging
self.assertEqual(len(log.records), 1)
self.assertEqual(
type(log.records[0].exception),
exceptions.UndefinedColumnError
)

0 comments on commit b2697ff

Please sign in to comment.