Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add middleware support #482

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion asyncpg/_testbase/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ def create_pool(dsn=None, *,
setup=None,
init=None,
loop=None,
middlewares=None,
pool_class=pg_pool.Pool,
connection_class=pg_connection.Connection,
**connect_kwargs):
Expand All @@ -272,7 +273,7 @@ def create_pool(dsn=None, *,
min_size=min_size, max_size=max_size,
max_queries=max_queries, loop=loop, setup=setup, init=init,
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
connection_class=connection_class,
connection_class=connection_class, middlewares=middlewares,
**connect_kwargs)


Expand Down
7 changes: 4 additions & 3 deletions asyncpg/connect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ async def _create_ssl_connection(protocol_factory, host, port, *,


async def _connect_addr(*, addr, loop, timeout, params, config,
connection_class):
middlewares, connection_class):
assert loop is not None

if timeout <= 0:
Expand Down Expand Up @@ -633,12 +633,12 @@ async def _connect_addr(*, addr, loop, timeout, params, config,
tr.close()
raise

con = connection_class(pr, tr, loop, addr, config, params)
con = connection_class(pr, tr, loop, addr, config, params, middlewares)
pr.set_connection(con)
return con


async def _connect(*, loop, timeout, connection_class, **kwargs):
async def _connect(*, loop, timeout, middlewares, connection_class, **kwargs):
if loop is None:
loop = asyncio.get_event_loop()

Expand All @@ -652,6 +652,7 @@ async def _connect(*, loop, timeout, connection_class, **kwargs):
con = await _connect_addr(
addr=addr, loop=loop, timeout=timeout,
params=params, config=config,
middlewares=middlewares,
connection_class=connection_class)
except (OSError, asyncio.TimeoutError, ConnectionError) as ex:
last_error = ex
Expand Down
22 changes: 17 additions & 5 deletions asyncpg/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class Connection(metaclass=ConnectionMeta):
"""

__slots__ = ('_protocol', '_transport', '_loop',
'_top_xact', '_aborted',
'_top_xact', '_aborted', '_middlewares',
'_pool_release_ctr', '_stmt_cache', '_stmts_to_close',
'_listeners', '_server_version', '_server_caps',
'_intro_query', '_reset_query', '_proxy',
Expand All @@ -52,7 +52,8 @@ class Connection(metaclass=ConnectionMeta):
def __init__(self, protocol, transport, loop,
addr: (str, int) or str,
config: connect_utils._ClientConfiguration,
params: connect_utils._ConnectionParameters):
params: connect_utils._ConnectionParameters,
_middlewares=None):
self._protocol = protocol
self._transport = transport
self._loop = loop
Expand Down Expand Up @@ -91,7 +92,7 @@ def __init__(self, protocol, transport, loop,

self._reset_query = None
self._proxy = None

self._middlewares = _middlewares
# Used to serialize operations that might involve anonymous
# statements. Specifically, we want to make the following
# operation atomic:
Expand Down Expand Up @@ -1399,8 +1400,13 @@ async def reload_schema_state(self):

async def _execute(self, query, args, limit, timeout, return_status=False):
with self._stmt_exclusive_section:
result, _ = await self.__execute(
query, args, limit, timeout, return_status=return_status)
wrapped = self.__execute
if self._middlewares:
for m in reversed(self._middlewares):
wrapped = await m(connection=self, handler=wrapped)

result, _ = await wrapped(query, args, limit,
timeout, return_status=return_status)
return result

async def __execute(self, query, args, limit, timeout,
Expand Down Expand Up @@ -1491,6 +1497,7 @@ async def connect(dsn=None, *,
max_cacheable_statement_size=1024 * 15,
command_timeout=None,
ssl=None,
middlewares=None,
connection_class=Connection,
server_settings=None):
r"""A coroutine to establish a connection to a PostgreSQL server.
Expand Down Expand Up @@ -1607,6 +1614,10 @@ async def connect(dsn=None, *,
PostgreSQL documentation for
a `list of supported options <server settings>`_.

:param middlewares:
An optional list of middleware functions. Refer to documentation
on create_pool.

:param Connection connection_class:
Class of the returned connection object. Must be a subclass of
:class:`~asyncpg.connection.Connection`.
Expand Down Expand Up @@ -1672,6 +1683,7 @@ async def connect(dsn=None, *,
ssl=ssl, database=database,
server_settings=server_settings,
command_timeout=command_timeout,
middlewares=middlewares,
statement_cache_size=statement_cache_size,
max_cached_statement_lifetime=max_cached_statement_lifetime,
max_cacheable_statement_size=max_cacheable_statement_size)
Expand Down
50 changes: 48 additions & 2 deletions asyncpg/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ class Pool:
"""

__slots__ = (
'_queue', '_loop', '_minsize', '_maxsize',
'_queue', '_loop', '_minsize', '_maxsize', '_middlewares',
'_init', '_connect_args', '_connect_kwargs',
'_working_addr', '_working_config', '_working_params',
'_holders', '_initialized', '_initializing', '_closing',
Expand All @@ -320,6 +320,7 @@ def __init__(self, *connect_args,
max_inactive_connection_lifetime,
setup,
init,
middlewares,
loop,
connection_class,
**connect_kwargs):
Expand Down Expand Up @@ -377,6 +378,7 @@ def __init__(self, *connect_args,
self._closed = False
self._generation = 0
self._init = init
self._middlewares = middlewares
self._connect_args = connect_args
self._connect_kwargs = connect_kwargs

Expand Down Expand Up @@ -469,6 +471,7 @@ async def _get_new_connection(self):
*self._connect_args,
loop=self._loop,
connection_class=self._connection_class,
middlewares=self._middlewares,
**self._connect_kwargs)

self._working_addr = con._addr
Expand All @@ -483,6 +486,7 @@ async def _get_new_connection(self):
addr=self._working_addr,
timeout=self._working_params.connect_timeout,
config=self._working_config,
middlewares=self._middlewares,
params=self._working_params,
connection_class=self._connection_class)

Expand Down Expand Up @@ -784,13 +788,37 @@ def __await__(self):
return self.pool._acquire(self.timeout).__await__()


def middleware(f):
"""Decorator for adding a middleware

Can be used like such

.. code-block:: python

@pool.middleware
async def my_middleware(query, args, limit,
timeout, return_status, *, handler, conn):
print('do something before')
result, stmt = await handler(query, args, limit,
timeout, return_status)
print('do something after')
return result, stmt

my_pool = await pool.create_pool(middlewares=[my_middleware])
"""
async def middleware_factory(connection, handler):
return functools.partial(f, connection=connection, handler=handler)
return middleware_factory


def create_pool(dsn=None, *,
min_size=10,
max_size=10,
max_queries=50000,
max_inactive_connection_lifetime=300.0,
setup=None,
init=None,
middlewares=None,
loop=None,
connection_class=connection.Connection,
**connect_kwargs):
Expand Down Expand Up @@ -866,6 +894,23 @@ def create_pool(dsn=None, *,
or :meth:`Connection.set_type_codec() <\
asyncpg.connection.Connection.set_type_codec>`.

:param middlewares:
A list of middleware functions to be middleware just
before a connection excecutes a statement.
Syntax of a middleware is as follows:

.. code-block:: python

async def middleware_factory(connection, handler):
async def middleware(query, args, limit,
timeout, return_status):
print('do something before')
result, stmt = await handler(query, args, limit,
timeout, return_status)
print('do something after')
return result, stmt
return middleware

:param loop:
An asyncio event loop instance. If ``None``, the default
event loop will be used.
Expand Down Expand Up @@ -893,6 +938,7 @@ def create_pool(dsn=None, *,
dsn,
connection_class=connection_class,
min_size=min_size, max_size=max_size,
max_queries=max_queries, loop=loop, setup=setup, init=init,
max_queries=max_queries, loop=loop, setup=setup,
middlewares=middlewares, init=init,
max_inactive_connection_lifetime=max_inactive_connection_lifetime,
**connect_kwargs)
1 change: 1 addition & 0 deletions docs/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ If you want to build **asyncpg** from a Git checkout you will need:
* CPython header files. These can usually be obtained by installing
the relevant Python development package: **python3-dev** on Debian/Ubuntu,
**python3-devel** on RHEL/Fedora.
* Clone the repo with submodules (`git clone --recursive`, or `git submodules init; git submodules update`)

Once the above requirements are satisfied, run the following command
in the root of the source checkout:
Expand Down
42 changes: 42 additions & 0 deletions tests/test_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,48 @@ async def worker():
tasks = [worker() for _ in range(n)]
await asyncio.gather(*tasks)

async def test_pool_with_middleware(self):
called = False

async def my_middleware_factory(connection, handler):
nhumrich marked this conversation as resolved.
Show resolved Hide resolved
async def middleware(query, args, limit, timeout, return_status):
nonlocal called
called = True
return await handler(query, args, limit,
timeout, return_status)
return middleware

pool = await self.create_pool(database='postgres',
min_size=1, max_size=1,
middlewares=[my_middleware_factory])

con = await pool.acquire(timeout=5)
await con.fetchval('SELECT 1')
assert called

pool.terminate()
del con

async def test_pool_with_middleware_decorator(self):
called = False

@pg_pool.middleware
async def my_middleware(query, args, limit, timeout, return_status,
*, connection, handler):
nonlocal called
called = True
return await handler(query, args, limit,
timeout, return_status)

pool = await self.create_pool(database='postgres', min_size=1,
max_size=1, middlewares=[my_middleware])
con = await pool.acquire(timeout=5)
await con.fetchval('SELECT 1')
assert called

pool.terminate()
del con

async def test_pool_03(self):
pool = await self.create_pool(database='postgres',
min_size=1, max_size=1)
Expand Down