Skip to content

Commit

Permalink
Allowing extra transaction options (#242)
Browse files Browse the repository at this point in the history
* Allowing extra transaction options

* Switching to Python 3.6-compatible asyncio primitives

* Using native SQLAlchemy engine for independent queries in tests

* Excluding postgresql+aiopg in parameterized transaction test

* Clarifying test skip comment in parameterized transaction test

* Adding missing type annotation

* Formatting with black

Co-authored-by: Phil Demetriou <inbox@philonas.net>
Co-authored-by: Vadim Markovtsev <vadim@athenian.co>
  • Loading branch information
3 people committed Sep 26, 2020
1 parent 2d0bc0f commit 932c5d1
Show file tree
Hide file tree
Showing 7 changed files with 70 additions and 11 deletions.
4 changes: 3 additions & 1 deletion databases/backends/aiopg.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,9 @@ def __init__(self, connection: AiopgConnection):
self._is_root = False
self._savepoint_name = ""

async def start(self, is_root: bool) -> None:
async def start(
self, is_root: bool, extra_options: typing.Dict[typing.Any, typing.Any]
) -> None:
assert self._connection._connection is not None, "Connection is not acquired"
self._is_root = is_root
cursor = await self._connection._connection.cursor()
Expand Down
4 changes: 3 additions & 1 deletion databases/backends/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,9 @@ def __init__(self, connection: MySQLConnection):
self._is_root = False
self._savepoint_name = ""

async def start(self, is_root: bool) -> None:
async def start(
self, is_root: bool, extra_options: typing.Dict[typing.Any, typing.Any]
) -> None:
assert self._connection._connection is not None, "Connection is not acquired"
self._is_root = is_root
if self._is_root:
Expand Down
6 changes: 4 additions & 2 deletions databases/backends/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,9 +286,11 @@ def __init__(self, connection: PostgresConnection):
None
) # type: typing.Optional[asyncpg.transaction.Transaction]

async def start(self, is_root: bool) -> None:
async def start(
self, is_root: bool, extra_options: typing.Dict[typing.Any, typing.Any]
) -> None:
assert self._connection._connection is not None, "Connection is not acquired"
self._transaction = self._connection._connection.transaction()
self._transaction = self._connection._connection.transaction(**extra_options)
await self._transaction.start()

async def commit(self) -> None:
Expand Down
4 changes: 3 additions & 1 deletion databases/backends/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,9 @@ def __init__(self, connection: SQLiteConnection):
self._is_root = False
self._savepoint_name = ""

async def start(self, is_root: bool) -> None:
async def start(
self, is_root: bool, extra_options: typing.Dict[typing.Any, typing.Any]
) -> None:
assert self._connection._connection is not None, "Connection is not acquired"
self._is_root = is_root
if self._is_root:
Expand Down
18 changes: 13 additions & 5 deletions databases/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,8 +184,10 @@ def connection(self) -> "Connection":
self._connection_context.set(connection)
return connection

def transaction(self, *, force_rollback: bool = False) -> "Transaction":
return Transaction(self.connection, force_rollback=force_rollback)
def transaction(
self, *, force_rollback: bool = False, **kwargs: typing.Any
) -> "Transaction":
return Transaction(self.connection, force_rollback=force_rollback, **kwargs)

@contextlib.contextmanager
def force_rollback(self) -> typing.Iterator[None]:
Expand Down Expand Up @@ -276,11 +278,13 @@ async def iterate(
async for record in self._connection.iterate(built_query):
yield record

def transaction(self, *, force_rollback: bool = False) -> "Transaction":
def transaction(
self, *, force_rollback: bool = False, **kwargs: typing.Any
) -> "Transaction":
def connection_callable() -> Connection:
return self

return Transaction(connection_callable, force_rollback)
return Transaction(connection_callable, force_rollback, **kwargs)

@property
def raw_connection(self) -> typing.Any:
Expand All @@ -305,9 +309,11 @@ def __init__(
self,
connection_callable: typing.Callable[[], Connection],
force_rollback: bool,
**kwargs: typing.Any,
) -> None:
self._connection_callable = connection_callable
self._force_rollback = force_rollback
self._extra_options = kwargs

async def __aenter__(self) -> "Transaction":
"""
Expand Down Expand Up @@ -355,7 +361,9 @@ async def start(self) -> "Transaction":
async with self._connection._transaction_lock:
is_root = not self._connection._transaction_stack
await self._connection.__aenter__()
await self._transaction.start(is_root=is_root)
await self._transaction.start(
is_root=is_root, extra_options=self._extra_options
)
self._connection._transaction_stack.append(self)
return self

Expand Down
4 changes: 3 additions & 1 deletion databases/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def raw_connection(self) -> typing.Any:


class TransactionBackend:
async def start(self, is_root: bool) -> None:
async def start(
self, is_root: bool, extra_options: typing.Dict[typing.Any, typing.Any]
) -> None:
raise NotImplementedError() # pragma: no cover

async def commit(self) -> None:
Expand Down
41 changes: 41 additions & 0 deletions tests/test_databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,47 @@ async def test_transaction_commit(database_url):
assert len(results) == 1


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_transaction_commit_serializable(database_url):
"""
Ensure that serializable transaction commit via extra parameters is supported.
"""

database_url = DatabaseURL(database_url)

if database_url.scheme != "postgresql":
pytest.skip("Test (currently) only supports asyncpg")

def insert_independently():
engine = sqlalchemy.create_engine(str(database_url))
conn = engine.connect()

query = notes.insert().values(text="example1", completed=True)
conn.execute(query)

def delete_independently():
engine = sqlalchemy.create_engine(str(database_url))
conn = engine.connect()

query = notes.delete()
conn.execute(query)

async with Database(database_url) as database:
async with database.transaction(force_rollback=True, isolation="serializable"):
query = notes.select()
results = await database.fetch_all(query=query)
assert len(results) == 0

insert_independently()

query = notes.select()
results = await database.fetch_all(query=query)
assert len(results) == 0

delete_independently()


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_transaction_rollback(database_url):
Expand Down

0 comments on commit 932c5d1

Please sign in to comment.