diff --git a/psycopg/psycopg/connection_async.py b/psycopg/psycopg/connection_async.py index 86be6b00d..f2e6ccd0e 100644 --- a/psycopg/psycopg/connection_async.py +++ b/psycopg/psycopg/connection_async.py @@ -23,6 +23,7 @@ from .conninfo import make_conninfo, conninfo_to_dict from ._encodings import pgconn_encoding from .connection import BaseConnection, CursorRow, Notify +from .copy import AsyncCopy from .generators import notifies from .transaction import AsyncTransaction from .cursor_async import AsyncCursor @@ -46,6 +47,8 @@ class AsyncConnection(BaseConnection[Row]): server_cursor_factory: Type[AsyncServerCursor[Row]] row_factory: AsyncRowFactory[Row] + _copycls = AsyncCopy + def __init__( self, pgconn: "PGconn", @@ -384,6 +387,8 @@ async def tpc_recover(self) -> List[Xid]: else: import sniffio + from .copy import AnyIOCopy + class AnyIOConnection(AsyncConnection[Row]): """ Asynchronous wrapper for a connection to the database using AnyIO @@ -392,6 +397,8 @@ class AnyIOConnection(AsyncConnection[Row]): __module__ = "psycopg" + _copycls = AnyIOCopy + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.lock = anyio.Lock() # type: ignore[assignment] diff --git a/psycopg/psycopg/copy.py b/psycopg/psycopg/copy.py index 189f70d75..21f7a9a36 100644 --- a/psycopg/psycopg/copy.py +++ b/psycopg/psycopg/copy.py @@ -13,6 +13,7 @@ from types import TracebackType from typing import TYPE_CHECKING, AsyncIterator, Iterator, Generic, Union from typing import Any, Dict, List, Match, Optional, Sequence, Type, Tuple +from typing import TypeVar from . import pq from . import errors as e @@ -291,6 +292,9 @@ def _write_end(self) -> None: self._worker = None # break the loop +_C = TypeVar("_C", bound="AsyncCopy") + + class AsyncCopy(BaseCopy["AsyncConnection[Any]"]): """Manage an asynchronous :sql:`COPY` operation.""" @@ -303,7 +307,7 @@ def __init__(self, cursor: "AsyncCursor[Any]"): ) self._worker: Optional[asyncio.Future[None]] = None - async def __aenter__(self) -> "AsyncCopy": + async def __aenter__(self: _C) -> _C: self._enter() return self @@ -385,6 +389,80 @@ async def _write_end(self) -> None: self._worker = None # break reference loops if any +try: + import anyio + import anyio.abc +except ImportError: + pass +else: + + class AnyIOCopy(AsyncCopy): + """Manage an asynchronous :sql:`COPY` operation using AnyIO + asynchronous library. + """ + + __module__ = "psycopg" + + def __init__(self, cursor: "AsyncCursor[Any]"): + super(AsyncCopy, self).__init__(cursor) + self._task_group = anyio.create_task_group() + ( + self._send_stream, + self._receive_stream, + ) = anyio.create_memory_object_stream( + max_buffer_size=self.QUEUE_SIZE, item_type=bytes + ) + self._worker_done: Optional[anyio.Event] = None + + async def __aenter__(self) -> "AnyIOCopy": + await self._task_group.__aenter__() + self = await super().__aenter__() + self._worker_done = None + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + await super().__aexit__(exc_type, exc_val, exc_tb) + await self._task_group.__aexit__(exc_type, exc_val, exc_tb) + + async def worker( + self, + *, + task_status: anyio.abc.TaskStatus = anyio.TASK_STATUS_IGNORED, + ) -> None: + """Push data to the server when available from the receiving stream.""" + done = anyio.Event() + async with self._receive_stream: + task_status.started(done) + async for data in self._receive_stream: + await self.connection.wait(copy_to(self._pgconn, data)) + done.set() + + async def _write(self, data: bytes) -> None: + if not data: + return + + if not self._worker_done: + ev = await self._task_group.start(self.worker) + assert isinstance(ev, anyio.Event) + self._worker_done = ev + + await self._send_stream.send(data) + + async def _write_end(self) -> None: + data = self.formatter.end() + await self._write(data) + + if self._worker_done: + self._send_stream.close() + await self._worker_done.wait() + self._worker_done = None + + class Formatter(ABC): """ A class which understand a copy format (text, binary). diff --git a/psycopg/psycopg/cursor_async.py b/psycopg/psycopg/cursor_async.py index f9d7109db..fcc520595 100644 --- a/psycopg/psycopg/cursor_async.py +++ b/psycopg/psycopg/cursor_async.py @@ -159,5 +159,5 @@ async def copy(self, statement: Query) -> AsyncIterator[AsyncCopy]: async with self._conn.lock: await self._conn.wait(self._start_copy_gen(statement)) - async with AsyncCopy(self) as copy: + async with self._conn._copycls(self) as copy: yield copy diff --git a/tests/test_connection_async.py b/tests/test_connection_async.py index 9fad2b3ba..f099f97c1 100644 --- a/tests/test_connection_async.py +++ b/tests/test_connection_async.py @@ -205,9 +205,6 @@ async def test_context_inerror_rollback_no_clobber( async def test_context_active_rollback_no_clobber( dsn, asyncconnection_class, caplog ): - if asyncconnection_class != AsyncConnection: - pytest.xfail("anyio connection not implemented") - caplog.set_level(logging.WARNING, logger="psycopg") with pytest.raises(ZeroDivisionError): @@ -216,8 +213,12 @@ async def test_context_active_rollback_no_clobber( async with cur.copy( "copy (select generate_series(1, 10)) to stdout" ) as copy: - async for row in copy.rows(): - 1 / 0 + rows = copy.rows() + try: + async for row in rows: + 1 / 0 + finally: + await rows.aclose() assert len(caplog.records) == 1 rec = caplog.records[0] diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index c136c0283..5f4477c3b 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -10,6 +10,7 @@ from psycopg import pq from psycopg import sql from psycopg import errors as e +from psycopg.copy import AsyncCopy from psycopg.pq import Format from psycopg.types import TypeInfo from psycopg.adapt import PyFormat @@ -21,7 +22,10 @@ from .test_copy import eur, sample_values, sample_records, sample_tabledef from .test_copy import py_to_raw -pytestmark = pytest.mark.asyncio + +@pytest.fixture +def aconn(any_aconn): + return any_aconn @pytest.mark.parametrize("format", Format) @@ -518,6 +522,9 @@ async def test_str(aconn): [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], ) async def test_worker_life(aconn, format, buffer): + if aconn._copycls != AsyncCopy: + pytest.skip("only applicable for AsyncCopy") + cur = aconn.cursor() await ensure_table(cur, sample_tabledef) async with cur.copy( @@ -534,6 +541,7 @@ async def test_worker_life(aconn, format, buffer): @pytest.mark.slow +@pytest.mark.asyncio @pytest.mark.parametrize( "fmt, set_types", [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)], @@ -594,6 +602,7 @@ async def work(): @pytest.mark.slow +@pytest.mark.asyncio @pytest.mark.parametrize( "fmt, set_types", [(Format.TEXT, True), (Format.TEXT, False), (Format.BINARY, True)],