From 097b138dd361450b48a3a1f41b4759d5a9039469 Mon Sep 17 00:00:00 2001 From: Denis Laxalde Date: Thu, 18 Nov 2021 16:08:17 +0100 Subject: [PATCH] Add AnyIOCopy class This class overrides some methods of AsyncCopy. It uses a task group to hold the worker and communicate is done through memory object streams. We also need an event to wait for the worker to complete. We add a _copycls attribute on AsyncConnection base class and AnyIOConnection so that the copy() context manager on cursor selects to proper Copy class (either AsyncCopy or AnyIOCopy). In copy tests, like in async connections ones, we alias the "aconn" fixture to "any_aconn", so as to run tests with both asyncio and trio backends. In test_context_active_rollback_no_clobber[trio], we need to properly close the rows() async generator in order to avoid a (trio) ResourceWarning. Tests test_copy_from_leaks and test_copy_to_leaks are only valid for asyncio backend because tenacity's AsyncRetrying is implemented using asyncio. --- psycopg/psycopg/connection_async.py | 7 +++ psycopg/psycopg/copy.py | 80 ++++++++++++++++++++++++++++- psycopg/psycopg/cursor_async.py | 2 +- tests/test_connection_async.py | 11 ++-- tests/test_copy_async.py | 11 +++- 5 files changed, 103 insertions(+), 8 deletions(-) diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py index 86be6b00d..f2e6ccd0e 100644 --- a/psycopg/psycopg/connection_async.py +++ b/psycopg/psycopg/connection_async.py @@ -23,6 +23,7 @@ from .conninfo import make_conninfo, conninfo_to_dict from ._encodings import pgconn_encoding from .connection import BaseConnection, CursorRow, Notify +from .copy import AsyncCopy from .generators import notifies from .transaction import AsyncTransaction from .cursor_async import AsyncCursor @@ -46,6 +47,8 @@ class AsyncConnection(BaseConnection[Row]): server_cursor_factory: Type[AsyncServerCursor[Row]] row_factory: AsyncRowFactory[Row] + _copycls = AsyncCopy + def __init__( self, pgconn: "PGconn", @@ -384,6 +387,8 @@ async def tpc_recover(self) -> List[Xid]: else: import sniffio + from .copy import AnyIOCopy + class AnyIOConnection(AsyncConnection[Row]): """ Asynchronous wrapper for a connection to the database using AnyIO @@ -392,6 +397,8 @@ class AnyIOConnection(AsyncConnection[Row]): __module__ = "psycopg" + _copycls = AnyIOCopy + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.lock = anyio.Lock() # type: ignore[assignment] diff --git a/psycopg/psycopg/copy.py b/psycopg/psycopg/copy.py index 189f70d75..21f7a9a36 100644 --- a/psycopg/psycopg/copy.py +++ b/psycopg/psycopg/copy.py @@ -13,6 +13,7 @@ from types import TracebackType from typing import TYPE_CHECKING, AsyncIterator, Iterator, Generic, Union from typing import Any, Dict, List, Match, Optional, Sequence, Type, Tuple +from typing import TypeVar from . import pq from . import errors as e @@ -291,6 +292,9 @@ def _write_end(self) -> None: self._worker = None # break the loop +_C = TypeVar("_C", bound="AsyncCopy") + + class AsyncCopy(BaseCopy["AsyncConnection[Any]"]): """Manage an asynchronous :sql:`COPY` operation.""" @@ -303,7 +307,7 @@ def __init__(self, cursor: "AsyncCursor[Any]"): ) self._worker: Optional[asyncio.Future[None]] = None - async def __aenter__(self) -> "AsyncCopy": + async def __aenter__(self: _C) -> _C: self._enter() return self @@ -385,6 +389,80 @@ async def _write_end(self) -> None: self._worker = None # break reference loops if any +try: + import anyio + import anyio.abc +except ImportError: + pass +else: + + class AnyIOCopy(AsyncCopy): + """Manage an asynchronous :sql:`COPY` operation using AnyIO + asynchronous library. + """ + + __module__ = "psycopg" + + def __init__(self, cursor: "AsyncCursor[Any]"): + super(AsyncCopy, self).__init__(cursor) + self._task_group = anyio.create_task_group() + ( + self._send_stream, + self._receive_stream, + ) = anyio.create_memory_object_stream( + max_buffer_size=self.QUEUE_SIZE, item_type=bytes + ) + self._worker_done: Optional[anyio.Event] = None + + async def __aenter__(self) -> "AnyIOCopy": + await self._task_group.__aenter__() + self = await super().__aenter__() + self._worker_done = None + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + await super().__aexit__(exc_type, exc_val, exc_tb) + await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + + async def worker( + self, + *, + task_status: anyio.abc.TaskStatus = anyio.TASK_STATUS_IGNORED, + ) -> None: + """Push data to the server when available from the receiving stream.""" + done = anyio.Event() + async with self._receive_stream: + task_status.started(done) + async for data in self._receive_stream: + await self.connection.wait(copy_to(self._pgconn, data)) + done.set() + + async def _write(self, data: bytes) -> None: + if not data: + return + + if not self._worker_done: + ev = await self._task_group.start(self.worker) + assert isinstance(ev, anyio.Event) + self._worker_done = ev + + await self._send_stream.send(data) + + async def _write_end(self) -> None: + data = self.formatter.end() + await self._write(data) + + if self._worker_done: + self._send_stream.close() + await self._worker_done.wait() + self._worker_done = None + + class Formatter(ABC): """ A class which understand a copy format (text, binary). diff --git a/psycopg/psycopg/cursor_async.py b/psycopg/psycopg/cursor_async.py index f9d7109db..fcc520595 100644 --- a/psycopg/psycopg/cursor_async.py +++ b/psycopg/psycopg/cursor_async.py @@ -159,5 +159,5 @@ async def copy(self, statement: Query) -> AsyncIterator[AsyncCopy]: async with self._conn.lock: await self._conn.wait(self._start_copy_gen(statement)) - async with AsyncCopy(self) as copy: + async with self._conn._copycls(self) as copy: yield copy diff --git a/tests/test_connection_async.py b/tests/test_connection_async.py index 9fad2b3ba..f099f97c1 100644 --- a/tests/test_connection_async.py +++ b/tests/test_connection_async.py @@ -205,9 +205,6 @@ async def test_context_inerror_rollback_no_clobber( async def test_context_active_rollback_no_clobber( dsn, asyncconnection_class, caplog ): - if asyncconnection_class != AsyncConnection: - pytest.xfail("anyio connection not implemented") - caplog.set_level(logging.WARNING, logger="psycopg") with pytest.raises(ZeroDivisionError): @@ -216,8 +213,12 @@ async def test_context_active_rollback_no_clobber( async with cur.copy( "copy (select generate_series(1, 10)) to stdout" ) as copy: - async for row in copy.rows(): - 1 / 0 + rows = copy.rows() + try: + async for row in rows: + 1 / 0 + finally: + await rows.aclose() assert len(caplog.records) == 1 rec = caplog.records[0] diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index c136c0283..5f4477c3b 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -10,6 +10,7 @@ from psycopg import pq from psycopg import sql from psycopg import errors as e +from psycopg.copy import AsyncCopy from psycopg.pq import Format from psycopg.types import TypeInfo from psycopg.adapt import PyFormat @@ -21,7 +22,10 @@ from .test_copy import eur, sample_values, sample_records, sample_tabledef from .test_copy import py_to_raw -pytestmark = pytest.mark.asyncio + +@pytest.fixture +def aconn(any_aconn): + return any_aconn @pytest.mark.parametrize("format", Format) @@ -518,6 +522,9 @@ async def test_str(aconn): [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], ) async def test_worker_life(aconn, format, buffer): + if aconn._copycls != AsyncCopy: + pytest.skip("only applicable for AsyncCopy") + cur = aconn.cursor() await ensure_table(cur, sample_tabledef) async with cur.copy( @@ -534,6 +541,7 @@ async def test_worker_life(aconn, format, buffer): @pytest.mark.slow +@pytest.mark.asyncio @pytest.mark.parametrize( "fmt, set_types", [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)], @@ -594,6 +602,7 @@ async def work(): @pytest.mark.slow +@pytest.mark.asyncio @pytest.mark.parametrize( "fmt, set_types", [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)],