Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: incorrect concurrent usage of connection and transaction #546

Merged
merged 20 commits into from
Jul 25, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
90c33da
fix: incorrect concurrent usage of connection and transaction
zevisert Apr 7, 2023
bea6629
refactor: rename contextvar class attributes, add some explaination c…
zevisert Apr 10, 2023
c9e3464
fix: contextvar.get takes no keyword arguments
zevisert Apr 10, 2023
f3078aa
test: add concurrent task tests
zevisert Apr 11, 2023
75969d3
feat: use ContextVar[dict] to track connections and transactions per …
zevisert Apr 11, 2023
4cd7451
test: check multiple databases in the same task use independant conne…
zevisert Apr 11, 2023
e4c95a7
chore: changes for linting and typechecking
zevisert Apr 11, 2023
a38e135
chore: use typing.Tuple for lower python version compatibility
zevisert Apr 11, 2023
460f72e
docs: update comment on _connection_contextmap
zevisert Apr 11, 2023
2d4554d
Update `Connection` and `Transaction` to be robust to concurrent use
zanieb Apr 16, 2023
16403c3
Merge remote-tracking branch 'madkinsz/example/instance-safe' into fi…
zevisert Apr 17, 2023
8370299
chore: remove optional annotation on asyncio.Task
zevisert Apr 18, 2023
1d4896f
test: add new tests for upcoming contextvar inheritance/isolation and…
zevisert May 24, 2023
02a9acb
feat: reimplement concurrency system with contextvar and weakmap
zevisert May 24, 2023
0f93807
chore: apply corrections from linters
zevisert May 24, 2023
f091482
fix: quote WeakKeyDictionary typing for python<=3.7
zevisert May 24, 2023
6fb55a5
docs: add examples for async transaction context and nested transactions
zevisert May 25, 2023
6de4f60
Merge remote-tracking branch 'upstream/master' into fix-transaction-c…
zevisert May 25, 2023
b94f097
fix: remove connection inheritance, add more tests, update docs
zevisert May 26, 2023
0a9e9e5
Merge branch 'master' into fix-transaction-contextvar
zanieb Jul 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
79 changes: 48 additions & 31 deletions databases/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,14 @@
import functools
import logging
import typing
from contextvars import ContextVar
from types import TracebackType
from urllib.parse import SplitResult, parse_qsl, unquote, urlsplit

from sqlalchemy import text
from sqlalchemy.sql import ClauseElement

from databases.importer import import_from_string
from databases.interfaces import DatabaseBackend, Record
from databases.interfaces import DatabaseBackend, Record, TransactionBackend

