diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py index 8f538e247..fbd0fea68 100644 --- a/psycopg/psycopg/connection_async.py +++ b/psycopg/psycopg/connection_async.py @@ -22,6 +22,7 @@ from .conninfo import make_conninfo, conninfo_to_dict from ._encodings import pgconn_encoding from .connection import BaseConnection, CursorRow, Notify +from .copy import AsyncCopy from .generators import notifies from .transaction import AsyncTransaction from .cursor_async import AsyncCursor @@ -45,6 +46,8 @@ class AsyncConnection(BaseConnection[Row]): server_cursor_factory: Type[AsyncServerCursor[Row]] row_factory: AsyncRowFactory[Row] + _copycls = AsyncCopy + def __init__( self, pgconn: "PGconn", @@ -341,6 +344,7 @@ def _no_set_async(self, attribute: str) -> None: except ImportError: pass else: + from .copy import AnyIOCopy class AnyIOConnection(AsyncConnection[Row]): """ @@ -350,6 +354,8 @@ class AnyIOConnection(AsyncConnection[Row]): __module__ = "psycopg" + _copycls = AnyIOCopy + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.lock = anyio.Lock() # type: ignore[assignment] diff --git a/psycopg/psycopg/copy.py b/psycopg/psycopg/copy.py index 189f70d75..1c32009a0 100644 --- a/psycopg/psycopg/copy.py +++ b/psycopg/psycopg/copy.py @@ -13,6 +13,7 @@ from types import TracebackType from typing import TYPE_CHECKING, AsyncIterator, Iterator, Generic, Union from typing import Any, Dict, List, Match, Optional, Sequence, Type, Tuple +from typing import TypeVar from . import pq from . import errors as e @@ -291,6 +292,9 @@ def _write_end(self) -> None: self._worker = None # break the loop +_C = TypeVar("_C", bound="AsyncCopy") + + class AsyncCopy(BaseCopy["AsyncConnection[Any]"]): """Manage an asynchronous :sql:`COPY` operation.""" @@ -303,7 +307,7 @@ def __init__(self, cursor: "AsyncCursor[Any]"): ) self._worker: Optional[asyncio.Future[None]] = None - async def __aenter__(self) -> "AsyncCopy": + async def __aenter__(self: _C) -> _C: self._enter() return self @@ -385,6 +389,68 @@ async def _write_end(self) -> None: self._worker = None # break reference loops if any +try: + import anyio +except ImportError: + pass +else: + + class AnyIOCopy(AsyncCopy): + """Manage an asynchronous :sql:`COPY` operation using AnyIO + asynchronous library. + """ + + __module__ = "psycopg" + + def __init__(self, cursor: "AsyncCursor[Any]"): + super(AsyncCopy, self).__init__(cursor) + self._task_group = anyio.create_task_group() + ( + self._send_stream, + self._receive_stream, + ) = anyio.create_memory_object_stream( + max_buffer_size=self.QUEUE_SIZE, item_type=bytes + ) + + async def __aenter__(self) -> "AnyIOCopy": + self = await super().__aenter__() + await self._task_group.__aenter__() + self._event = anyio.Event() + self._task_group.start_soon(self.worker) + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + await super().__aexit__(exc_type, exc_val, exc_tb) + await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + + async def worker(self) -> None: + """Push data to the server when available from the receiving stream. + + Terminate reading when the stream receives a None. + """ + async with self._receive_stream: + async for data in self._receive_stream: + await self.connection.wait(copy_to(self._pgconn, data)) + self._event.set() + + async def _write(self, data: bytes) -> None: + if not data: + return + + await self._send_stream.send(data) + + async def _write_end(self) -> None: + data = self.formatter.end() + await self._write(data) + self._send_stream.close() + await self._event.wait() + + class Formatter(ABC): """ A class which understand a copy format (text, binary). diff --git a/psycopg/psycopg/cursor_async.py b/psycopg/psycopg/cursor_async.py index f9d7109db..fcc520595 100644 --- a/psycopg/psycopg/cursor_async.py +++ b/psycopg/psycopg/cursor_async.py @@ -159,5 +159,5 @@ async def copy(self, statement: Query) -> AsyncIterator[AsyncCopy]: async with self._conn.lock: await self._conn.wait(self._start_copy_gen(statement)) - async with AsyncCopy(self) as copy: + async with self._conn._copycls(self) as copy: yield copy diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index c136c0283..b95226362 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -3,6 +3,7 @@ import hashlib from io import BytesIO, StringIO from itertools import cycle +from typing import Any import pytest @@ -10,6 +11,7 @@ from psycopg import pq from psycopg import sql from psycopg import errors as e +from psycopg.copy import AsyncCopy from psycopg.pq import Format from psycopg.types import TypeInfo from psycopg.adapt import PyFormat @@ -21,7 +23,10 @@ from .test_copy import eur, sample_values, sample_records, sample_tabledef from .test_copy import py_to_raw -pytestmark = pytest.mark.asyncio + +@pytest.fixture +def aconn(any_aconn): + return any_aconn @pytest.mark.parametrize("format", Format) @@ -518,6 +523,9 @@ async def test_str(aconn): [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], ) async def test_worker_life(aconn, format, buffer): + if aconn._copycls != AsyncCopy: + pytest.skip("only applicable for AsyncCopy") + cur = aconn.cursor() await ensure_table(cur, sample_tabledef) async with cur.copy( @@ -539,13 +547,16 @@ async def test_worker_life(aconn, format, buffer): [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)], ) @pytest.mark.parametrize("method", ["read", "iter", "row", "rows"]) -async def test_copy_to_leaks(dsn, faker, fmt, set_types, method, retries): +async def test_copy_to_leaks( + dsn, asyncconnection_class, faker, fmt, set_types, method, retries +): faker.format = PyFormat.from_pq(fmt) faker.choose_schema(ncols=20) faker.make_records(20) async def work(): - async with await psycopg.AsyncConnection.connect(dsn) as conn: + conn: psycopg.AsyncConnection[Any] + async with await asyncconnection_class.connect(dsn) as conn: async with conn.cursor(binary=fmt) as cur: await cur.execute(faker.drop_stmt) await cur.execute(faker.create_stmt) @@ -598,13 +609,15 @@ async def work(): "fmt, set_types", [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)], ) -async def test_copy_from_leaks(dsn, faker, fmt, set_types, retries): +async def test_copy_from_leaks( + dsn, asyncconnection_class, faker, fmt, set_types, retries +): faker.format = PyFormat.from_pq(fmt) faker.choose_schema(ncols=20) faker.make_records(20) async def work(): - async with await psycopg.AsyncConnection.connect(dsn) as conn: + async with await asyncconnection_class.connect(dsn) as conn: async with conn.cursor(binary=fmt) as cur: await cur.execute(faker.drop_stmt) await cur.execute(faker.create_stmt)