Skip to content

Commit

Permalink
feat: add support for AnyIO async library
Browse files Browse the repository at this point in the history
This adds a new AnyIOConnection class to be used instead of
AsyncConnection in combination with the 'anyio' async library. The class
essentially uses an anyio.Lock instead of an asyncio.Lock (and same for
getaddrinfo()) and relies on AnyIO-specific waiting functions, also
introduced here. The same is done for crdb connection class, though with
more repetition due to typing issues mentioned in inline comments.

All anyio-related code lives in the _anyio sub-package. An 'anyio'
setuptools extra is defined to pull required dependencies.

AnyIOConnection is exposed on the psycopg namespace, a runtime
check is performed when instantiating possibly producing an informative
message about missing dependencies.

In tests, overall, the previous anyio_backend fixture is now
parametrized with both asyncio and trio backends and 'aconn_cls' returns
either AsyncConnection or AnyIOConnection depending on backend name.
Test dependencies now include 'anyio[trio]'.

In "waiting" tests, we define 'wait_{conn_,}_async' fixtures that will
pick either asyncio or anyio waiting functions depending on the value of
'anyio_backend' fixture.

Concurrency tests (e.g. test_concurrency_async.py or respective crdb
ones) are not run with the trio backend as then explicitly use asyncio
API. Porting them does not seem strictly needed, at least now. So they
get marked with asyncio_backend.

Finally, we ignore an invalid error (raised by a deprecation warning
during test) about usage of the 'loop' parameter in asyncio API that is
due to Python bug as mentioned in comment.
  • Loading branch information
dlax committed Feb 4, 2023
1 parent 88f8570 commit ba03da6
Show file tree
Hide file tree
Showing 23 changed files with 431 additions and 47 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Expand Up @@ -77,7 +77,7 @@ jobs:
- name: Include dnspython to the packages to install
if: ${{ matrix.ext == 'dns' }}
run: |
echo "DEPS=$DEPS dnspython" >> $GITHUB_ENV
echo "DEPS=$DEPS anyio[trio] dnspython" >> $GITHUB_ENV
echo "MARKERS=$MARKERS dns" >> $GITHUB_ENV
- name: Include shapely to the packages to install
Expand Down
15 changes: 13 additions & 2 deletions docs/api/connections.rst
Expand Up @@ -401,8 +401,8 @@ The `!Connection` class
.. _pg_prepared_xacts: https://www.postgresql.org/docs/current/static/view-pg-prepared-xacts.html


The `!AsyncConnection` class
----------------------------
The `!AsyncConnection` classes
------------------------------

.. autoclass:: AsyncConnection()

Expand Down Expand Up @@ -487,3 +487,14 @@ The `!AsyncConnection` class
.. automethod:: tpc_commit
.. automethod:: tpc_rollback
.. automethod:: tpc_recover


.. autoclass:: AnyIOConnection()

This is class is similar to `AsyncConnection` but uses anyio_ as an
asynchronous library instead of `asyncio`.

To use this class, run ``pip install "psycopg[anyio]"`` to install
required dependencies.

.. _anyio: https://anyio.readthedocs.io/
2 changes: 2 additions & 0 deletions psycopg/psycopg/__init__.py
Expand Up @@ -25,6 +25,7 @@
from .server_cursor import AsyncServerCursor, ServerCursor
from .client_cursor import AsyncClientCursor, ClientCursor
from .connection_async import AsyncConnection
from ._anyio.connection import AnyIOConnection

