Skip to content

Commit

Permalink
Add AnyIOConnection class
Browse files Browse the repository at this point in the history
We add parametrized fixtures for async connection classes:
- 'asyncconnection_class' return either AsyncConnection or
  AnyIOConnection class;
- 'any_aconn' uses asyncconnection_class and will thus build a
  connection of requested type.
We also add an 'asyncio_backend' fixture for tests that are implemented
using asyncio.

Accordingly, async connections tests are now run with asyncio and trio
backends as we alias 'aconn' fixture to 'any_aconn' in
test_connection_async.py.

The global 'pytestmark = pytest.mark.asyncio' is no longer needed as
most test functions uses a fixture that depends on 'anyio_backend' which
detects async functions. Only test_connect_(bad,)args() functions still
need to be marked.
  • Loading branch information
dlax committed Nov 26, 2021
1 parent 73793ed commit af2a34d
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 37 deletions.
11 changes: 9 additions & 2 deletions docs/api/connections.rst
Original file line number Diff line number Diff line change
Expand Up @@ -265,8 +265,8 @@ The `!Connection` class
.. automethod:: fileno


The `!AsyncConnection` class
----------------------------
The `!AsyncConnection` classes
------------------------------

.. autoclass:: AsyncConnection()

Expand Down Expand Up @@ -330,6 +330,13 @@ The `!AsyncConnection` class
.. automethod:: set_deferrable


.. autoclass:: AnyIOConnection()

This is class is similar to `AsyncConnection` but uses anyio_ as an
asynchronous library instead of `asyncio`.

.. _anyio: https://anyio.readthedocs.io/

Connection support objects
--------------------------

Expand Down
7 changes: 7 additions & 0 deletions psycopg/psycopg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,11 @@
from .server_cursor import AsyncServerCursor, ServerCursor
from .connection_async import AsyncConnection

try:
from .connection_async import AnyIOConnection
except ImportError:
pass

from . import dbapi20
from .dbapi20 import BINARY, DATETIME, NUMBER, ROWID, STRING
from .dbapi20 import Binary, Date, DateFromTicks, Time, TimeFromTicks
Expand Down Expand Up @@ -100,3 +105,5 @@
"ROWID",
"STRING",
]
if "AnyIOConnection" in globals():
__all__.append("AnyIOConnection")
28 changes: 28 additions & 0 deletions psycopg/psycopg/connection_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,3 +334,31 @@ def _no_set_async(self, attribute: str) -> None:
f"'the {attribute!r} property is read-only on async connections:"
f" please use 'await .set_{attribute}()' instead."
)


try:
import anyio
except ImportError:
pass
else:

class AnyIOConnection(AsyncConnection[Row]):
"""
Asynchronous wrapper for a connection to the database using AnyIO
asynchronous library.
"""

__module__ = "psycopg"

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.lock = anyio.Lock() # type: ignore[assignment]

async def wait(self, gen: PQGen[RV]) -> RV:
return await waiting.wait_anyio(gen, self.pgconn.socket)

@classmethod
async def _wait_conn(
cls, gen: PQGenConn[RV], timeout: Optional[int]
) -> RV:
return await waiting.wait_conn_anyio(gen, timeout)
8 changes: 8 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,11 @@ def event_loop(request):
)
def anyio_backend(request):
return request.param


@pytest.fixture
def asyncio_backend(anyio_backend):
"""Skip if the async backend is not 'asyncio'."""
backend, _ = anyio_backend
if backend != "asyncio":
pytest.skip("only applicable for 'asyncio' backend")
26 changes: 26 additions & 0 deletions tests/fix_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,32 @@ async def aconn(dsn, request):
await conn.close()


@pytest.fixture
def asyncconnection_class(anyio_backend):
from psycopg import AsyncConnection, AnyIOConnection

if isinstance(anyio_backend, tuple):
backend = anyio_backend[0]
else:
assert isinstance(anyio_backend, str)
backend = anyio_backend
return {"asyncio": AsyncConnection, "trio": AnyIOConnection}[backend]


@pytest.fixture
async def any_aconn(dsn, request, asyncconnection_class):
"""Return an `AsyncConnection` or an `AnyIOConnection` connected to the
``--test-dsn`` database.
"""
conn = await asyncconnection_class.connect(dsn)
msg = check_connection_version(conn.info.server_version, request.function)
if msg:
await conn.close()
pytest.skip(msg)
yield conn
await conn.close()


