From 388891653632f8521b6e9c9f536608ed327e4e98 Mon Sep 17 00:00:00 2001 From: Denis Laxalde Date: Thu, 18 Nov 2021 16:08:17 +0100 Subject: [PATCH] feat: add AnyIOCopy class This class overrides some methods of AsyncCopy. It uses a task group to hold the worker and communicate is done through memory object streams. We also need an event to wait for the worker to complete. We add a _copycls attribute on AsyncConnection base class and AnyIOConnection so that the copy() context manager on cursor selects to proper Copy class (either AsyncCopy or AnyIOCopy). In copy tests, like in async connections ones, we alias the "aconn" fixture to "any_aconn", so as to run tests with both asyncio and trio backends. In test_context_active_rollback_no_clobber[trio], we need to properly close the rows() async generator in order to avoid a (trio) ResourceWarning. Tests test_copy_from_leaks and test_copy_to_leaks are only valid for asyncio backend because tenacity's AsyncRetrying is implemented using asyncio. We ignore an invalid error (raised by a deprecation warning during test) about usage of the 'loop' parameter in asyncio API that is due to Python bug as mentioned in comment. --- psycopg/psycopg/connection_async.py | 7 +++ psycopg/psycopg/copy.py | 80 ++++++++++++++++++++++++++++- psycopg/psycopg/cursor_async.py | 2 +- pyproject.toml | 2 + tests/test_copy_async.py | 11 +++- 5 files changed, 99 insertions(+), 3 deletions(-) diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py index b9c5a59a0..498fb45fb 100644 --- a/psycopg/psycopg/connection_async.py +++ b/psycopg/psycopg/connection_async.py @@ -24,6 +24,7 @@ from ._pipeline import AsyncPipeline from ._encodings import pgconn_encoding from .connection import BaseConnection, CursorRow, Notify +from .copy import AsyncCopy from .generators import notifies from .transaction import AsyncTransaction from .cursor_async import AsyncCursor @@ -48,6 +49,8 @@ class AsyncConnection(BaseConnection[Row]): row_factory: AsyncRowFactory[Row] _pipeline: Optional[AsyncPipeline] + _copycls = AsyncCopy + def __init__( self, pgconn: "PGconn", @@ -417,6 +420,8 @@ async def tpc_recover(self) -> List[Xid]: else: import sniffio + from .copy import AnyIOCopy + class AnyIOConnection(AsyncConnection[Row]): """ Asynchronous wrapper for a connection to the database using AnyIO @@ -425,6 +430,8 @@ class AnyIOConnection(AsyncConnection[Row]): __module__ = "psycopg" + _copycls = AnyIOCopy + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.lock = anyio.Lock() # type: ignore[assignment] diff --git a/psycopg/psycopg/copy.py b/psycopg/psycopg/copy.py index abd7addae..ebed57e60 100644 --- a/psycopg/psycopg/copy.py +++ b/psycopg/psycopg/copy.py @@ -13,6 +13,7 @@ from types import TracebackType from typing import TYPE_CHECKING, AsyncIterator, Iterator, Generic, Union from typing import Any, Dict, List, Match, Optional, Sequence, Type, Tuple +from typing import TypeVar from . import pq from . import errors as e @@ -338,6 +339,9 @@ def _write_end(self) -> None: raise self._worker_error +_C = TypeVar("_C", bound="AsyncCopy") + + class AsyncCopy(BaseCopy["AsyncConnection[Any]"]): """Manage an asynchronous :sql:`COPY` operation.""" @@ -348,7 +352,7 @@ def __init__(self, cursor: "AsyncCursor[Any]"): self._queue: asyncio.Queue[bytes] = asyncio.Queue(maxsize=self.QUEUE_SIZE) self._worker: Optional[asyncio.Future[None]] = None - async def __aenter__(self) -> "AsyncCopy": + async def __aenter__(self: _C) -> _C: self._enter() return self @@ -437,6 +441,80 @@ async def _write_end(self) -> None: self._worker = None # break reference loops if any +try: + import anyio + import anyio.abc +except ImportError: + pass +else: + + class AnyIOCopy(AsyncCopy): + """Manage an asynchronous :sql:`COPY` operation using AnyIO + asynchronous library. + """ + + __module__ = "psycopg" + + def __init__(self, cursor: "AsyncCursor[Any]"): + super(AsyncCopy, self).__init__(cursor) + self._task_group = anyio.create_task_group() + ( + self._send_stream, + self._receive_stream, + ) = anyio.create_memory_object_stream( + max_buffer_size=self.QUEUE_SIZE, item_type=bytes + ) + self._worker_done: Optional[anyio.Event] = None + + async def __aenter__(self) -> "AnyIOCopy": + await self._task_group.__aenter__() + self = await super().__aenter__() + self._worker_done = None + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + await super().__aexit__(exc_type, exc_val, exc_tb) + await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + + async def worker( + self, + *, + task_status: anyio.abc.TaskStatus = anyio.TASK_STATUS_IGNORED, + ) -> None: + """Push data to the server when available from the receiving stream.""" + done = anyio.Event() + async with self._receive_stream: + task_status.started(done) + async for data in self._receive_stream: + await self.connection.wait(copy_to(self._pgconn, data)) + done.set() + + async def _write(self, data: bytes) -> None: + if not data: + return + + if not self._worker_done: + ev = await self._task_group.start(self.worker) + assert isinstance(ev, anyio.Event) + self._worker_done = ev + + await self._send_stream.send(data) + + async def _write_end(self) -> None: + data = self.formatter.end() + await self._write(data) + + if self._worker_done: + self._send_stream.close() + await self._worker_done.wait() + self._worker_done = None + + class Formatter(ABC): """ A class which understand a copy format (text, binary). diff --git a/psycopg/psycopg/cursor_async.py b/psycopg/psycopg/cursor_async.py index 8c4ac4d9e..7b3c144b1 100644 --- a/psycopg/psycopg/cursor_async.py +++ b/psycopg/psycopg/cursor_async.py @@ -184,7 +184,7 @@ async def copy(self, statement: Query) -> AsyncIterator[AsyncCopy]: async with self._conn.lock: await self._conn.wait(self._start_copy_gen(statement)) - async with AsyncCopy(self) as copy: + async with self._conn._copycls(self) as copy: yield copy async def _fetch_pipeline(self) -> None: diff --git a/pyproject.toml b/pyproject.toml index 72bb520dc..af8e29028 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,6 +6,8 @@ build-backend = "setuptools.build_meta" asyncio_mode = "auto" filterwarnings = [ "error", + # Workaround for Python 3.9.7 (see https://bugs.python.org/issue45097) + "ignore:The loop argument is deprecated since Python 3\\.8, and scheduled for removal in Python 3\\.10\\.:DeprecationWarning:asyncio" ] testpaths=[ "tests", diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index 0c0683da8..cb74cf911 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -11,6 +11,7 @@ from psycopg import pq from psycopg import sql from psycopg import errors as e +from psycopg.copy import AsyncCopy from psycopg.pq import Format from psycopg.types import TypeInfo from psycopg.adapt import PyFormat @@ -22,7 +23,10 @@ from .test_copy import eur, sample_values, sample_records, sample_tabledef from .test_copy import py_to_raw -pytestmark = pytest.mark.asyncio + +@pytest.fixture +def aconn(any_aconn): + return any_aconn @pytest.mark.parametrize("format", Format) @@ -578,6 +582,9 @@ async def test_description(aconn): [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], ) async def test_worker_life(aconn, format, buffer): + if aconn._copycls != AsyncCopy: + pytest.skip("only applicable for AsyncCopy") + cur = aconn.cursor() await ensure_table(cur, sample_tabledef) async with cur.copy(f"copy copy_in from stdin (format {format.name})") as copy: @@ -605,6 +612,7 @@ def copy_to_broken(pgconn, buffer): @pytest.mark.slow +@pytest.mark.asyncio @pytest.mark.parametrize( "fmt, set_types", [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)], @@ -661,6 +669,7 @@ async def work(): @pytest.mark.slow +@pytest.mark.asyncio @pytest.mark.parametrize( "fmt, set_types", [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)],