from . import dbapi20
from .dbapi20 import BINARY, DATETIME, NUMBER, ROWID, STRING
Expand Down Expand Up @@ -59,6 +60,7 @@
# this is the canonical place to obtain them and should be used by MyPy too,
# so that function signatures are consistent with the documentation.
__all__ = [
"AnyIOConnection",
"AsyncClientCursor",
"AsyncConnection",
"AsyncCopy",
Expand Down
Empty file.
80 changes: 80 additions & 0 deletions psycopg/psycopg/_anyio/connection.py
@@ -0,0 +1,80 @@
"""
psycopg async connection objects using AnyIO
"""

# Copyright (C) 2022 The Psycopg Team


from functools import lru_cache
from typing import Any, Optional, TYPE_CHECKING

from .. import errors as e
from ..abc import PQGen, PQGenConn, RV
from ..connection_async import AsyncConnection
from ..rows import Row

if TYPE_CHECKING:
import anyio
import sniffio
from . import waiting
else:
anyio = sniffio = waiting = None


@lru_cache()
def _import_anyio() -> None:
global anyio, sniffio, waiting
try:
import anyio
import sniffio
from . import waiting
except ImportError as e:
raise ImportError(
"anyio is not installed; run `pip install psycopg[anyio]`"
) from e


class AnyIOConnection(AsyncConnection[Row]):
"""
Asynchronous wrapper for a connection to the database using AnyIO
asynchronous library.
"""

__module__ = "psycopg"

def __init__(self, *args: Any, **kwargs: Any) -> None:
_import_anyio()
self._lockcls = anyio.Lock # type: ignore[assignment]
super().__init__(*args, **kwargs)

@staticmethod
def _async_library() -> str:
_import_anyio()
return sniffio.current_async_library()

@staticmethod
def _getaddrinfo() -> Any:
_import_anyio()
return anyio.getaddrinfo

async def wait(self, gen: PQGen[RV]) -> RV:
try:
return await waiting.wait(gen, self.pgconn.socket)
except KeyboardInterrupt:
# TODO: this doesn't seem to work as it does for sync connections
# see tests/test_concurrency_async.py::test_ctrl_c
# In the test, the code doesn't reach this branch.

# On Ctrl-C, try to cancel the query in the server, otherwise
# otherwise the connection will be stuck in ACTIVE state
c = self.pgconn.get_cancel()
c.cancel()
try:
await waiting.wait(gen, self.pgconn.socket)
except e.QueryCanceled:
pass # as expected
raise

@classmethod
async def _wait_conn(cls, gen: PQGenConn[RV], timeout: Optional[int]) -> RV:
return await waiting.wait_conn(gen, timeout)
140 changes: 140 additions & 0 deletions psycopg/psycopg/_anyio/waiting.py
@@ -0,0 +1,140 @@
"""
Async waiting functions using AnyIO.
"""

# Copyright (C) 2022 The Psycopg Team


import socket
from typing import Optional

import anyio

from .. import errors as e
from ..abc import PQGen, PQGenConn, RV
from ..waiting import Ready, Wait


def _fromfd(fileno: int) -> socket.socket:
# AnyIO's wait_socket_readable() and wait_socket_writable() functions work
# with socket object (despite the underlying async libraries -- asyncio and
# trio -- accept integer 'fileno' values):
# https://github.com/agronholm/anyio/issues/386
try:
return socket.fromfd(fileno, socket.AF_INET, socket.SOCK_STREAM)
except OSError as exc:
raise e.OperationalError(
f"failed to build a socket from connection file descriptor: {exc}"
)


async def wait(gen: PQGen[RV], fileno: int) -> RV:
"""
Coroutine waiting for a generator to complete.
:param gen: a generator performing database operations and yielding
`Ready` values when it would block.
:param fileno: the file descriptor to wait on.
:return: whatever *gen* returns on completion.
Behave like in `waiting.wait()`, but exposing an `anyio` interface.
"""
s: Wait
ready: Ready
sock = _fromfd(fileno)

async def readable(ev: anyio.Event) -> None:
await anyio.wait_socket_readable(sock)
nonlocal ready
ready |= Ready.R # type: ignore[assignment]
ev.set()

async def writable(ev: anyio.Event) -> None:
await anyio.wait_socket_writable(sock)
nonlocal ready
ready |= Ready.W # type: ignore[assignment]
ev.set()

try:
s = next(gen)
while True:
reader = s & Wait.R
writer = s & Wait.W
if not reader and not writer:
raise e.InternalError(f"bad poll status: {s}")
ev = anyio.Event()
ready = 0 # type: ignore[assignment]
async with anyio.create_task_group() as tg:
if reader:
tg.start_soon(readable, ev)
if writer:
tg.start_soon(writable, ev)
await ev.wait()
tg.cancel_scope.cancel() # Move on upon first task done.

s = gen.send(ready)

except StopIteration as ex:
rv: RV = ex.args[0] if ex.args else None
return rv

finally:
sock.close()


async def wait_conn(gen: PQGenConn[RV], timeout: Optional[float] = None) -> RV:
"""
Coroutine waiting for a connection generator to complete.
:param gen: a generator performing database operations and yielding
(fd, `Ready`) pairs when it would block.
:param timeout: timeout (in seconds) to check for other interrupt, e.g.
to allow Ctrl-C. If zero or None, wait indefinitely.
:return: whatever *gen* returns on completion.
Behave like in `waiting.wait()`, but take the fileno to wait from the
generator itself, which might change during processing.
"""
s: Wait
ready: Ready

async def readable(sock: socket.socket, ev: anyio.Event) -> None:
await anyio.wait_socket_readable(sock)
nonlocal ready
ready = Ready.R
ev.set()

async def writable(sock: socket.socket, ev: anyio.Event) -> None:
await anyio.wait_socket_writable(sock)
nonlocal ready
ready = Ready.W
ev.set()

timeout = timeout or None
try:
fileno, s = next(gen)

while True:
reader = s & Wait.R
writer = s & Wait.W
if not reader and not writer:
raise e.InternalError(f"bad poll status: {s}")
ev = anyio.Event()
ready = 0 # type: ignore[assignment]
with _fromfd(fileno) as sock:
async with anyio.create_task_group() as tg:
if reader:
tg.start_soon(readable, sock, ev)
if writer:
tg.start_soon(writable, sock, ev)
with anyio.fail_after(timeout):
await ev.wait()

fileno, s = gen.send(ready)

except TimeoutError:
raise e.OperationalError("timeout expired")

except StopIteration as ex:
rv: RV = ex.args[0] if ex.args else None
return rv
17 changes: 13 additions & 4 deletions psycopg/psycopg/connection_async.py
Expand Up @@ -53,6 +53,7 @@ class AsyncConnection(BaseConnection[Row]):
row_factory: AsyncRowFactory[Row]
_pipeline: Optional[AsyncPipeline]
_Self = TypeVar("_Self", bound="AsyncConnection[Any]")
_lockcls = asyncio.Lock

def __init__(
self,
Expand All @@ -61,10 +62,14 @@ def __init__(
):
super().__init__(pgconn)
self.row_factory = row_factory
self.lock = asyncio.Lock()
self.lock = self._lockcls()
self.cursor_factory = AsyncCursor
self.server_cursor_factory = AsyncServerCursor

@staticmethod
def _async_library() -> str:
return "asyncio"

@overload
@classmethod
async def connect(
Expand Down Expand Up @@ -107,7 +112,7 @@ async def connect(
cursor_factory: Optional[Type[AsyncCursor[Row]]] = None,
**kwargs: Any,
) -> "AsyncConnection[Any]":
if sys.platform == "win32":
if sys.platform == "win32" and cls._async_library() == "asyncio":
loop = asyncio.get_running_loop()
if isinstance(loop, asyncio.ProactorEventLoop):
raise e.InterfaceError(
Expand Down Expand Up @@ -167,6 +172,11 @@ async def __aexit__(
if not getattr(self, "_pool", None):
await self.close()

@staticmethod
def _getaddrinfo() -> Any:
loop = asyncio.get_running_loop()
return loop.getaddrinfo

@classmethod
async def _get_connection_params(
cls, conninfo: str, **kwargs: Any
Expand All @@ -189,8 +199,7 @@ async def _get_connection_params(
params["connect_timeout"] = None

# Resolve host addresses in non-blocking way
loop = asyncio.get_running_loop()
params = await resolve_hostaddr_async(params, getaddrinfo=loop.getaddrinfo)
params = await resolve_hostaddr_async(params, getaddrinfo=cls._getaddrinfo())

return params

Expand Down
2 changes: 2 additions & 0 deletions psycopg/psycopg/crdb/__init__.py
Expand Up @@ -6,6 +6,7 @@

from . import _types
from .connection import CrdbConnection, AsyncCrdbConnection, CrdbConnectionInfo
from ._anyio import AnyIOCrdbConnection

adapters = _types.adapters # exposed by the package
connect = CrdbConnection.connect
Expand All @@ -14,6 +15,7 @@
_types.register_crdb_adapters(adapters)

__all__ = [
"AnyIOCrdbConnection",
"AsyncCrdbConnection",
"CrdbConnection",
"CrdbConnectionInfo",
Expand Down

0 comments on commit ba03da6

Please sign in to comment.