From c12c877eed0fc4f756c6f7dcaea17e5aaaf4007c Mon Sep 17 00:00:00 2001 From: Denis Laxalde Date: Thu, 18 Nov 2021 16:08:17 +0100 Subject: [PATCH] Add AnyIOCopy class This class overrides some methods of AsyncCopy. It uses a task group to hold the worker and communicate is done through memory object streams. We also need an event to wait for the worker to complete. We add a _copycls attribute on AsyncConnection base class and AnyIOConnection so that the copy() context manager on cursor selects to proper Copy class (either AsyncCopy or AnyIOCopy). In copy tests, like in async connections ones, we alias the "aconn" fixture to "any_aconn", so as to run tests with both asyncio and trio backends. In test_context_active_rollback_no_clobber[trio], we need to properly close the rows() async generator in order to avoid a (trio) ResourceWarning. Tests test_copy_from_leaks and test_copy_to_leaks are only valid for asyncio backend because tenacity's AsyncRetrying is implemented using asyncio. We ignore an invalid error (raised by a deprecation warning during test) about usage of the 'loop' parameter in asyncio API that is due to Python bug as mentioned in comment. --- psycopg/psycopg/connection_async.py | 7 +++ psycopg/psycopg/copy.py | 80 ++++++++++++++++++++++++++++- psycopg/psycopg/cursor_async.py | 2 +- pyproject.toml | 2 + tests/test_copy_async.py | 11 +++- 5 files changed, 99 insertions(+), 3 deletions(-) diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py index 9ea7efedc..83badee01 100644 --- a/psycopg/psycopg/connection_async.py +++ b/psycopg/psycopg/connection_async.py @@ -23,6 +23,7 @@ from .conninfo import make_conninfo, conninfo_to_dict from ._encodings import pgconn_encoding from .connection import BaseConnection, CursorRow, Notify +from .copy import AsyncCopy from .generators import notifies from .transaction import AsyncTransaction from .cursor_async import AsyncCursor @@ -46,6 +47,8 @@ class AsyncConnection(BaseConnection[Row]): server_cursor_factory: Type[AsyncServerCursor[Row]] row_factory: AsyncRowFactory[Row] + _copycls = AsyncCopy + def __init__( self, pgconn: "PGconn", @@ -388,6 +391,8 @@ async def tpc_recover(self) -> List[Xid]: else: import sniffio + from .copy import AnyIOCopy + class AnyIOConnection(AsyncConnection[Row]): """ Asynchronous wrapper for a connection to the database using AnyIO @@ -396,6 +401,8 @@ class AnyIOConnection(AsyncConnection[Row]): __module__ = "psycopg" + _copycls = AnyIOCopy + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.lock = anyio.Lock() # type: ignore[assignment] diff --git a/psycopg/psycopg/copy.py b/psycopg/psycopg/copy.py index df0f23732..afabb1ac1 100644 --- a/psycopg/psycopg/copy.py +++ b/psycopg/psycopg/copy.py @@ -13,6 +13,7 @@ from types import TracebackType from typing import TYPE_CHECKING, AsyncIterator, Iterator, Generic, Union from typing import Any, Dict, List, Match, Optional, Sequence, Type, Tuple +from typing import TypeVar from . import pq from . import errors as e @@ -313,6 +314,9 @@ def _write_end(self) -> None: self._worker = None # break the loop +_C = TypeVar("_C", bound="AsyncCopy") + + class AsyncCopy(BaseCopy["AsyncConnection[Any]"]): """Manage an asynchronous :sql:`COPY` operation.""" @@ -325,7 +329,7 @@ def __init__(self, cursor: "AsyncCursor[Any]"): ) self._worker: Optional[asyncio.Future[None]] = None - async def __aenter__(self) -> "AsyncCopy": + async def __aenter__(self: _C) -> _C: self._enter() return self @@ -406,6 +410,80 @@ async def _write_end(self) -> None: self._worker = None # break reference loops if any +try: + import anyio + import anyio.abc +except ImportError: + pass +else: + + class AnyIOCopy(AsyncCopy): + """Manage an asynchronous :sql:`COPY` operation using AnyIO + asynchronous library. + """ + + __module__ = "psycopg" + + def __init__(self, cursor: "AsyncCursor[Any]"): + super(AsyncCopy, self).__init__(cursor) + self._task_group = anyio.create_task_group() + ( + self._send_stream, + self._receive_stream, + ) = anyio.create_memory_object_stream( + max_buffer_size=self.QUEUE_SIZE, item_type=bytes + ) + self._worker_done: Optional[anyio.Event] = None + + async def __aenter__(self) -> "AnyIOCopy": + await self._task_group.__aenter__() + self = await super().__aenter__() + self._worker_done = None + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + await super().__aexit__(exc_type, exc_val, exc_tb) + await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + + async def worker( + self, + *, + task_status: anyio.abc.TaskStatus = anyio.TASK_STATUS_IGNORED, + ) -> None: + """Push data to the server when available from the receiving stream.""" + done = anyio.Event() + async with self._receive_stream: + task_status.started(done) + async for data in self._receive_stream: + await self.connection.wait(copy_to(self._pgconn, data)) + done.set() + + async def _write(self, data: bytes) -> None: + if not data: + return + + if not self._worker_done: + ev = await self._task_group.start(self.worker) + assert isinstance(ev, anyio.Event) + self._worker_done = ev + + await self._send_stream.send(data) + + async def _write_end(self) -> None: + data = self.formatter.end() + await self._write(data) + + if self._worker_done: + self._send_stream.close() + await self._worker_done.wait() + self._worker_done = None + + class Formatter(ABC): """ A class which understand a copy format (text, binary). diff --git a/psycopg/psycopg/cursor_async.py b/psycopg/psycopg/cursor_async.py index 7777d7445..744c41c5c 100644 --- a/psycopg/psycopg/cursor_async.py +++ b/psycopg/psycopg/cursor_async.py @@ -168,5 +168,5 @@ async def copy(self, statement: Query) -> AsyncIterator[AsyncCopy]: async with self._conn.lock: await self._conn.wait(self._start_copy_gen(statement)) - async with AsyncCopy(self) as copy: + async with self._conn._copycls(self) as copy: yield copy diff --git a/pyproject.toml b/pyproject.toml index b05c86604..0a20408ef 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -10,6 +10,8 @@ filterwarnings = [ "error", # The zoneinfo warning ignore is only required on Python 3.6 "ignore::DeprecationWarning:backports.zoneinfo._common", + # Workaround for Python 3.9.7 (see https://bugs.python.org/issue45097) + "ignore:The loop argument is deprecated since Python 3\\.8, and scheduled for removal in Python 3\\.10\\.:DeprecationWarning:asyncio" ] testpaths=[ "tests", diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index 717da3fa6..7d9be1a7c 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -10,6 +10,7 @@ from psycopg import pq from psycopg import sql from psycopg import errors as e +from psycopg.copy import AsyncCopy from psycopg.pq import Format from psycopg.types import TypeInfo from psycopg.adapt import PyFormat @@ -21,7 +22,10 @@ from .test_copy import eur, sample_values, sample_records, sample_tabledef from .test_copy import py_to_raw -pytestmark = pytest.mark.asyncio + +@pytest.fixture +def aconn(any_aconn): + return any_aconn @pytest.mark.parametrize("format", Format) @@ -553,6 +557,9 @@ async def test_str(aconn): [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], ) async def test_worker_life(aconn, format, buffer): + if aconn._copycls != AsyncCopy: + pytest.skip("only applicable for AsyncCopy") + cur = aconn.cursor() await ensure_table(cur, sample_tabledef) async with cur.copy( @@ -569,6 +576,7 @@ async def test_worker_life(aconn, format, buffer): @pytest.mark.slow +@pytest.mark.asyncio @pytest.mark.parametrize( "fmt, set_types", [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)], @@ -629,6 +637,7 @@ async def work(): @pytest.mark.slow +@pytest.mark.asyncio @pytest.mark.parametrize( "fmt, set_types", [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)],