Skip to content

Commit

Permalink
fix: remove connection inheritance, add more tests, update docs
Browse files Browse the repository at this point in the history
Connections are once again stored as state on the Database instance,
keyed by the current asyncio.Task. Each task acquires it's own
connection, and a WeakKeyDictionary allows the connection to be
discarded if the owning task is garbage collected. TransactionBackends
are still stored as contextvars, and a connection must be explicitly
provided to descendant tasks if active transaction state is to be
inherited.
  • Loading branch information
zevisert committed May 26, 2023
1 parent 6de4f60 commit bc28059
Show file tree
Hide file tree
Showing 3 changed files with 203 additions and 90 deletions.
44 changes: 24 additions & 20 deletions databases/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,9 @@
logger = logging.getLogger("databases")


_ACTIVE_CONNECTIONS: ContextVar[
typing.Optional["weakref.WeakKeyDictionary['Database', 'Connection']"]
] = ContextVar("databases:open_connections", default=None)
_ACTIVE_TRANSACTIONS: ContextVar[
typing.Optional["weakref.WeakKeyDictionary['Transaction', 'TransactionBackend']"]
] = ContextVar("databases:open_transactions", default=None)
] = ContextVar("databases:active_transactions", default=None)


class Database:
Expand All @@ -54,6 +51,8 @@ class Database:
"sqlite": "databases.backends.sqlite:SQLiteBackend",
}

_connection_map: "weakref.WeakKeyDictionary[asyncio.Task, 'Connection']"

def __init__(
self,
url: typing.Union[str, "DatabaseURL"],
Expand All @@ -64,6 +63,7 @@ def __init__(
self.url = DatabaseURL(url)
self.options = options
self.is_connected = False
self._connection_map = weakref.WeakKeyDictionary()

self._force_rollback = force_rollback

Expand All @@ -78,28 +78,28 @@ def __init__(
self._global_transaction: typing.Optional[Transaction] = None

@property
def _connection(self) -> typing.Optional["Connection"]:
connections = _ACTIVE_CONNECTIONS.get()
if connections is None:
return None
def _current_task(self):
task = asyncio.current_task()
if not task:
raise RuntimeError("No currently active asyncio.Task found")
return task

return connections.get(self, None)
@property
def _connection(self) -> typing.Optional["Connection"]:
return self._connection_map.get(self._current_task)

@_connection.setter
def _connection(
self, connection: typing.Optional["Connection"]
) -> typing.Optional["Connection"]:
connections = _ACTIVE_CONNECTIONS.get()
if connections is None:
connections = weakref.WeakKeyDictionary()
_ACTIVE_CONNECTIONS.set(connections)
task = self._current_task

if connection is None:
connections.pop(self, None)
self._connection_map.pop(task, None)
else:
connections[self] = connection
self._connection_map[task] = connection

return connections.get(self, None)
return self._connection

async def connect(self) -> None:
"""
Expand All @@ -119,7 +119,7 @@ async def connect(self) -> None:
assert self._global_connection is None
assert self._global_transaction is None

self._global_connection = Connection(self._backend)
self._global_connection = Connection(self, self._backend)
self._global_transaction = self._global_connection.transaction(
force_rollback=True
)
Expand Down Expand Up @@ -218,7 +218,7 @@ def connection(self) -> "Connection":
return self._global_connection

if not self._connection:
self._connection = Connection(self._backend)
self._connection = Connection(self, self._backend)

return self._connection

Expand All @@ -243,7 +243,8 @@ def _get_backend(self) -> str:


class Connection:
def __init__(self, backend: DatabaseBackend) -> None:
def __init__(self, database: Database, backend: DatabaseBackend) -> None:
self._database = database
self._backend = backend

self._connection_lock = asyncio.Lock()
Expand Down Expand Up @@ -277,6 +278,7 @@ async def __aexit__(
self._connection_counter -= 1
if self._connection_counter == 0:
await self._connection.release()
self._database._connection = None

async def fetch_all(
self,
Expand Down Expand Up @@ -393,13 +395,15 @@ def _transaction(
transactions = _ACTIVE_TRANSACTIONS.get()
if transactions is None:
transactions = weakref.WeakKeyDictionary()
_ACTIVE_TRANSACTIONS.set(transactions)
else:
transactions = transactions.copy()

if transaction is None:
transactions.pop(self, None)
else:
transactions[self] = transaction

_ACTIVE_TRANSACTIONS.set(transactions)
return transactions.get(self, None)

async def __aenter__(self) -> "Transaction":
Expand Down
23 changes: 11 additions & 12 deletions docs/connections_and_transactions.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ that transparently handles the use of either transactions or savepoints.

## Connecting and disconnecting

You can control the database connect/disconnect, by using it as a async context manager.
You can control the database connection pool with an async context manager:

```python
async with Database(DATABASE_URL) as database:
...
```

Or by using explicit connection and disconnection:
Or by using the explicit `.connect()` and `.disconnect()` methods:

```python
database = Database(DATABASE_URL)
Expand All @@ -23,6 +23,8 @@ await database.connect()
await database.disconnect()
```

Connections within this connection pool are acquired for each new `asyncio.Task`.

If you're integrating against a web framework, then you'll probably want
to hook into framework startup or shutdown events. For example, with
[Starlette][starlette] you would use the following:
Expand Down Expand Up @@ -96,12 +98,13 @@ async def create_users(request):
...
```

Transaction state is stored in the context of the currently executing asynchronous task.
This state is _inherited_ by tasks that are started from within an active transaction:
Transaction state is tied to the connection used in the currently executing asynchronous task.
If you would like to influence an active transaction from another task, the connection must be
shared. This state is _inherited_ by tasks that are share the same connection:

```python
async def add_excitement(database: Database, id: int):
await database.execute(
async def add_excitement(connnection: databases.core.Connection, id: int):
await connection.execute(
"UPDATE notes SET text = CONCAT(text, '!!!') WHERE id = :id",
{"id": id}
)
Expand All @@ -113,17 +116,13 @@ async with Database(database_url) as database:
await database.execute(
"INSERT INTO notes(id, text) values (1, 'databases is cool')"
)
# ...but child tasks inherit transaction state!
await asyncio.create_task(add_excitement(database, id=1))
# ...but child tasks can use this connection now!
await asyncio.create_task(add_excitement(database.connection(), id=1))

await database.fetch_val("SELECT text FROM notes WHERE id=1")
# ^ returns: "databases is cool!!!"
```

!!! note
In python 3.11, you can opt-out of context propagation by providing a new context to
[`asyncio.create_task`](https://docs.python.org/3.11/library/asyncio-task.html#creating-tasks).

Nested transactions are fully supported, and are implemented using database savepoints:

```python
Expand Down

0 comments on commit bc28059

Please sign in to comment.