Skip to content

Commit

Permalink
Merge pull request #5 from encode/master
Browse files Browse the repository at this point in the history
fix concurrency of parallel transactions (encode#328)
  • Loading branch information
sthagen committed Aug 17, 2021
2 parents 8945740 + cc3246a commit 8e56577
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 11 deletions.
32 changes: 24 additions & 8 deletions databases/core.py
Expand Up @@ -5,7 +5,7 @@
import sys
import typing
from types import TracebackType
from urllib.parse import SplitResult, parse_qsl, urlsplit, unquote
from urllib.parse import SplitResult, parse_qsl, unquote, urlsplit

from sqlalchemy import text
from sqlalchemy.sql import ClauseElement
Expand All @@ -14,9 +14,9 @@
from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend

if sys.version_info >= (3, 7): # pragma: no cover
from contextvars import ContextVar
import contextvars as contextvars
else: # pragma: no cover
from aiocontextvars import ContextVar
import aiocontextvars as contextvars

try: # pragma: no cover
import click
Expand Down Expand Up @@ -68,7 +68,9 @@ def __init__(
self._backend = backend_cls(self.url, **self.options)

# Connections are stored as task-local state.
self._connection_context = ContextVar("connection_context") # type: ContextVar
self._connection_context = contextvars.ContextVar(
"connection_context"
) # type: contextvars.ContextVar

# When `force_rollback=True` is used, we use a single global
# connection, within a transaction that always rolls back.
Expand Down Expand Up @@ -173,21 +175,35 @@ async def iterate(
async for record in connection.iterate(query, values):
yield record

def _new_connection(self) -> "Connection":
connection = Connection(self._backend)
self._connection_context.set(connection)
return connection

def connection(self) -> "Connection":
if self._global_connection is not None:
return self._global_connection

try:
return self._connection_context.get()
except LookupError:
connection = Connection(self._backend)
self._connection_context.set(connection)
return connection
return self._new_connection()

def transaction(
self, *, force_rollback: bool = False, **kwargs: typing.Any
) -> "Transaction":
return Transaction(self.connection, force_rollback=force_rollback, **kwargs)
try:
connection = self._connection_context.get()
is_root = not connection._transaction_stack
if is_root:
newcontext = contextvars.copy_context()
get_conn = lambda: newcontext.run(self._new_connection)
else:
get_conn = self.connection
except LookupError:
get_conn = self.connection

return Transaction(get_conn, force_rollback=force_rollback, **kwargs)

@contextlib.contextmanager
def force_rollback(self) -> typing.Iterator[None]:
Expand Down
11 changes: 10 additions & 1 deletion docs/connections_and_transactions.md
Expand Up @@ -59,12 +59,21 @@ database = Database('postgresql://localhost/example', ssl=True, min_size=5, max_

## Transactions

Transactions are managed by async context blocks:
Transactions are managed by async context blocks.

A transaction can be acquired from the database connection pool:

```python
async with database.transaction():
...
```
It can also be acquired from a specific database connection:

```python
async with database.connection() as connection:
async with connection.transaction():
...
```

For a lower-level transaction API:

Expand Down
4 changes: 3 additions & 1 deletion tests/test_database_url.py
@@ -1,7 +1,9 @@
from databases import DatabaseURL
from urllib.parse import quote

import pytest

from databases import DatabaseURL


def test_database_url_repr():
u = DatabaseURL("postgresql://localhost/name")
Expand Down
20 changes: 19 additions & 1 deletion tests/test_databases.py
Expand Up @@ -800,7 +800,7 @@ async def test_queries_with_expose_backend_connection(database_url):
"""
async with Database(database_url) as database:
async with database.connection() as connection:
async with database.transaction(force_rollback=True):
async with connection.transaction(force_rollback=True):
# Get the raw connection
raw_connection = connection.raw_connection

Expand Down Expand Up @@ -996,3 +996,21 @@ async def test_column_names(database_url, select_query):
assert sorted(results[0].keys()) == ["completed", "id", "text"]
assert results[0]["text"] == "example1"
assert results[0]["completed"] == True


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_parallel_transactions(database_url):
"""
Test parallel transaction execution.
"""

async def test_task(db):
async with db.transaction():
await db.fetch_one("SELECT 1")

async with Database(database_url) as database:
await database.fetch_one("SELECT 1")

tasks = [test_task(database) for i in range(4)]
await asyncio.gather(*tasks)

0 comments on commit 8e56577

Please sign in to comment.