Skip to content

Commit

Permalink
fix concurrency issues of parallel transactions (encode#327)
Browse files Browse the repository at this point in the history
  • Loading branch information
goteguru committed Apr 27, 2021
1 parent 22c1631 commit 0af516c
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 5 deletions.
24 changes: 20 additions & 4 deletions databases/core.py
Expand Up @@ -15,8 +15,10 @@

if sys.version_info >= (3, 7): # pragma: no cover
from contextvars import ContextVar
import contextvars as contextvars
else: # pragma: no cover
from aiocontextvars import ContextVar
import aiocontextvars as contextvars

try: # pragma: no cover
import click
Expand Down Expand Up @@ -173,21 +175,34 @@ async def iterate(
async for record in connection.iterate(query, values):
yield record

def _new_connection(self) -> "Connection":
connection = Connection(self._backend)
self._connection_context.set(connection)
return connection

def connection(self) -> "Connection":
if self._global_connection is not None:
return self._global_connection

try:
return self._connection_context.get()
except LookupError:
connection = Connection(self._backend)
self._connection_context.set(connection)
return connection
return self._new_connection()

def transaction(
self, *, force_rollback: bool = False, **kwargs: typing.Any
) -> "Transaction":
return Transaction(self.connection, force_rollback=force_rollback, **kwargs)
try:
connection = self._connection_context.get()
if not connection._transaction_stack:
newcontext = contextvars.copy_context()
get_conn = lambda: newcontext.run(self._new_connection)
else:
get_conn = self.connection
except LookupError:
get_conn = self.connection

return Transaction(get_conn, force_rollback=force_rollback, **kwargs)

@contextlib.contextmanager
def force_rollback(self) -> typing.Iterator[None]:
Expand Down Expand Up @@ -357,6 +372,7 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
async def start(self) -> "Transaction":
self._connection = self._connection_callable()
self._transaction = self._connection._connection.transaction()
logger.warning(self._connection)

async with self._connection._transaction_lock:
is_root = not self._connection._transaction_stack
Expand Down
19 changes: 18 additions & 1 deletion tests/test_databases.py
Expand Up @@ -800,7 +800,7 @@ async def test_queries_with_expose_backend_connection(database_url):
"""
async with Database(database_url) as database:
async with database.connection() as connection:
async with database.transaction(force_rollback=True):
async with connection.transaction(force_rollback=True):
# Get the raw connection
raw_connection = connection.raw_connection

Expand Down Expand Up @@ -996,3 +996,20 @@ async def test_column_names(database_url, select_query):
assert sorted(results[0].keys()) == ["completed", "id", "text"]
assert results[0]["text"] == "example1"
assert results[0]["completed"] == True

@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_parallel_transactions(database_url):
"""
Test parallel transaction execution.
"""
async def test_task(db):
async with db.transaction():
await db.fetch_one("SELECT 1")

async with Database(database_url) as database:
await database.fetch_one("SELECT 1")

tasks = [test_task(database) for i in range(4)]
await asyncio.gather(*tasks)

0 comments on commit 0af516c

Please sign in to comment.