Skip to content

Commit

Permalink
Add AnyIOCopy class
Browse files Browse the repository at this point in the history
This class overrides some methods of AsyncCopy. It uses a task group to
hold the worker and communicate is done through memory object streams.
We also need an event to wait for the worker to complete.

We add a _copycls attribute on AsyncConnection base class and
AnyIOConnection so that the copy() context manager on cursor selects to
proper Copy class (either AsyncCopy or AnyIOCopy).

In copy tests, like in async connections ones, we alias the "aconn"
fixture to "any_aconn", so as to run tests with both asyncio and trio
backends.

In test_context_active_rollback_no_clobber[trio], we need to properly
close the rows() async generator in order to avoid a (trio)
ResourceWarning.

Tests test_copy_from_leaks and test_copy_to_leaks are only valid for
asyncio backend because tenacity's AsyncRetrying is implemented using
asyncio.
  • Loading branch information
dlax committed Dec 21, 2021
1 parent 46322e1 commit 597c1cf
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 8 deletions.
7 changes: 7 additions & 0 deletions psycopg/psycopg/connection_async.py
Expand Up @@ -23,6 +23,7 @@
from .conninfo import make_conninfo, conninfo_to_dict
from ._encodings import pgconn_encoding
from .connection import BaseConnection, CursorRow, Notify
from .copy import AsyncCopy
from .generators import notifies
from .transaction import AsyncTransaction
from .cursor_async import AsyncCursor
Expand All @@ -46,6 +47,8 @@ class AsyncConnection(BaseConnection[Row]):
server_cursor_factory: Type[AsyncServerCursor[Row]]
row_factory: AsyncRowFactory[Row]

_copycls = AsyncCopy

