diff --git a/databases/core.py b/databases/core.py index 2bab6735..cc3668c4 100644 --- a/databases/core.py +++ b/databases/core.py @@ -5,7 +5,7 @@ import sys import typing from types import TracebackType -from urllib.parse import SplitResult, parse_qsl, urlsplit, unquote +from urllib.parse import SplitResult, parse_qsl, unquote, urlsplit from sqlalchemy import text from sqlalchemy.sql import ClauseElement @@ -14,9 +14,9 @@ from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend if sys.version_info >= (3, 7): # pragma: no cover - from contextvars import ContextVar + import contextvars as contextvars else: # pragma: no cover - from aiocontextvars import ContextVar + import aiocontextvars as contextvars try: # pragma: no cover import click @@ -68,7 +68,9 @@ def __init__( self._backend = backend_cls(self.url, **self.options) # Connections are stored as task-local state. - self._connection_context = ContextVar("connection_context") # type: ContextVar + self._connection_context = contextvars.ContextVar( + "connection_context" + ) # type: contextvars.ContextVar # When `force_rollback=True` is used, we use a single global # connection, within a transaction that always rolls back. @@ -173,6 +175,11 @@ async def iterate( async for record in connection.iterate(query, values): yield record + def _new_connection(self) -> "Connection": + connection = Connection(self._backend) + self._connection_context.set(connection) + return connection + def connection(self) -> "Connection": if self._global_connection is not None: return self._global_connection @@ -180,14 +187,23 @@ def connection(self) -> "Connection": try: return self._connection_context.get() except LookupError: - connection = Connection(self._backend) - self._connection_context.set(connection) - return connection + return self._new_connection() def transaction( self, *, force_rollback: bool = False, **kwargs: typing.Any ) -> "Transaction": - return Transaction(self.connection, force_rollback=force_rollback, **kwargs) + try: + connection = self._connection_context.get() + is_root = not connection._transaction_stack + if is_root: + newcontext = contextvars.copy_context() + get_conn = lambda: newcontext.run(self._new_connection) + else: + get_conn = self.connection + except LookupError: + get_conn = self.connection + + return Transaction(get_conn, force_rollback=force_rollback, **kwargs) @contextlib.contextmanager def force_rollback(self) -> typing.Iterator[None]: diff --git a/docs/connections_and_transactions.md b/docs/connections_and_transactions.md index dc2ff8f5..03e8cc22 100644 --- a/docs/connections_and_transactions.md +++ b/docs/connections_and_transactions.md @@ -59,12 +59,21 @@ database = Database('postgresql://localhost/example', ssl=True, min_size=5, max_ ## Transactions -Transactions are managed by async context blocks: +Transactions are managed by async context blocks. + +A transaction can be acquired from the database connection pool: ```python async with database.transaction(): ... ``` +It can also be acquired from a specific database connection: + +```python +async with database.connection() as connection: + async with connection.transaction(): + ... +``` For a lower-level transaction API: diff --git a/tests/test_database_url.py b/tests/test_database_url.py index 53e01be3..38c178aa 100644 --- a/tests/test_database_url.py +++ b/tests/test_database_url.py @@ -1,7 +1,9 @@ -from databases import DatabaseURL from urllib.parse import quote + import pytest +from databases import DatabaseURL + def test_database_url_repr(): u = DatabaseURL("postgresql://localhost/name") diff --git a/tests/test_databases.py b/tests/test_databases.py index c7317688..5aae152a 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -800,7 +800,7 @@ async def test_queries_with_expose_backend_connection(database_url): """ async with Database(database_url) as database: async with database.connection() as connection: - async with database.transaction(force_rollback=True): + async with connection.transaction(force_rollback=True): # Get the raw connection raw_connection = connection.raw_connection @@ -996,3 +996,21 @@ async def test_column_names(database_url, select_query): assert sorted(results[0].keys()) == ["completed", "id", "text"] assert results[0]["text"] == "example1" assert results[0]["completed"] == True + + +@pytest.mark.parametrize("database_url", DATABASE_URLS) +@async_adapter +async def test_parallel_transactions(database_url): + """ + Test parallel transaction execution. + """ + + async def test_task(db): + async with db.transaction(): + await db.fetch_one("SELECT 1") + + async with Database(database_url) as database: + await database.fetch_one("SELECT 1") + + tasks = [test_task(database) for i in range(4)] + await asyncio.gather(*tasks)