@pytest.fixture(scope="session")
def svcconn(dsn):
"""
Expand Down
83 changes: 48 additions & 35 deletions tests/test_connection_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import sys
import weakref
from typing import Type, Any

import psycopg
from psycopg import AsyncConnection, Notify
Expand All @@ -17,34 +18,37 @@
from .test_connection import tx_params, tx_values_map, conninfo_params_timeout
from .test_adapt import make_bin_dumper, make_dumper

pytestmark = pytest.mark.asyncio

@pytest.fixture
def aconn(any_aconn):
return any_aconn

async def test_connect(dsn):
conn = await AsyncConnection.connect(dsn)

async def test_connect(dsn, asyncconnection_class):
conn = await asyncconnection_class.connect(dsn)
assert not conn.closed
assert conn.pgconn.status == conn.ConnStatus.OK
await conn.close()


async def test_connect_bad():
async def test_connect_bad(asyncconnection_class):
with pytest.raises(psycopg.OperationalError):
await AsyncConnection.connect("dbname=nosuchdb")
await asyncconnection_class.connect("dbname=nosuchdb")


async def test_connect_str_subclass(dsn):
async def test_connect_str_subclass(dsn, asyncconnection_class):
class MyString(str):
pass

conn = await AsyncConnection.connect(MyString(dsn))
conn = await asyncconnection_class.connect(MyString(dsn))
assert not conn.closed
assert conn.pgconn.status == conn.ConnStatus.OK
await conn.close()


@pytest.mark.slow
@pytest.mark.timing
async def test_connect_timeout():
async def test_connect_timeout(asyncio_backend):
s = socket.socket(socket.AF_INET)
s.bind(("", 0))
port = s.getsockname()[1]
Expand Down Expand Up @@ -99,30 +103,30 @@ async def test_broken(aconn):
assert aconn.broken


async def test_connection_warn_close(dsn, recwarn):
conn = await AsyncConnection.connect(dsn)
async def test_connection_warn_close(dsn, asyncconnection_class, recwarn):
conn = await asyncconnection_class.connect(dsn)
await conn.close()
del conn
assert not recwarn, [str(w.message) for w in recwarn.list]

conn = await AsyncConnection.connect(dsn)
conn = await asyncconnection_class.connect(dsn)
del conn
assert "IDLE" in str(recwarn.pop(ResourceWarning).message)

conn = await AsyncConnection.connect(dsn)
conn = await asyncconnection_class.connect(dsn)
await conn.execute("select 1")
del conn
assert "INTRANS" in str(recwarn.pop(ResourceWarning).message)

conn = await AsyncConnection.connect(dsn)
conn = await asyncconnection_class.connect(dsn)
try:
await conn.execute("select wat")
except Exception:
pass
del conn
assert "INERROR" in str(recwarn.pop(ResourceWarning).message)

async with await AsyncConnection.connect(dsn) as conn:
async with await asyncconnection_class.connect(dsn) as conn:
pass
del conn
assert not recwarn, [str(w.message) for w in recwarn.list]
Expand All @@ -137,7 +141,7 @@ async def test_context_commit(aconn, dsn):
assert aconn.closed
assert not aconn.broken

async with await psycopg.AsyncConnection.connect(dsn) as aconn:
async with await aconn.__class__.connect(dsn) as aconn:
async with aconn.cursor() as cur:
await cur.execute("select * from textctx")
assert await cur.fetchall() == []
Expand All @@ -157,7 +161,7 @@ async def test_context_rollback(aconn, dsn):
assert aconn.closed
assert not aconn.broken

async with await psycopg.AsyncConnection.connect(dsn) as aconn:
async with await aconn.__class__.connect(dsn) as aconn:
async with aconn.cursor() as cur:
with pytest.raises(UndefinedTable):
await cur.execute("select * from textctx")
Expand All @@ -169,9 +173,11 @@ async def test_context_close(aconn):
await aconn.close()


async def test_context_rollback_no_clobber(conn, dsn, caplog):
async def test_context_rollback_no_clobber(
conn, dsn, asyncconnection_class, caplog
):
with pytest.raises(ZeroDivisionError):
async with await psycopg.AsyncConnection.connect(dsn) as conn2:
async with await asyncconnection_class.connect(dsn) as conn2:
await conn2.execute("select 1")
conn.execute(
"select pg_terminate_backend(%s::int)",
Expand All @@ -186,8 +192,8 @@ async def test_context_rollback_no_clobber(conn, dsn, caplog):


@pytest.mark.slow
async def test_weakref(dsn):
conn = await psycopg.AsyncConnection.connect(dsn)
async def test_weakref(dsn, asyncconnection_class):
conn = await asyncconnection_class.connect(dsn)
w = weakref.ref(conn)
await conn.close()
del conn
Expand Down Expand Up @@ -284,8 +290,8 @@ async def test_autocommit(aconn):
assert aconn.autocommit is True


async def test_autocommit_connect(dsn):
aconn = await psycopg.AsyncConnection.connect(dsn, autocommit=True)
async def test_autocommit_connect(dsn, asyncconnection_class):
aconn = await asyncconnection_class.connect(dsn, autocommit=True)
assert aconn.autocommit
await aconn.close()

Expand Down Expand Up @@ -333,6 +339,7 @@ async def test_autocommit_unknown(aconn):
(("host=foo",), {"user": None}, "host=foo"),
],
)
@pytest.mark.asyncio
async def test_connect_args(monkeypatch, pgconn, args, kwargs, want):
the_conninfo: str

Expand All @@ -356,14 +363,15 @@ def fake_connect(conninfo):
((), {"nosuchparam": 42}, psycopg.ProgrammingError),
],
)
@pytest.mark.asyncio
async def test_connect_badargs(monkeypatch, pgconn, args, kwargs, exctype):
def fake_connect(conninfo):
return pgconn
yield

monkeypatch.setattr(psycopg.connection, "connect", fake_connect)
with pytest.raises(exctype):
await psycopg.AsyncConnection.connect(*args, **kwargs)
await AsyncConnection.connect(*args, **kwargs)


async def test_broken_connection(aconn):
Expand Down Expand Up @@ -477,12 +485,14 @@ async def test_execute_binary(aconn):
assert cur.pgresult.fformat(0) == 1


async def test_row_factory(dsn):
defaultconn = await AsyncConnection.connect(dsn)
async def test_row_factory(
dsn: str, asyncconnection_class: Type[AsyncConnection[Any]]
) -> None:
defaultconn = await asyncconnection_class.connect(dsn)
assert defaultconn.row_factory is tuple_row # type: ignore[comparison-overlap]
await defaultconn.close()

conn = await AsyncConnection.connect(dsn, row_factory=my_row_factory)
conn = await asyncconnection_class.connect(dsn, row_factory=my_row_factory)
assert conn.row_factory is my_row_factory # type: ignore[comparison-overlap]

cur = await conn.execute("select 'a' as ve")
Expand All @@ -496,11 +506,10 @@ async def test_row_factory(dsn):
await cur2.execute("select 1, 1, 2")
assert await cur2.fetchall() == [(1, 1, 2)]

# TODO: maybe fix something to get rid of 'type: ignore' below.
conn.row_factory = tuple_row # type: ignore[assignment]
conn.row_factory = tuple_row
cur3 = await conn.execute("select 'vale'")
r = await cur3.fetchone()
assert r and r == ("vale",) # type: ignore[comparison-overlap]
assert r and r == ("vale",)
await conn.close()


Expand Down Expand Up @@ -649,19 +658,21 @@ async def test_set_transaction_param_strange(aconn):


@pytest.mark.parametrize("dsn, kwargs, exp", conninfo_params_timeout)
async def test_get_connection_params(dsn, kwargs, exp):
params = await AsyncConnection._get_connection_params(dsn, **kwargs)
async def test_get_connection_params(dsn, asyncconnection_class, kwargs, exp):
params = await asyncconnection_class._get_connection_params(dsn, **kwargs)
conninfo = make_conninfo(**params)
assert conninfo_to_dict(conninfo) == exp[0]
assert params["connect_timeout"] == exp[1]


async def test_connect_context_adapters(dsn):
async def test_connect_context_adapters(
dsn: str, asyncconnection_class: Type[AsyncConnection[Any]]
) -> None:
ctx = psycopg.adapt.AdaptersMap(psycopg.adapters)
ctx.register_dumper(str, make_bin_dumper("b"))
ctx.register_dumper(str, make_dumper("t"))

conn = await psycopg.AsyncConnection.connect(dsn, context=ctx)
conn = await asyncconnection_class.connect(dsn, context=ctx)

cur = await conn.execute("select %s", ["hello"])
assert (await cur.fetchone())[0] == "hellot" # type: ignore[index]
Expand All @@ -670,11 +681,13 @@ async def test_connect_context_adapters(dsn):
await conn.close()


async def test_connect_context_copy(dsn, aconn):
async def test_connect_context_copy(
dsn: str, aconn: AsyncConnection[Any]
) -> None:
aconn.adapters.register_dumper(str, make_bin_dumper("b"))
aconn.adapters.register_dumper(str, make_dumper("t"))

aconn2 = await psycopg.AsyncConnection.connect(dsn, context=aconn)
aconn2 = await aconn.__class__.connect(dsn, context=aconn)

cur = await aconn2.execute("select %s", ["hello"])
assert (await cur.fetchone())[0] == "hellot" # type: ignore[index]
Expand Down

0 comments on commit af2a34d

Please sign in to comment.