From 0af516c55595a20f4f1590e86b9f6cd5cb589111 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A9sz=C3=A1ros=20Gergely?= Date: Tue, 27 Apr 2021 20:16:56 +0200 Subject: [PATCH] fix concurrency issues of parallel transactions (#327) --- databases/core.py | 24 ++++++++++++++++++++---- tests/test_databases.py | 19 ++++++++++++++++++- 2 files changed, 38 insertions(+), 5 deletions(-) diff --git a/databases/core.py b/databases/core.py index 2bab6735..85d4f98b 100644 --- a/databases/core.py +++ b/databases/core.py @@ -15,8 +15,10 @@ if sys.version_info >= (3, 7): # pragma: no cover from contextvars import ContextVar + import contextvars as contextvars else: # pragma: no cover from aiocontextvars import ContextVar + import aiocontextvars as contextvars try: # pragma: no cover import click @@ -173,6 +175,11 @@ async def iterate( async for record in connection.iterate(query, values): yield record + def _new_connection(self) -> "Connection": + connection = Connection(self._backend) + self._connection_context.set(connection) + return connection + def connection(self) -> "Connection": if self._global_connection is not None: return self._global_connection @@ -180,14 +187,22 @@ def connection(self) -> "Connection": try: return self._connection_context.get() except LookupError: - connection = Connection(self._backend) - self._connection_context.set(connection) - return connection + return self._new_connection() def transaction( self, *, force_rollback: bool = False, **kwargs: typing.Any ) -> "Transaction": - return Transaction(self.connection, force_rollback=force_rollback, **kwargs) + try: + connection = self._connection_context.get() + if not connection._transaction_stack: + newcontext = contextvars.copy_context() + get_conn = lambda: newcontext.run(self._new_connection) + else: + get_conn = self.connection + except LookupError: + get_conn = self.connection + + return Transaction(get_conn, force_rollback=force_rollback, **kwargs) @contextlib.contextmanager def force_rollback(self) -> typing.Iterator[None]: @@ -357,6 +372,7 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: async def start(self) -> "Transaction": self._connection = self._connection_callable() self._transaction = self._connection._connection.transaction() + logger.warning(self._connection) async with self._connection._transaction_lock: is_root = not self._connection._transaction_stack diff --git a/tests/test_databases.py b/tests/test_databases.py index c7317688..fe726fe5 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -800,7 +800,7 @@ async def test_queries_with_expose_backend_connection(database_url): """ async with Database(database_url) as database: async with database.connection() as connection: - async with database.transaction(force_rollback=True): + async with connection.transaction(force_rollback=True): # Get the raw connection raw_connection = connection.raw_connection @@ -996,3 +996,20 @@ async def test_column_names(database_url, select_query): assert sorted(results[0].keys()) == ["completed", "id", "text"] assert results[0]["text"] == "example1" assert results[0]["completed"] == True + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_parallel_transactions(database_url): + """ + Test parallel transaction execution. + """ + async def test_task(db): + async with db.transaction(): + await db.fetch_one("SELECT 1") + + async with Database(database_url) as database: + await database.fetch_one("SELECT 1") + + tasks = [test_task(database) for i in range(4)] + await asyncio.gather(*tasks) +