Skip to content

Commit

Permalink
Rewrite AsyncCopy with anyio
Browse files Browse the repository at this point in the history
We use memory object streams instead of a queue. A task group is created
at Async.__init__() and entered/exited when the context manager is used.

Since we drop the _worker attribute, test_worker_life() is removed.
  • Loading branch information
dlax committed Nov 17, 2021
1 parent 4b80d75 commit 3a2c831
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 40 deletions.
39 changes: 19 additions & 20 deletions psycopg/psycopg/copy.py
Expand Up @@ -7,19 +7,19 @@
import re
import queue
import struct
import asyncio
import threading
from abc import ABC, abstractmethod
from types import TracebackType
from typing import TYPE_CHECKING, AsyncIterator, Iterator, Generic, Union
from typing import Any, Dict, List, Match, Optional, Sequence, Type, Tuple

import anyio

from . import pq
from . import errors as e
from .pq import ExecStatus
from .abc import ConnectionType, PQGen, Transformer
from .adapt import PyFormat
from ._compat import create_task
from ._cmodule import _psycopg
from ._encodings import pgconn_encoding
from .generators import copy_from, copy_to, copy_end
Expand Down Expand Up @@ -300,13 +300,17 @@ class AsyncCopy(BaseCopy["AsyncConnection[Any]"]):

def __init__(self, cursor: "AsyncCursor[Any]"):
super().__init__(cursor)
self._queue: asyncio.Queue[Optional[bytes]] = asyncio.Queue(
maxsize=self.QUEUE_SIZE
)
self._worker: Optional[asyncio.Future[None]] = None
self._task_group = anyio.create_task_group()
(
self._send_stream,
self._receive_stream,
) = anyio.create_memory_object_stream(max_buffer_size=self.QUEUE_SIZE)

async def __aenter__(self) -> "AsyncCopy":
self._enter()
await self._task_group.__aenter__()
self._task_group.start_soon(self.worker)
await self._send_stream.__aenter__()
return self

async def __aexit__(
Expand All @@ -316,6 +320,8 @@ async def __aexit__(
exc_tb: Optional[TracebackType],
) -> None:
await self.finish(exc_val)
await self._send_stream.__aexit__(exc_type, exc_val, exc_tb)
await self._task_group.__aexit__(exc_type, exc_val, exc_tb)

async def __aiter__(self) -> AsyncIterator[memoryview]:
while True:
Expand Down Expand Up @@ -362,29 +368,22 @@ async def worker(self) -> None:
The function is designed to be run in a separate thread.
"""
while 1:
data = await self._queue.get()
if not data:
break
await self.connection.wait(copy_to(self._pgconn, data))
async with self._receive_stream:
async for data in self._receive_stream:
if not data:
break
await self.connection.wait(copy_to(self._pgconn, data))

async def _write(self, data: bytes) -> None:
if not data:
return

if not self._worker:
self._worker = create_task(self.worker())

await self._queue.put(data)
await self._send_stream.send(data)

async def _write_end(self) -> None:
data = self.formatter.end()
await self._write(data)
await self._queue.put(None)

if self._worker:
await asyncio.gather(self._worker)
self._worker = None # break reference loops if any
await self._send_stream.send(None)


class Formatter(ABC):
Expand Down
20 changes: 0 additions & 20 deletions tests/test_copy_async.py
Expand Up @@ -513,26 +513,6 @@ async def test_str(aconn):
assert "[INTRANS]" in str(copy)


@pytest.mark.parametrize(
"format, buffer",
[(Format.TEXT, "sample_text"), (Format.BINARY, "sample_binary")],
)
async def test_worker_life(aconn, format, buffer):
cur = aconn.cursor()
await ensure_table(cur, sample_tabledef)
async with cur.copy(
f"copy copy_in from stdin (format {format.name})"
) as copy:
assert not copy._worker
await copy.write(globals()[buffer])
assert copy._worker

assert not copy._worker
await cur.execute("select * from copy_in order by 1")
data = await cur.fetchall()
assert data == sample_records


@pytest.mark.slow
@pytest.mark.parametrize(
"fmt, set_types",
Expand Down

0 comments on commit 3a2c831

Please sign in to comment.