From 3a2c831d2af20c2e94988d2ea0f6b52e7ed10508 Mon Sep 17 00:00:00 2001 From: Denis Laxalde Date: Tue, 9 Nov 2021 14:03:54 +0100 Subject: [PATCH] Rewrite AsyncCopy with anyio We use memory object streams instead of a queue. A task group is created at Async.__init__() and entered/exited when the context manager is used. Since we drop the _worker attribute, test_worker_life() is removed. --- psycopg/psycopg/copy.py | 39 +++++++++++++++++++-------------------- tests/test_copy_async.py | 20 -------------------- 2 files changed, 19 insertions(+), 40 deletions(-) diff --git a/psycopg/psycopg/copy.py b/psycopg/psycopg/copy.py index 572b9699d..cd9d030d0 100644 --- a/psycopg/psycopg/copy.py +++ b/psycopg/psycopg/copy.py @@ -7,19 +7,19 @@ import re import queue import struct -import asyncio import threading from abc import ABC, abstractmethod from types import TracebackType from typing import TYPE_CHECKING, AsyncIterator, Iterator, Generic, Union from typing import Any, Dict, List, Match, Optional, Sequence, Type, Tuple +import anyio + from . import pq from . import errors as e from .pq import ExecStatus from .abc import ConnectionType, PQGen, Transformer from .adapt import PyFormat -from ._compat import create_task from ._cmodule import _psycopg from ._encodings import pgconn_encoding from .generators import copy_from, copy_to, copy_end @@ -300,13 +300,17 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]): def __init__(self, cursor: "AsyncCursor[Any]"): super().__init__(cursor) - self._queue: asyncio.Queue[Optional[bytes]] = asyncio.Queue( - maxsize=self.QUEUE_SIZE - ) - self._worker: Optional[asyncio.Future[None]] = None + self._task_group = anyio.create_task_group() + ( + self._send_stream, + self._receive_stream, + ) = anyio.create_memory_object_stream(max_buffer_size=self.QUEUE_SIZE) async def __aenter__(self) -> "AsyncCopy": self._enter() + await self._task_group.__aenter__() + self._task_group.start_soon(self.worker) + await self._send_stream.__aenter__() return self async def __aexit__( @@ -316,6 +320,8 @@ async def __aexit__( exc_tb: Optional[TracebackType], ) -> None: await self.finish(exc_val) + await self._send_stream.__aexit__(exc_type, exc_val, exc_tb) + await self._task_group.__aexit__(exc_type, exc_val, exc_tb) async def __aiter__(self) -> AsyncIterator[memoryview]: while True: @@ -362,29 +368,22 @@ async def worker(self) -> None: The function is designed to be run in a separate thread. """ - while 1: - data = await self._queue.get() - if not data: - break - await self.connection.wait(copy_to(self._pgconn, data)) + async with self._receive_stream: + async for data in self._receive_stream: + if not data: + break + await self.connection.wait(copy_to(self._pgconn, data)) async def _write(self, data: bytes) -> None: if not data: return - if not self._worker: - self._worker = create_task(self.worker()) - - await self._queue.put(data) + await self._send_stream.send(data) async def _write_end(self) -> None: data = self.formatter.end() await self._write(data) - await self._queue.put(None) - - if self._worker: - await asyncio.gather(self._worker) - self._worker = None # break reference loops if any + await self._send_stream.send(None) class Formatter(ABC): diff --git a/tests/test_copy_async.py b/tests/test_copy_async.py index c136c0283..8d606c2ae 100644 --- a/tests/test_copy_async.py +++ b/tests/test_copy_async.py @@ -513,26 +513,6 @@ async def test_str(aconn): assert "[INTRANS]" in str(copy) -@pytest.mark.parametrize( - "format, buffer", - [(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")], -) -async def test_worker_life(aconn, format, buffer): - cur = aconn.cursor() - await ensure_table(cur, sample_tabledef) - async with cur.copy( - f"copy copy_in from stdin (format {format.name})" - ) as copy: - assert not copy._worker - await copy.write(globals()[buffer]) - assert copy._worker - - assert not copy._worker - await cur.execute("select * from copy_in order by 1") - data = await cur.fetchall() - assert data == sample_records - - @pytest.mark.slow @pytest.mark.parametrize( "fmt, set_types",