Skip to content

Commit

Permalink
Add middleware support
Browse files Browse the repository at this point in the history
  • Loading branch information
Nick Humrich authored and nhumrich committed Jan 11, 2020
1 parent 851d586 commit 2ff5dad
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 12 deletions.
3 changes: 2 additions & 1 deletion asyncpg/_testbase/__init__.py
Expand Up @@ -264,6 +264,7 @@ def create_pool(dsn=None, *,
setup=None,
init=None,
loop=None,
middlewares=None,
pool_class=pg_pool.Pool,
connection_class=pg_connection.Connection,
**connect_kwargs):
Expand All @@ -272,7 +273,7 @@ def create_pool(dsn=None, *,
min_size=min_size, max_size=max_size,
max_queries=max_queries, loop=loop, setup=setup, init=init,
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
connection_class=connection_class,
connection_class=connection_class, middlewares=middlewares,
**connect_kwargs)


Expand Down
7 changes: 4 additions & 3 deletions asyncpg/connect_utils.py
Expand Up @@ -594,7 +594,7 @@ async def _create_ssl_connection(protocol_factory, host, port, *,


async def _connect_addr(*, addr, loop, timeout, params, config,
connection_class):
middlewares, connection_class):
assert loop is not None

if timeout <= 0:
Expand Down Expand Up @@ -633,12 +633,12 @@ async def _connect_addr(*, addr, loop, timeout, params, config,
tr.close()
raise

con = connection_class(pr, tr, loop, addr, config, params)
con = connection_class(pr, tr, loop, addr, config, params, middlewares)
pr.set_connection(con)
return con


async def _connect(*, loop, timeout, connection_class, **kwargs):
async def _connect(*, loop, timeout, middlewares, connection_class, **kwargs):
if loop is None:
loop = asyncio.get_event_loop()

Expand All @@ -652,6 +652,7 @@ async def _connect(*, loop, timeout, connection_class, **kwargs):
con = await _connect_addr(
addr=addr, loop=loop, timeout=timeout,
params=params, config=config,
middlewares=middlewares,
connection_class=connection_class)
except (OSError, asyncio.TimeoutError, ConnectionError) as ex:
last_error = ex
Expand Down
22 changes: 17 additions & 5 deletions asyncpg/connection.py
Expand Up @@ -41,7 +41,7 @@ class Connection(metaclass=ConnectionMeta):
"""

__slots__ = ('_protocol', '_transport', '_loop',
'_top_xact', '_aborted',
'_top_xact', '_aborted', '_middlewares',
'_pool_release_ctr', '_stmt_cache', '_stmts_to_close',
'_listeners', '_server_version', '_server_caps',
'_intro_query', '_reset_query', '_proxy',
Expand All @@ -52,7 +52,8 @@ class Connection(metaclass=ConnectionMeta):
def __init__(self, protocol, transport, loop,
addr: (str, int) or str,
config: connect_utils._ClientConfiguration,
params: connect_utils._ConnectionParameters):
params: connect_utils._ConnectionParameters,
_middlewares=None):
self._protocol = protocol
self._transport = transport
self._loop = loop
Expand Down Expand Up @@ -91,7 +92,7 @@ def __init__(self, protocol, transport, loop,

self._reset_query = None
self._proxy = None

self._middlewares = _middlewares
# Used to serialize operations that might involve anonymous
# statements. Specifically, we want to make the following
# operation atomic:
Expand Down Expand Up @@ -1399,8 +1400,13 @@ async def reload_schema_state(self):

async def _execute(self, query, args, limit, timeout, return_status=False):
with self._stmt_exclusive_section:
result, _ = await self.__execute(
query, args, limit, timeout, return_status=return_status)
wrapped = self.__execute
if self._middlewares:
for m in reversed(self._middlewares):
wrapped = await m(connection=self, handler=wrapped)

result, _ = await wrapped(query, args, limit,
timeout, return_status=return_status)
return result

async def __execute(self, query, args, limit, timeout,
Expand Down Expand Up @@ -1491,6 +1497,7 @@ async def connect(dsn=None, *,
max_cacheable_statement_size=1024 * 15,
command_timeout=None,
ssl=None,
middlewares=None,
connection_class=Connection,
server_settings=None):
r"""A coroutine to establish a connection to a PostgreSQL server.
Expand Down Expand Up @@ -1607,6 +1614,10 @@ async def connect(dsn=None, *,
PostgreSQL documentation for
a `list of supported options <server settings>`_.
:param middlewares:
An optional list of middleware functions. Refer to documentation
on create_pool.
:param Connection connection_class:
Class of the returned connection object. Must be a subclass of
:class:`~asyncpg.connection.Connection`.
Expand Down Expand Up @@ -1672,6 +1683,7 @@ async def connect(dsn=None, *,
ssl=ssl, database=database,
server_settings=server_settings,
command_timeout=command_timeout,
middlewares=middlewares,
statement_cache_size=statement_cache_size,
max_cached_statement_lifetime=max_cached_statement_lifetime,
max_cacheable_statement_size=max_cacheable_statement_size)
Expand Down
46 changes: 44 additions & 2 deletions asyncpg/pool.py
Expand Up @@ -305,7 +305,7 @@ class Pool:
"""

__slots__ = (
'_queue', '_loop', '_minsize', '_maxsize',
'_queue', '_loop', '_minsize', '_maxsize', '_middlewares',
'_init', '_connect_args', '_connect_kwargs',
'_working_addr', '_working_config', '_working_params',
'_holders', '_initialized', '_initializing', '_closing',
Expand All @@ -320,6 +320,7 @@ def __init__(self, *connect_args,
max_inactive_connection_lifetime,
setup,
init,
middlewares,
loop,
connection_class,
**connect_kwargs):
Expand Down Expand Up @@ -377,6 +378,7 @@ def __init__(self, *connect_args,
self._closed = False
self._generation = 0
self._init = init
self._middlewares = middlewares
self._connect_args = connect_args
self._connect_kwargs = connect_kwargs

Expand Down Expand Up @@ -469,6 +471,7 @@ async def _get_new_connection(self):
*self._connect_args,
loop=self._loop,
connection_class=self._connection_class,
middlewares=self._middlewares,
**self._connect_kwargs)

self._working_addr = con._addr
Expand All @@ -483,6 +486,7 @@ async def _get_new_connection(self):
addr=self._working_addr,
timeout=self._working_params.connect_timeout,
config=self._working_config,
middlewares=self._middlewares,
params=self._working_params,
connection_class=self._connection_class)

Expand Down Expand Up @@ -784,13 +788,37 @@ def __await__(self):
return self.pool._acquire(self.timeout).__await__()


def middleware(f):
"""Decorator for adding a middleware
Can be used like such
.. code-block:: python
@pool.middleware
async def my_middleware(query, args, limit,
timeout, return_status, *, handler, conn):
print('do something before')
result, stmt = await handler(query, args, limit,
timeout, return_status)
print('do something after')
return result, stmt
my_pool = await pool.create_pool(middlewares=[my_middleware])
"""
async def middleware_factory(connection, handler):
return functools.partial(f, connection=connection, handler=handler)
return middleware_factory


def create_pool(dsn=None, *,
min_size=10,
max_size=10,
max_queries=50000,
max_inactive_connection_lifetime=300.0,
setup=None,
init=None,
middlewares=None,
loop=None,
connection_class=connection.Connection,
**connect_kwargs):
Expand Down Expand Up @@ -866,6 +894,19 @@ def create_pool(dsn=None, *,
or :meth:`Connection.set_type_codec() <\
asyncpg.connection.Connection.set_type_codec>`.
:param middlewares:
A list of middleware functions to be middleware just
before a connection excecutes a statement.
Syntax of a middleware is as follows:
async def middleware_factory(connection, handler):
async def middleware(query, args, limit, timeout, return_status):
print('do something before')
result, stmt = await handler(query, args, limit,
timeout, return_status)
print('do something after')
return result, stmt
return middleware
:param loop:
An asyncio event loop instance. If ``None``, the default
event loop will be used.
Expand Down Expand Up @@ -893,6 +934,7 @@ def create_pool(dsn=None, *,
dsn,
connection_class=connection_class,
min_size=min_size, max_size=max_size,
max_queries=max_queries, loop=loop, setup=setup, init=init,
max_queries=max_queries, loop=loop, setup=setup,
middlewares=middlewares, init=init,
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
**connect_kwargs)
2 changes: 1 addition & 1 deletion docs/installation.rst
Expand Up @@ -30,7 +30,7 @@ If you want to build **asyncpg** from a Git checkout you will need:
* CPython header files. These can usually be obtained by installing
the relevant Python development package: **python3-dev** on Debian/Ubuntu,
**python3-devel** on RHEL/Fedora.

* Clone the repo with submodules (`git clone --recursive`, or `git submodules init; git submodules update`)
Once the above requirements are satisfied, run the following command
in the root of the source checkout:

Expand Down
42 changes: 42 additions & 0 deletions tests/test_pool.py
Expand Up @@ -76,6 +76,48 @@ async def worker():
tasks = [worker() for _ in range(n)]
await asyncio.gather(*tasks)

async def test_pool_with_middleware(self):
called = False

async def my_middleware_factory(connection, handler):
async def middleware(query, args, limit, timeout, return_status):
nonlocal called
called = True
return await handler(query, args, limit,
timeout, return_status)
return middleware

pool = await self.create_pool(database='postgres',
min_size=1, max_size=1,
middlewares=[my_middleware_factory])

con = await pool.acquire(timeout=5)
await con.fetchval('SELECT 1')
assert called

pool.terminate()
del con

async def test_pool_with_middleware_decorator(self):
called = False

@pg_pool.middleware
async def my_middleware(query, args, limit, timeout, return_status,
*, connection, handler):
nonlocal called
called = True
return await handler(query, args, limit,
timeout, return_status)

pool = await self.create_pool(database='postgres', min_size=1,
max_size=1, middlewares=[my_middleware])
con = await pool.acquire(timeout=5)
await con.fetchval('SELECT 1')
assert called

pool.terminate()
del con

async def test_pool_03(self):
pool = await self.create_pool(database='postgres',
min_size=1, max_size=1)
Expand Down

0 comments on commit 2ff5dad

Please sign in to comment.