Skip to content

Commit

Permalink
Add AnyIOConnection class
Browse files Browse the repository at this point in the history
The AnyIOConnection class simply uses an anyio.Lock instead of an
asyncio.Lock and relies on anyio waiting functions introduced earlier.

There is no runtime dependency on anyio; this is left to the
responsibility of the user. Thus, the AnyIOConnection class name is
exported at package level only if available.

When checking for windows compatibility in connect(), we now guard event
loop lookup depending on the async library in use, since this is
asyncio-specific.

In tests, we add the following fixtures:
* 'asyncconnection_class' (parametrized) which returns either
  AsyncConnection or AnyIOConnection class;
* 'any_aconn' uses asyncconnection_class (thus also parametrized) and
  will thus build a connection of requested type.
* 'asyncio_backend' useful to mark some tests only applicable for the
  asyncio backend (typically because they use asyncio in test code).

Accordingly, async connections tests are now run with asyncio and trio
backends as we alias 'aconn' fixture to 'any_aconn' in
test_connection_async.py.

The global 'pytestmark = pytest.mark.asyncio' is no longer needed as
most test functions uses a fixture that depends on 'anyio_backend' which
detects async functions. Only test_connect_(bad,)args() functions still
need to be marked.
  • Loading branch information
dlax committed Dec 21, 2021
1 parent 0a679d0 commit 46322e1
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 40 deletions.
11 changes: 9 additions & 2 deletions docs/api/connections.rst
Expand Up @@ -381,8 +381,8 @@ The `!Connection` class
.. _pg_prepared_xacts: https://www.postgresql.org/docs/current/static/view-pg-prepared-xacts.html


The `!AsyncConnection` class
----------------------------
The `!AsyncConnection` classes
------------------------------

.. autoclass:: AsyncConnection()

Expand Down Expand Up @@ -451,6 +451,13 @@ The `!AsyncConnection` class
.. automethod:: tpc_recover


.. autoclass:: AnyIOConnection()

This is class is similar to `AsyncConnection` but uses anyio_ as an
asynchronous library instead of `asyncio`.

.. _anyio: https://anyio.readthedocs.io/

Connection support objects
--------------------------

Expand Down
9 changes: 9 additions & 0 deletions psycopg/psycopg/__init__.py
Expand Up @@ -24,6 +24,13 @@
from .server_cursor import AsyncServerCursor, ServerCursor
from .connection_async import AsyncConnection

try:
from .connection_async import AnyIOConnection # noqa: F401
except ImportError:
_anyio = False
else:
_anyio = True

from . import dbapi20
from .dbapi20 import BINARY, DATETIME, NUMBER, ROWID, STRING
from .dbapi20 import Binary, Date, DateFromTicks, Time, TimeFromTicks
Expand Down Expand Up @@ -102,3 +109,5 @@
"ROWID",
"STRING",
]
if _anyio:
__all__.append("AnyIOConnection")
39 changes: 38 additions & 1 deletion psycopg/psycopg/connection_async.py
Expand Up @@ -57,6 +57,10 @@ def __init__(
self.cursor_factory = AsyncCursor
self.server_cursor_factory = AsyncServerCursor

@classmethod
def _async_library(cls) -> str:
return "asyncio"

@overload
@classmethod
async def connect(
Expand Down Expand Up @@ -93,7 +97,7 @@ async def connect(
**kwargs: Any,
) -> "AsyncConnection[Any]":

if sys.platform == "win32":
if sys.platform == "win32" and cls._async_library() == "asyncio":
loop = get_running_loop()
if isinstance(loop, asyncio.ProactorEventLoop):
raise e.InterfaceError(
Expand Down Expand Up @@ -371,3 +375,36 @@ async def tpc_recover(self) -> List[Xid]:
await self.rollback()

return res


try:
import anyio
except ImportError:
pass
else:
import sniffio

class AnyIOConnection(AsyncConnection[Row]):
"""
Asynchronous wrapper for a connection to the database using AnyIO
asynchronous library.
"""

__module__ = "psycopg"

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.lock = anyio.Lock() # type: ignore[assignment]

@classmethod
def _async_library(cls) -> str:
return sniffio.current_async_library()

async def wait(self, gen: PQGen[RV]) -> RV:
return await waiting.wait_anyio(gen, self.pgconn.socket)

@classmethod
async def _wait_conn(
cls, gen: PQGenConn[RV], timeout: Optional[int]
) -> RV:
return await waiting.wait_conn_anyio(gen, timeout)
1 change: 1 addition & 0 deletions psycopg/setup.py
Expand Up @@ -64,6 +64,7 @@
"sphinx-autobuild >= 2021.3.14",
"sphinx-autodoc-typehints ~= 1.12.0",
# to document optional modules
"anyio",
"dnspython ~= 2.1.0",
"shapely ~= 1.7.0",
],
Expand Down
8 changes: 8 additions & 0 deletions tests/conftest.py
Expand Up @@ -116,3 +116,11 @@ def event_loop(request):
)
def anyio_backend(request):
return request.param


@pytest.fixture
def asyncio_backend(anyio_backend):
"""Skip if the async backend is not 'asyncio'."""
backend, _ = anyio_backend
if backend != "asyncio":
pytest.skip("only applicable for 'asyncio' backend")
26 changes: 26 additions & 0 deletions tests/fix_db.py
Expand Up @@ -151,6 +151,32 @@ async def aconn(dsn, request, tracefile):
await conn.close()


@pytest.fixture
def asyncconnection_class(anyio_backend):
from psycopg import AsyncConnection, AnyIOConnection

if isinstance(anyio_backend, tuple):
backend = anyio_backend[0]
else:
assert isinstance(anyio_backend, str)
backend = anyio_backend
return {"asyncio": AsyncConnection, "trio": AnyIOConnection}[backend]


@pytest.fixture
async def any_aconn(dsn, request, asyncconnection_class):
"""Return an `AsyncConnection` or an `AnyIOConnection` connected to the
``--test-dsn`` database.
"""
conn = await asyncconnection_class.connect(dsn)
msg = check_connection_version(conn.info.server_version, request.function)
if msg:
await conn.close()
pytest.skip(msg)
yield conn
await conn.close()


@pytest.fixture(scope="session")
def svcconn(dsn):
"""
Expand Down

0 comments on commit 46322e1

Please sign in to comment.