def __init__(
self,
pgconn: "PGconn",
Expand Down Expand Up @@ -384,6 +387,8 @@ async def tpc_recover(self) -> List[Xid]:
else:
import sniffio

from .copy import AnyIOCopy

class AnyIOConnection(AsyncConnection[Row]):
"""
Asynchronous wrapper for a connection to the database using AnyIO
Expand All @@ -392,6 +397,8 @@ class AnyIOConnection(AsyncConnection[Row]):

__module__ = "psycopg"

_copycls = AnyIOCopy

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.lock = anyio.Lock() # type: ignore[assignment]
Expand Down
80 changes: 79 additions & 1 deletion psycopg/psycopg/copy.py
Expand Up @@ -13,6 +13,7 @@
from types import TracebackType
from typing import TYPE_CHECKING, AsyncIterator, Iterator, Generic, Union
from typing import Any, Dict, List, Match, Optional, Sequence, Type, Tuple
from typing import TypeVar

from . import pq
from . import errors as e
Expand Down Expand Up @@ -291,6 +292,9 @@ def _write_end(self) -> None:
self._worker = None # break the loop


_C = TypeVar("_C", bound="AsyncCopy")


class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):
"""Manage an asynchronous :sql:`COPY` operation."""

Expand All @@ -303,7 +307,7 @@ def __init__(self, cursor: "AsyncCursor[Any]"):
)
self._worker: Optional[asyncio.Future[None]] = None

async def __aenter__(self) -> "AsyncCopy":
async def __aenter__(self: _C) -> _C:
self._enter()
return self

Expand Down Expand Up @@ -385,6 +389,80 @@ async def _write_end(self) -> None:
self._worker = None # break reference loops if any


try:
import anyio
import anyio.abc
except ImportError:
pass
else:

class AnyIOCopy(AsyncCopy):
"""Manage an asynchronous :sql:`COPY` operation using AnyIO
asynchronous library.
"""

__module__ = "psycopg"

def __init__(self, cursor: "AsyncCursor[Any]"):
super(AsyncCopy, self).__init__(cursor)
self._task_group = anyio.create_task_group()
(
self._send_stream,
self._receive_stream,
) = anyio.create_memory_object_stream(
max_buffer_size=self.QUEUE_SIZE, item_type=bytes
)
self._worker_done: Optional[anyio.Event] = None

async def __aenter__(self) -> "AnyIOCopy":
await self._task_group.__aenter__()
self = await super().__aenter__()
self._worker_done = None
return self

async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
) -> None:
await super().__aexit__(exc_type, exc_val, exc_tb)
await self._task_group.__aexit__(exc_type, exc_val, exc_tb)

async def worker(
self,
*,
task_status: anyio.abc.TaskStatus = anyio.TASK_STATUS_IGNORED,
) -> None:
"""Push data to the server when available from the receiving stream."""
done = anyio.Event()
async with self._receive_stream:
task_status.started(done)
async for data in self._receive_stream:
await self.connection.wait(copy_to(self._pgconn, data))
done.set()

async def _write(self, data: bytes) -> None:
if not data:
return

if not self._worker_done:
ev = await self._task_group.start(self.worker)
assert isinstance(ev, anyio.Event)
self._worker_done = ev

await self._send_stream.send(data)

async def _write_end(self) -> None:
data = self.formatter.end()
await self._write(data)

if self._worker_done:
self._send_stream.close()
await self._worker_done.wait()
self._worker_done = None


class Formatter(ABC):
"""
A class which understand a copy format (text, binary).
Expand Down
2 changes: 1 addition & 1 deletion psycopg/psycopg/cursor_async.py
Expand Up @@ -159,5 +159,5 @@ async def copy(self, statement: Query) -> AsyncIterator[AsyncCopy]:
async with self._conn.lock:
await self._conn.wait(self._start_copy_gen(statement))

async with AsyncCopy(self) as copy:
async with self._conn._copycls(self) as copy:
yield copy
11 changes: 6 additions & 5 deletions tests/test_connection_async.py
Expand Up @@ -205,9 +205,6 @@ async def test_context_inerror_rollback_no_clobber(
async def test_context_active_rollback_no_clobber(
dsn, asyncconnection_class, caplog
):
if asyncconnection_class != AsyncConnection:
pytest.xfail("anyio connection not implemented")

caplog.set_level(logging.WARNING, logger="psycopg")

with pytest.raises(ZeroDivisionError):
Expand All @@ -216,8 +213,12 @@ async def test_context_active_rollback_no_clobber(
async with cur.copy(
"copy (select generate_series(1, 10)) to stdout"
) as copy:
async for row in copy.rows():
1 / 0
rows = copy.rows()
try:
async for row in rows:
1 / 0
finally:
await rows.aclose()

assert len(caplog.records) == 1
rec = caplog.records[0]
Expand Down
11 changes: 10 additions & 1 deletion tests/test_copy_async.py
Expand Up @@ -10,6 +10,7 @@
from psycopg import pq
from psycopg import sql
from psycopg import errors as e
from psycopg.copy import AsyncCopy
from psycopg.pq import Format
from psycopg.types import TypeInfo
from psycopg.adapt import PyFormat
Expand All @@ -21,7 +22,10 @@
from .test_copy import eur, sample_values, sample_records, sample_tabledef
from .test_copy import py_to_raw

pytestmark = pytest.mark.asyncio

@pytest.fixture
def aconn(any_aconn):
return any_aconn


@pytest.mark.parametrize("format", Format)
Expand Down Expand Up @@ -518,6 +522,9 @@ async def test_str(aconn):
[(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
)
async def test_worker_life(aconn, format, buffer):
if aconn._copycls != AsyncCopy:
pytest.skip("only applicable for AsyncCopy")

cur = aconn.cursor()
await ensure_table(cur, sample_tabledef)
async with cur.copy(
Expand All @@ -534,6 +541,7 @@ async def test_worker_life(aconn, format, buffer):


@pytest.mark.slow
@pytest.mark.asyncio
@pytest.mark.parametrize(
"fmt, set_types",
[(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)],
Expand Down Expand Up @@ -594,6 +602,7 @@ async def work():


@pytest.mark.slow
@pytest.mark.asyncio
@pytest.mark.parametrize(
"fmt, set_types",
[(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)],
Expand Down

0 comments on commit 597c1cf

Please sign in to comment.