try: # pragma: no cover
import click
Expand Down Expand Up @@ -63,8 +62,8 @@ def __init__(
assert issubclass(backend_cls, DatabaseBackend)
self._backend = backend_cls(self.url, **self.options)

# Connections are stored as task-local state.
self._connection_context: ContextVar = ContextVar("connection_context")
# Connections are stored per asyncio task
self._connections: typing.Dict[asyncio.Task, Connection] = {}
zevisert marked this conversation as resolved.
Show resolved Hide resolved

# When `force_rollback=True` is used, we use a single global
# connection, within a transaction that always rolls back.
Expand Down Expand Up @@ -113,7 +112,10 @@ async def disconnect(self) -> None:
self._global_transaction = None
self._global_connection = None
else:
self._connection_context = ContextVar("connection_context")
current_task = asyncio.current_task()
assert current_task is not None, "No currently running task"
if current_task in self._connections:
del self._connections[current_task]
zevisert marked this conversation as resolved.
Show resolved Hide resolved
zevisert marked this conversation as resolved.
Show resolved Hide resolved

await self._backend.disconnect()
logger.info(
Expand Down Expand Up @@ -187,12 +189,12 @@ def connection(self) -> "Connection":
if self._global_connection is not None:
return self._global_connection

try:
return self._connection_context.get()
except LookupError:
connection = Connection(self._backend)
self._connection_context.set(connection)
return connection
current_task = asyncio.current_task()
assert current_task is not None, "No currently running task"
if current_task not in self._connections:
self._connections[current_task] = Connection(self._backend)

return self._connections[current_task]

def transaction(
self, *, force_rollback: bool = False, **kwargs: typing.Any
Expand Down Expand Up @@ -345,6 +347,11 @@ def __init__(
self._force_rollback = force_rollback
self._extra_options = kwargs

# Transactions are stored per asyncio task
self._transactions: typing.Dict[
typing.Optional[asyncio.Task], TransactionBackend
] = {}

async def __aenter__(self) -> "Transaction":
"""
Called when entering `async with database.transaction()`
Expand Down Expand Up @@ -385,31 +392,41 @@ async def wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any:
return wrapper # type: ignore

async def start(self) -> "Transaction":
self._connection = self._connection_callable()
self._transaction = self._connection._connection.transaction()

async with self._connection._transaction_lock:
is_root = not self._connection._transaction_stack
await self._connection.__aenter__()
await self._transaction.start(
is_root=is_root, extra_options=self._extra_options
)
self._connection._transaction_stack.append(self)
connection = self._connection_callable()
current_task = asyncio.current_task()
assert current_task is not None, "No currently running task"
transaction = connection._connection.transaction()
self._transactions[current_task] = transaction
async with connection._transaction_lock:
is_root = not connection._transaction_stack
await connection.__aenter__()
await transaction.start(is_root=is_root, extra_options=self._extra_options)
connection._transaction_stack.append(self)
return self

async def commit(self) -> None:
async with self._connection._transaction_lock:
assert self._connection._transaction_stack[-1] is self
self._connection._transaction_stack.pop()
await self._transaction.commit()
await self._connection.__aexit__()
connection = self._connection_callable()
current_task = asyncio.current_task()
transaction = self._transactions.get(current_task, None)
assert transaction is not None, "Transaction not found in current task"
async with connection._transaction_lock:
assert connection._transaction_stack[-1] is self
connection._transaction_stack.pop()
await transaction.commit()
await connection.__aexit__()
del self._transactions[current_task]

async def rollback(self) -> None:
async with self._connection._transaction_lock:
assert self._connection._transaction_stack[-1] is self
self._connection._transaction_stack.pop()
await self._transaction.rollback()
await self._connection.__aexit__()
connection = self._connection_callable()
current_task = asyncio.current_task()
transaction = self._transactions.get(current_task, None)
assert transaction is not None, "Transaction not found in current task"
async with connection._transaction_lock:
assert connection._transaction_stack[-1] is self
connection._transaction_stack.pop()
await transaction.rollback()
await connection.__aexit__()
del self._transactions[current_task]


class _EmptyNetloc(str):
Expand Down
84 changes: 71 additions & 13 deletions tests/test_databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import datetime
import decimal
import functools
import itertools
import os
import re
from unittest.mock import MagicMock, patch
Expand Down Expand Up @@ -789,15 +790,16 @@ async def test_connect_and_disconnect(database_url):

@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_connection_context(database_url):
"""
Test connection contexts are task-local.
"""
async def test_connection_context_same_task(database_url):
async with Database(database_url) as database:
async with database.connection() as connection_1:
async with database.connection() as connection_2:
assert connection_1 is connection_2


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_connection_context_multiple_tasks(database_url):
async with Database(database_url) as database:
connection_1 = None
connection_2 = None
Expand All @@ -817,9 +819,8 @@ async def get_connection_2():
connection_2 = connection
await test_complete.wait()

loop = asyncio.get_event_loop()
task_1 = loop.create_task(get_connection_1())
task_2 = loop.create_task(get_connection_2())
task_1 = asyncio.create_task(get_connection_1())
task_2 = asyncio.create_task(get_connection_2())
while connection_1 is None or connection_2 is None:
await asyncio.sleep(0.000001)
assert connection_1 is not connection_2
Expand All @@ -828,6 +829,20 @@ async def get_connection_2():
await task_2


@pytest.mark.parametrize(
"database_url1,database_url2",
(
pytest.param(db1, db2, id=f"{db1} | {db2}")
for (db1, db2) in itertools.combinations(DATABASE_URLS, 2)
),
)
@async_adapter
async def test_connection_context_multiple_databases(database_url1, database_url2):
async with Database(database_url1) as database1:
async with Database(database_url2) as database2:
assert database1.connection() is not database2.connection()


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_connection_context_with_raw_connection(database_url):
Expand Down Expand Up @@ -961,16 +976,59 @@ async def test_database_url_interface(database_url):
@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_concurrent_access_on_single_connection(database_url):
database_url = DatabaseURL(database_url)
if database_url.dialect != "postgresql":
pytest.skip("Test requires `pg_sleep()`")

async with Database(database_url, force_rollback=True) as database:

async def db_lookup():
await database.fetch_one("SELECT pg_sleep(1)")
await database.fetch_one("SELECT 1 AS value")

await asyncio.gather(
db_lookup(),
db_lookup(),
)


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_concurrent_transactions_on_single_connection(database_url: str):
async with Database(database_url) as database:

@database.transaction()
async def db_lookup():
await database.fetch_one(query="SELECT 1 AS value")

await asyncio.gather(
db_lookup(),
db_lookup(),
)


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_concurrent_tasks_on_single_connection(database_url: str):
async with Database(database_url) as database:

async def db_lookup():
await database.fetch_one(query="SELECT 1 AS value")

await asyncio.gather(
asyncio.create_task(db_lookup()),
asyncio.create_task(db_lookup()),
)


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_concurrent_task_transactions_on_single_connection(database_url: str):
async with Database(database_url) as database:

@database.transaction()
async def db_lookup():
await database.fetch_one(query="SELECT 1 AS value")

await asyncio.gather(db_lookup(), db_lookup())
await asyncio.gather(
asyncio.create_task(db_lookup()),
asyncio.create_task(db_lookup()),
)


@pytest.mark.parametrize("database_url", DATABASE_URLS)
Expand Down