Skip to content

Commit

Permalink
switch to sqlalchemy 1.4 (#299)
Browse files Browse the repository at this point in the history
* switch to sqlalchemy 1.4

* fix deprecation warning and add tests
  • Loading branch information
PrettyWood committed Aug 26, 2021
1 parent b9f35e5 commit 89efe60
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 30 deletions.
35 changes: 27 additions & 8 deletions databases/backends/aiopg.py
Expand Up @@ -7,11 +7,11 @@
import aiopg
from aiopg.sa.engine import APGCompiler_psycopg2
from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
from sqlalchemy.engine.cursor import CursorResultMetaData
from sqlalchemy.engine.interfaces import Dialect, ExecutionContext
from sqlalchemy.engine.result import ResultMetaData, RowProxy
from sqlalchemy.engine.result import Row
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.ddl import DDLElement
from sqlalchemy.types import TypeEngine

from databases.core import DatabaseURL
from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend
Expand Down Expand Up @@ -119,9 +119,15 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[typing.Mapping]:
try:
await cursor.execute(query, args)
rows = await cursor.fetchall()
metadata = ResultMetaData(context, cursor.description)
metadata = CursorResultMetaData(context, cursor.description)
return [
RowProxy(metadata, row, metadata._processors, metadata._keymap)
Row(
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style,
row,
)
for row in rows
]
finally:
Expand All @@ -136,8 +142,14 @@ async def fetch_one(self, query: ClauseElement) -> typing.Optional[typing.Mappin
row = await cursor.fetchone()
if row is None:
return None
metadata = ResultMetaData(context, cursor.description)
return RowProxy(metadata, row, metadata._processors, metadata._keymap)
metadata = CursorResultMetaData(context, cursor.description)
return Row(
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style,
row,
)
finally:
cursor.close()

Expand Down Expand Up @@ -169,9 +181,15 @@ async def iterate(
cursor = await self._connection.cursor()
try:
await cursor.execute(query, args)
metadata = ResultMetaData(context, cursor.description)
metadata = CursorResultMetaData(context, cursor.description)
async for row in cursor:
yield RowProxy(metadata, row, metadata._processors, metadata._keymap)
yield Row(
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style,
row,
)
finally:
cursor.close()

Expand All @@ -196,6 +214,7 @@ def _compile(
compiled._result_columns,
compiled._ordered_columns,
compiled._textual_ordered_columns,
compiled._loose_column_name_matching,
)
else:
args = {}
Expand Down
35 changes: 27 additions & 8 deletions databases/backends/mysql.py
Expand Up @@ -5,11 +5,11 @@

import aiomysql
from sqlalchemy.dialects.mysql import pymysql
from sqlalchemy.engine.cursor import CursorResultMetaData
from sqlalchemy.engine.interfaces import Dialect, ExecutionContext
from sqlalchemy.engine.result import ResultMetaData, RowProxy
from sqlalchemy.engine.result import Row
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.ddl import DDLElement
from sqlalchemy.types import TypeEngine

from databases.core import LOG_EXTRA, DatabaseURL
from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend
Expand Down Expand Up @@ -107,9 +107,15 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[typing.Mapping]:
try:
await cursor.execute(query, args)
rows = await cursor.fetchall()
metadata = ResultMetaData(context, cursor.description)
metadata = CursorResultMetaData(context, cursor.description)
return [
RowProxy(metadata, row, metadata._processors, metadata._keymap)
Row(
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style,
row,
)
for row in rows
]
finally:
Expand All @@ -124,8 +130,14 @@ async def fetch_one(self, query: ClauseElement) -> typing.Optional[typing.Mappin
row = await cursor.fetchone()
if row is None:
return None
metadata = ResultMetaData(context, cursor.description)
return RowProxy(metadata, row, metadata._processors, metadata._keymap)
metadata = CursorResultMetaData(context, cursor.description)
return Row(
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style,
row,
)
finally:
await cursor.close()

Expand Down Expand Up @@ -159,9 +171,15 @@ async def iterate(
cursor = await self._connection.cursor()
try:
await cursor.execute(query, args)
metadata = ResultMetaData(context, cursor.description)
metadata = CursorResultMetaData(context, cursor.description)
async for row in cursor:
yield RowProxy(metadata, row, metadata._processors, metadata._keymap)
yield Row(
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style,
row,
)
finally:
await cursor.close()

Expand All @@ -186,6 +204,7 @@ def _compile(
compiled._result_columns,
compiled._ordered_columns,
compiled._textual_ordered_columns,
compiled._loose_column_name_matching,
)
else:
args = {}
Expand Down
23 changes: 22 additions & 1 deletion databases/backends/postgres.py
Expand Up @@ -104,8 +104,29 @@ def __init__(
self._dialect = dialect
self._column_map, self._column_map_int, self._column_map_full = column_maps

@property
def _mapping(self) -> asyncpg.Record:
return self._row

def keys(self) -> typing.KeysView:
import warnings

warnings.warn(
"The `Row.keys()` method is deprecated to mimic SQLAlchemy behaviour, "
"use `Row._mapping.keys()` instead.",
DeprecationWarning,
)
return self._mapping.keys()

def values(self) -> typing.ValuesView:
return self._row.values()
import warnings

warnings.warn(
"The `Row.values()` method is deprecated to mimic SQLAlchemy behaviour, "
"use `Row._mapping.values()` instead.",
DeprecationWarning,
)
return self._mapping.values()

def __getitem__(self, key: typing.Any) -> typing.Any:
if len(self._column_map) == 0: # raw query
Expand Down
35 changes: 27 additions & 8 deletions databases/backends/sqlite.py
Expand Up @@ -4,11 +4,11 @@

import aiosqlite
from sqlalchemy.dialects.sqlite import pysqlite
from sqlalchemy.engine.cursor import CursorResultMetaData
from sqlalchemy.engine.interfaces import Dialect, ExecutionContext
from sqlalchemy.engine.result import ResultMetaData, RowProxy
from sqlalchemy.engine.result import Row
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.ddl import DDLElement
from sqlalchemy.types import TypeEngine

from databases.core import LOG_EXTRA, DatabaseURL
from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend
Expand Down Expand Up @@ -92,9 +92,15 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[typing.Mapping]:

async with self._connection.execute(query, args) as cursor:
rows = await cursor.fetchall()
metadata = ResultMetaData(context, cursor.description)
metadata = CursorResultMetaData(context, cursor.description)
return [
RowProxy(metadata, row, metadata._processors, metadata._keymap)
Row(
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style,
row,
)
for row in rows
]

Expand All @@ -106,8 +112,14 @@ async def fetch_one(self, query: ClauseElement) -> typing.Optional[typing.Mappin
row = await cursor.fetchone()
if row is None:
return None
metadata = ResultMetaData(context, cursor.description)
return RowProxy(metadata, row, metadata._processors, metadata._keymap)
metadata = CursorResultMetaData(context, cursor.description)
return Row(
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style,
row,
)

async def execute(self, query: ClauseElement) -> typing.Any:
assert self._connection is not None, "Connection is not acquired"
Expand All @@ -129,9 +141,15 @@ async def iterate(
assert self._connection is not None, "Connection is not acquired"
query, args, context = self._compile(query)
async with self._connection.execute(query, args) as cursor:
metadata = ResultMetaData(context, cursor.description)
metadata = CursorResultMetaData(context, cursor.description)
async for row in cursor:
yield RowProxy(metadata, row, metadata._processors, metadata._keymap)
yield Row(
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style,
row,
)

def transaction(self) -> TransactionBackend:
return SQLiteTransaction(self)
Expand All @@ -158,6 +176,7 @@ def _compile(
compiled._result_columns,
compiled._ordered_columns,
compiled._textual_ordered_columns,
compiled._loose_column_name_matching,
)

query_message = compiled.string.replace(" \n", " ").replace("\n", " ")
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -48,7 +48,7 @@ def get_packages(package):
packages=get_packages("databases"),
package_data={"databases": ["py.typed"]},
data_files=[("", ["LICENSE.md"])],
install_requires=['sqlalchemy<1.4', 'aiocontextvars;python_version<"3.7"'],
install_requires=['sqlalchemy>=1.4,<1.5', 'aiocontextvars;python_version<"3.7"'],
extras_require={
"postgresql": ["asyncpg"],
"mysql": ["aiomysql"],
Expand Down
55 changes: 51 additions & 4 deletions tests/test_databases.py
Expand Up @@ -3,6 +3,7 @@
import decimal
import functools
import os
import re

import pytest
import sqlalchemy
Expand Down Expand Up @@ -336,8 +337,8 @@ async def test_result_values_allow_duplicate_names(database_url):
query = "SELECT 1 AS id, 2 AS id"
row = await database.fetch_one(query=query)

assert list(row.keys()) == ["id", "id"]
assert list(row.values()) == [1, 2]
assert list(row._mapping.keys()) == ["id", "id"]
assert list(row._mapping.values()) == [1, 2]


@pytest.mark.parametrize("database_url", DATABASE_URLS)
Expand Down Expand Up @@ -981,7 +982,7 @@ async def test_iterate_outside_transaction_with_temp_table(database_url):
@async_adapter
async def test_column_names(database_url, select_query):
"""
Test that column names are exposed correctly through `.keys()` on each row.
Test that column names are exposed correctly through `._mapping.keys()` on each row.
"""
async with Database(database_url) as database:
async with database.transaction(force_rollback=True):
Expand All @@ -993,7 +994,7 @@ async def test_column_names(database_url, select_query):
results = await database.fetch_all(query=select_query)
assert len(results) == 1

assert sorted(results[0].keys()) == ["completed", "id", "text"]
assert sorted(results[0]._mapping.keys()) == ["completed", "id", "text"]
assert results[0]["text"] == "example1"
assert results[0]["completed"] == True

Expand All @@ -1014,3 +1015,49 @@ async def test_task(db):

tasks = [test_task(database) for i in range(4)]
await asyncio.gather(*tasks)


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@async_adapter
async def test_posgres_interface(database_url):
"""
Since SQLAlchemy 1.4, `Row.values()` is removed and `Row.keys()` is deprecated.
Custom postgres interface mimics more or less this behaviour by deprecating those
two methods
"""
database_url = DatabaseURL(database_url)

if database_url.scheme != "postgresql":
pytest.skip("Test is only for postgresql")

async with Database(database_url) as database:
async with database.transaction(force_rollback=True):
query = notes.insert()
values = {"text": "example1", "completed": True}
await database.execute(query, values)

query = notes.select()
result = await database.fetch_one(query=query)

with pytest.warns(
DeprecationWarning,
match=re.escape(
"The `Row.keys()` method is deprecated to mimic SQLAlchemy behaviour, "
"use `Row._mapping.keys()` instead."
),
):
assert (
list(result.keys())
== [k for k in result]
== ["id", "text", "completed"]
)

with pytest.warns(
DeprecationWarning,
match=re.escape(
"The `Row.values()` method is deprecated to mimic SQLAlchemy behaviour, "
"use `Row._mapping.values()` instead."
),
):
# avoid checking `id` at index 0 since it may change depending on the launched tests
assert list(result.values())[1:] == ["example1", True]

0 comments on commit 89efe60

Please sign in to comment.