Skip to content

Commit

Permalink
build(deps): switch to sqlalchemy 1.4
Browse files Browse the repository at this point in the history
  • Loading branch information
PrettyWood committed May 13, 2021
1 parent 22c1631 commit 9d6e0c0
Show file tree
Hide file tree
Showing 9 changed files with 111 additions and 33 deletions.
35 changes: 27 additions & 8 deletions databases/backends/aiopg.py
Expand Up @@ -7,11 +7,11 @@
import aiopg
from aiopg.sa.engine import APGCompiler_psycopg2
from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
from sqlalchemy.engine.cursor import CursorResultMetaData
from sqlalchemy.engine.interfaces import Dialect, ExecutionContext
from sqlalchemy.engine.result import ResultMetaData, RowProxy
from sqlalchemy.engine.result import Row
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.ddl import DDLElement
from sqlalchemy.types import TypeEngine

from databases.core import DatabaseURL
from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend
Expand Down Expand Up @@ -119,9 +119,15 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[typing.Mapping]:
try:
await cursor.execute(query, args)
rows = await cursor.fetchall()
metadata = ResultMetaData(context, cursor.description)
metadata = CursorResultMetaData(context, cursor.description)
return [
RowProxy(metadata, row, metadata._processors, metadata._keymap)
Row(
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style,
row,
)
for row in rows
]
finally:
Expand All @@ -136,8 +142,14 @@ async def fetch_one(self, query: ClauseElement) -> typing.Optional[typing.Mappin
row = await cursor.fetchone()
if row is None:
return None
metadata = ResultMetaData(context, cursor.description)
return RowProxy(metadata, row, metadata._processors, metadata._keymap)
metadata = CursorResultMetaData(context, cursor.description)
return Row(
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style,
row,
)
finally:
cursor.close()

Expand Down Expand Up @@ -169,9 +181,15 @@ async def iterate(
cursor = await self._connection.cursor()
try:
await cursor.execute(query, args)
metadata = ResultMetaData(context, cursor.description)
metadata = CursorResultMetaData(context, cursor.description)
async for row in cursor:
yield RowProxy(metadata, row, metadata._processors, metadata._keymap)
yield Row(
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style,
row,
)
finally:
cursor.close()

Expand All @@ -196,6 +214,7 @@ def _compile(
compiled._result_columns,
compiled._ordered_columns,
compiled._textual_ordered_columns,
compiled._loose_column_name_matching,
)
else:
args = {}
Expand Down
35 changes: 27 additions & 8 deletions databases/backends/mysql.py
Expand Up @@ -5,11 +5,11 @@

import aiomysql
from sqlalchemy.dialects.mysql import pymysql
from sqlalchemy.engine.cursor import CursorResultMetaData
from sqlalchemy.engine.interfaces import Dialect, ExecutionContext
from sqlalchemy.engine.result import ResultMetaData, RowProxy
from sqlalchemy.engine.result import Row
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.ddl import DDLElement
from sqlalchemy.types import TypeEngine

from databases.core import LOG_EXTRA, DatabaseURL
from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend
Expand Down Expand Up @@ -107,9 +107,15 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[typing.Mapping]:
try:
await cursor.execute(query, args)
rows = await cursor.fetchall()
metadata = ResultMetaData(context, cursor.description)
metadata = CursorResultMetaData(context, cursor.description)
return [
RowProxy(metadata, row, metadata._processors, metadata._keymap)
Row(
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style,
row,
)
for row in rows
]
finally:
Expand All @@ -124,8 +130,14 @@ async def fetch_one(self, query: ClauseElement) -> typing.Optional[typing.Mappin
row = await cursor.fetchone()
if row is None:
return None
metadata = ResultMetaData(context, cursor.description)
return RowProxy(metadata, row, metadata._processors, metadata._keymap)
metadata = CursorResultMetaData(context, cursor.description)
return Row(
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style,
row,
)
finally:
await cursor.close()

Expand Down Expand Up @@ -159,9 +171,15 @@ async def iterate(
cursor = await self._connection.cursor()
try:
await cursor.execute(query, args)
metadata = ResultMetaData(context, cursor.description)
metadata = CursorResultMetaData(context, cursor.description)
async for row in cursor:
yield RowProxy(metadata, row, metadata._processors, metadata._keymap)
yield Row(
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style,
row,
)
finally:
await cursor.close()

Expand All @@ -186,6 +204,7 @@ def _compile(
compiled._result_columns,
compiled._ordered_columns,
compiled._textual_ordered_columns,
compiled._loose_column_name_matching,
)
else:
args = {}
Expand Down
21 changes: 20 additions & 1 deletion databases/backends/postgres.py
Expand Up @@ -104,8 +104,27 @@ def __init__(
self._dialect = dialect
self._column_map, self._column_map_int, self._column_map_full = column_maps

@property
def _mapping(self) -> asyncpg.Record:
return self._row

def keys(self) -> typing.KeysView:
import warnings

warnings.warn(
"The `Row.keys()` method is deprecated to mimic SQLAlchemy behaviour, "
"use `Row._mapping.keys()` instead."
)
return self._mapping.keys()

def values(self) -> typing.ValuesView:
return self._row.values()
import warnings

warnings.warn(
"The `Row.values()` method is deprecated to mimic SQLAlchemy behaviour, "
"use `Row._mapping.values()` instead."
)
return self._mapping.values()

def __getitem__(self, key: typing.Any) -> typing.Any:
if len(self._column_map) == 0: # raw query
Expand Down
35 changes: 27 additions & 8 deletions databases/backends/sqlite.py
Expand Up @@ -4,11 +4,11 @@

import aiosqlite
from sqlalchemy.dialects.sqlite import pysqlite
from sqlalchemy.engine.cursor import CursorResultMetaData
from sqlalchemy.engine.interfaces import Dialect, ExecutionContext
from sqlalchemy.engine.result import ResultMetaData, RowProxy
from sqlalchemy.engine.result import Row
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.ddl import DDLElement
from sqlalchemy.types import TypeEngine

from databases.core import LOG_EXTRA, DatabaseURL
from databases.interfaces import ConnectionBackend, DatabaseBackend, TransactionBackend
Expand Down Expand Up @@ -92,9 +92,15 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[typing.Mapping]:

async with self._connection.execute(query, args) as cursor:
rows = await cursor.fetchall()
metadata = ResultMetaData(context, cursor.description)
metadata = CursorResultMetaData(context, cursor.description)
return [
RowProxy(metadata, row, metadata._processors, metadata._keymap)
Row(
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style,
row,
)
for row in rows
]

Expand All @@ -106,8 +112,14 @@ async def fetch_one(self, query: ClauseElement) -> typing.Optional[typing.Mappin
row = await cursor.fetchone()
if row is None:
return None
metadata = ResultMetaData(context, cursor.description)
return RowProxy(metadata, row, metadata._processors, metadata._keymap)
metadata = CursorResultMetaData(context, cursor.description)
return Row(
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style,
row,
)

async def execute(self, query: ClauseElement) -> typing.Any:
assert self._connection is not None, "Connection is not acquired"
Expand All @@ -129,9 +141,15 @@ async def iterate(
assert self._connection is not None, "Connection is not acquired"
query, args, context = self._compile(query)
async with self._connection.execute(query, args) as cursor:
metadata = ResultMetaData(context, cursor.description)
metadata = CursorResultMetaData(context, cursor.description)
async for row in cursor:
yield RowProxy(metadata, row, metadata._processors, metadata._keymap)
yield Row(
metadata,
metadata._processors,
metadata._keymap,
Row._default_key_style,
row,
)

def transaction(self) -> TransactionBackend:
return SQLiteTransaction(self)
Expand All @@ -158,6 +176,7 @@ def _compile(
compiled._result_columns,
compiled._ordered_columns,
compiled._textual_ordered_columns,
compiled._loose_column_name_matching,
)

query_message = compiled.string.replace(" \n", " ").replace("\n", " ")
Expand Down
2 changes: 1 addition & 1 deletion databases/core.py
Expand Up @@ -5,7 +5,7 @@
import sys
import typing
from types import TracebackType
from urllib.parse import SplitResult, parse_qsl, urlsplit, unquote
from urllib.parse import SplitResult, parse_qsl, unquote, urlsplit

from sqlalchemy import text
from sqlalchemy.sql import ClauseElement
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Expand Up @@ -22,4 +22,4 @@ mypy
pytest
pytest-cov
starlette
requests
requests
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -48,7 +48,7 @@ def get_packages(package):
packages=get_packages("databases"),
package_data={"databases": ["py.typed"]},
data_files=[("", ["LICENSE.md"])],
install_requires=['sqlalchemy<1.4', 'aiocontextvars;python_version<"3.7"'],
install_requires=['sqlalchemy>=1.4,<1.5', 'aiocontextvars;python_version<"3.7"'],
extras_require={
"postgresql": ["asyncpg"],
"mysql": ["aiomysql"],
Expand Down
4 changes: 3 additions & 1 deletion tests/test_database_url.py
@@ -1,7 +1,9 @@
from databases import DatabaseURL
from urllib.parse import quote

import pytest

from databases import DatabaseURL


def test_database_url_repr():
u = DatabaseURL("postgresql://localhost/name")
Expand Down
8 changes: 4 additions & 4 deletions tests/test_databases.py
Expand Up @@ -336,8 +336,8 @@ async def test_result_values_allow_duplicate_names(database_url):
query = "SELECT 1 AS id, 2 AS id"
row = await database.fetch_one(query=query)

assert list(row.keys()) == ["id", "id"]
assert list(row.values()) == [1, 2]
assert list(row._mapping.keys()) == ["id", "id"]
assert list(row._mapping.values()) == [1, 2]


@pytest.mark.parametrize("database_url", DATABASE_URLS)
Expand Down Expand Up @@ -981,7 +981,7 @@ async def test_iterate_outside_transaction_with_temp_table(database_url):
@async_adapter
async def test_column_names(database_url, select_query):
"""
Test that column names are exposed correctly through `.keys()` on each row.
Test that column names are exposed correctly through `._mapping.keys()` on each row.
"""
async with Database(database_url) as database:
async with database.transaction(force_rollback=True):
Expand All @@ -993,6 +993,6 @@ async def test_column_names(database_url, select_query):
results = await database.fetch_all(query=select_query)
assert len(results) == 1

assert sorted(results[0].keys()) == ["completed", "id", "text"]
assert sorted(results[0]._mapping.keys()) == ["completed", "id", "text"]
assert results[0]["text"] == "example1"
assert results[0]["completed"] == True

0 comments on commit 9d6e0c0

Please sign in to comment.