From eaad9729e5b7824c0c3a004db5377d90867beaf9 Mon Sep 17 00:00:00 2001 From: Denis Laxalde Date: Thu, 18 Nov 2021 16:08:17 +0100 Subject: [PATCH] Add AnyIOCopy class (WIP) WIP: tests fail We add a _copycls attribute on AsyncConnection base class and AnyIOConnection so that the copy() context manager on cursor selects to proper Copy class (either AsyncCopy or AnyIOCopy). In copy tests, like in async connections ones, we alias the "aconn" fixture to "any_aconn", so as to run tests with both asyncio and trio backends. --- psycopg/psycopg/connection_async.py | 6 +++ psycopg/psycopg/copy.py | 68 ++++++++++++++++++++++++++++- psycopg/psycopg/cursor_async.py | 2 +- tests/test_copy_async.py | 23 +++++++--- 4 files changed, 92 insertions(+), 7 deletions(-) diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py index 8f538e247..fbd0fea68 100644 --- a/psycopg/psycopg/connection_async.py +++ b/psycopg/psycopg/connection_async.py @@ -22,6 +22,7 @@ from .conninfo import make_conninfo, conninfo_to_dict from ._encodings import pgconn_encoding from .connection import BaseConnection, CursorRow, Notify +from .copy import AsyncCopy from .generators import notifies from .transaction import AsyncTransaction from .cursor_async import AsyncCursor @@ -45,6 +46,8 @@ class AsyncConnection(BaseConnection[Row]): server_cursor_factory: Type[AsyncServerCursor[Row]] row_factory: AsyncRowFactory[Row] + _copycls = AsyncCopy + def __init__( self, pgconn: "PGconn", @@ -341,6 +344,7 @@ def _no_set_async(self, attribute: str) -> None: except ImportError: pass else: + from .copy import AnyIOCopy class AnyIOConnection(AsyncConnection[Row]): """ @@ -350,6 +354,8 @@ class AnyIOConnection(AsyncConnection[Row]): __module__ = "psycopg" + _copycls = AnyIOCopy + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.lock = anyio.Lock() # type: ignore[assignment] diff --git a/psycopg/psycopg/copy.py b/psycopg/psycopg/copy.py index 189f70d75..2ba9a3523 100644 --- a/psycopg/psycopg/copy.py +++ b/psycopg/psycopg/copy.py @@ -13,6 +13,7 @@ from types import TracebackType from typing import TYPE_CHECKING, AsyncIterator, Iterator, Generic, Union from typing import Any, Dict, List, Match, Optional, Sequence, Type, Tuple +from typing import TypeVar from . import pq from . import errors as e @@ -291,6 +292,9 @@ def _write_end(self) -> None: self._worker = None # break the loop +_C = TypeVar("_C", bound="AsyncCopy") + + class AsyncCopy(BaseCopy["AsyncConnection[Any]"]): """Manage an asynchronous :sql:`COPY` operation.""" @@ -303,7 +307,7 @@ def __init__(self, cursor: "AsyncCursor[Any]"): ) self._worker: Optional[asyncio.Future[None]] = None - async def __aenter__(self) -> "AsyncCopy": + async def __aenter__(self: _C) -> _C: self._enter() return self @@ -385,6 +389,68 @@ async def _write_end(self) -> None: self._worker = None # break reference loops if any +try: + import anyio +except ImportError: + pass +else: + + class AnyIOCopy(AsyncCopy): + """Manage an asynchronous :sql:`COPY` operation using AnyIO + asynchronous library. + """ + + __module__ = "psycopg" + + def __init__(self, cursor: "AsyncCursor[Any]"): + super(AsyncCopy, self).__init__(cursor) + self._task_group = anyio.create_task_group() + ( + self._send_stream, + self._receive_stream, + ) = anyio.create_memory_object_stream( + max_buffer_size=self.QUEUE_SIZE + ) + + async def __aenter__(self) -> "AnyIOCopy": + await self._task_group.__aenter__() + self._task_group.start_soon(self.worker) + await self._send_stream.__aenter__() + return await super().__aenter__() + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + await super().__aexit__(exc_type, exc_val, exc_tb) + await self._send_stream.__aexit__(exc_type, exc_val, exc_tb) + await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + + async def worker(self) -> None: + """Push data to the server when available from the receiving stream. + + Terminate reading when the stream receives a None. + """ + async with self._receive_stream: + async for data in self._receive_stream: + if not data: + break + await self.connection.wait(copy_to(self._pgconn, data)) + + async def _write(self, data: bytes) -> None: + if not data: + return + + await self._send_stream.send(data) + + async def _write_end(self) -> None: + data = self.formatter.end() + await self._write(data) + await self._send_stream.send(b"") + + class Formatter(ABC): """ A class which understand a copy format (text, binary). diff --git a/psycopg/psycopg/cursor_async.py b/psycopg/psycopg/cursor_async.py index f9d7109db..fcc520595 100644 --- a/psycopg/psycopg/cursor_async.py +++ b/psycopg/psycopg/cursor_async.py @@ -159,5 +159,5 @@ async def copy(self, statement: Query) -> AsyncIterator[AsyncCopy]: async with self._conn.lock: await self._conn.wait(self._start_copy_gen(statement)) - async with AsyncCopy(self) as copy: + async with self._conn._copycls(self) as copy: yield copy diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index c136c0283..b95226362 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -3,6 +3,7 @@ import hashlib from io import BytesIO, StringIO from itertools import cycle +from typing import Any import pytest @@ -10,6 +11,7 @@ from psycopg import pq from psycopg import sql from psycopg import errors as e +from psycopg.copy import AsyncCopy from psycopg.pq import Format from psycopg.types import TypeInfo from psycopg.adapt import PyFormat @@ -21,7 +23,10 @@ from .test_copy import eur, sample_values, sample_records, sample_tabledef from .test_copy import py_to_raw -pytestmark = pytest.mark.asyncio + +@pytest.fixture +def aconn(any_aconn): + return any_aconn @pytest.mark.parametrize("format", Format) @@ -518,6 +523,9 @@ async def test_str(aconn): [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], ) async def test_worker_life(aconn, format, buffer): + if aconn._copycls != AsyncCopy: + pytest.skip("only applicable for AsyncCopy") + cur = aconn.cursor() await ensure_table(cur, sample_tabledef) async with cur.copy( @@ -539,13 +547,16 @@ async def test_worker_life(aconn, format, buffer): [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)], ) @pytest.mark.parametrize("method", ["read", "iter", "row", "rows"]) -async def test_copy_to_leaks(dsn, faker, fmt, set_types, method, retries): +async def test_copy_to_leaks( + dsn, asyncconnection_class, faker, fmt, set_types, method, retries +): faker.format = PyFormat.from_pq(fmt) faker.choose_schema(ncols=20) faker.make_records(20) async def work(): - async with await psycopg.AsyncConnection.connect(dsn) as conn: + conn: psycopg.AsyncConnection[Any] + async with await asyncconnection_class.connect(dsn) as conn: async with conn.cursor(binary=fmt) as cur: await cur.execute(faker.drop_stmt) await cur.execute(faker.create_stmt) @@ -598,13 +609,15 @@ async def work(): "fmt, set_types", [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)], ) -async def test_copy_from_leaks(dsn, faker, fmt, set_types, retries): +async def test_copy_from_leaks( + dsn, asyncconnection_class, faker, fmt, set_types, retries +): faker.format = PyFormat.from_pq(fmt) faker.choose_schema(ncols=20) faker.make_records(20) async def work(): - async with await psycopg.AsyncConnection.connect(dsn) as conn: + async with await asyncconnection_class.connect(dsn) as conn: async with conn.cursor(binary=fmt) as cur: await cur.execute(faker.drop_stmt) await cur.execute(faker.create_stmt)