diff --git a/databases/core.py b/databases/core.py index 8f7df452..ff208da2 100644 --- a/databases/core.py +++ b/databases/core.py @@ -14,9 +14,9 @@ from databases.interfaces import DatabaseBackend, Record if sys.version_info >= (3, 7): # pragma: no cover - import contextvars as contextvars + from contextvars import ContextVar else: # pragma: no cover - import aiocontextvars as contextvars + from aiocontextvars import ContextVar try: # pragma: no cover import click @@ -69,9 +69,7 @@ def __init__( self._backend = backend_cls(self.url, **self.options) # Connections are stored as task-local state. - self._connection_context = contextvars.ContextVar( - "connection_context" - ) # type: contextvars.ContextVar + self._connection_context = ContextVar("connection_context") # type: ContextVar # When `force_rollback=True` is used, we use a single global # connection, within a transaction that always rolls back. @@ -120,7 +118,7 @@ async def disconnect(self) -> None: self._global_transaction = None self._global_connection = None else: - self._connection_context = contextvars.ContextVar("connection_context") + self._connection_context = ContextVar("connection_context") await self._backend.disconnect() logger.info( @@ -182,11 +180,6 @@ async def iterate( async for record in connection.iterate(query, values): yield record - def _new_connection(self) -> "Connection": - connection = Connection(self._backend) - self._connection_context.set(connection) - return connection - def connection(self) -> "Connection": if self._global_connection is not None: return self._global_connection @@ -194,23 +187,14 @@ def connection(self) -> "Connection": try: return self._connection_context.get() except LookupError: - return self._new_connection() + connection = Connection(self._backend) + self._connection_context.set(connection) + return connection def transaction( self, *, force_rollback: bool = False, **kwargs: typing.Any ) -> "Transaction": - try: - connection = self._connection_context.get() - is_root = not connection._transaction_stack - if is_root: - newcontext = contextvars.copy_context() - get_conn = lambda: newcontext.run(self._new_connection) - else: - get_conn = self.connection - except LookupError: - get_conn = self.connection - - return Transaction(get_conn, force_rollback=force_rollback, **kwargs) + return Transaction(self.connection, force_rollback=force_rollback, **kwargs) @contextlib.contextmanager def force_rollback(self) -> typing.Iterator[None]: diff --git a/tests/test_databases.py b/tests/test_databases.py index 7a0b84fd..e6313e94 100644 --- a/tests/test_databases.py +++ b/tests/test_databases.py @@ -1123,25 +1123,6 @@ async def test_column_names(database_url, select_query): assert results[0]["completed"] == True -@pytest.mark.parametrize("database_url", DATABASE_URLS) -@mysql_versions -@async_adapter -async def test_parallel_transactions(database_url): - """ - Test parallel transaction execution. - """ - - async def test_task(db): - async with db.transaction(): - await db.fetch_one("SELECT 1") - - async with Database(database_url) as database: - await database.fetch_one("SELECT 1") - - tasks = [test_task(database) for i in range(4)] - await asyncio.gather(*tasks) - - @pytest.mark.parametrize("database_url", DATABASE_URLS) @async_adapter async def test_posgres_interface(database_url):