Skip to content

Commit

Permalink
S01E08
Browse files Browse the repository at this point in the history
  • Loading branch information
ansipunk committed Mar 3, 2024
1 parent 1c58a73 commit 3f26f76
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 58 deletions.
66 changes: 54 additions & 12 deletions databases/backends/asyncpg.py
Expand Up @@ -4,10 +4,11 @@
import asyncpg
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.ddl import DDLElement

from databases.backends.common.records import Record, create_column_maps
from databases.backends.dialects.psycopg import compile_query, get_dialect
from databases.core import DatabaseURL
from databases.backends.dialects.psycopg import dialect as psycopg_dialect
from databases.core import LOG_EXTRA, DatabaseURL
from databases.interfaces import (
ConnectionBackend,
DatabaseBackend,
Expand All @@ -24,9 +25,20 @@ def __init__(
) -> None:
self._database_url = DatabaseURL(database_url)
self._options = options
self._dialect = get_dialect()
self._dialect = self._get_dialect()
self._pool = None

def _get_dialect(self) -> Dialect:
dialect = psycopg_dialect(paramstyle="pyformat")
dialect.implicit_returning = True
dialect.supports_native_enum = True
dialect.supports_smallserial = True # 9.2+
dialect._backslash_escapes = False
dialect.supports_sane_multi_rowcount = True # psycopg 2.0.9+
dialect._has_native_hstore = True
dialect.supports_native_decimal = True
return dialect

def _get_connection_kwargs(self) -> dict:
url_options = self._database_url.options

Expand Down Expand Up @@ -87,15 +99,15 @@ async def release(self) -> None:

async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, result_columns = compile_query(query, self._dialect)
query_str, args, result_columns = self._compile(query)
rows = await self._connection.fetch(query_str, *args)
dialect = self._dialect
column_maps = create_column_maps(result_columns)
return [Record(row, result_columns, dialect, column_maps) for row in rows]

async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, result_columns = compile_query(query, self._dialect)
query_str, args, result_columns = self._compile(query)
row = await self._connection.fetchrow(query_str, *args)
if row is None:
return None
Expand Down Expand Up @@ -123,7 +135,7 @@ async def fetch_val(

async def execute(self, query: ClauseElement) -> typing.Any:
assert self._connection is not None, "Connection is not acquired"
query_str, args, _ = compile_query(query, self._dialect)
query_str, args, _ = self._compile(query)
return await self._connection.fetchval(query_str, *args)

async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
Expand All @@ -132,25 +144,55 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
# loop through multiple executes here, which should all end up
# using the same prepared statement.
for single_query in queries:
single_query, args, _ = compile_query(single_query, self._dialect)
single_query, args, _ = self._compile(single_query)
await self._connection.execute(single_query, *args)

async def iterate(
self, query: ClauseElement
) -> typing.AsyncGenerator[typing.Any, None]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, result_columns = compile_query(query, self._dialect)
query_str, args, result_columns = self._compile(query)
column_maps = create_column_maps(result_columns)
async for row in self._connection.cursor(query_str, *args):
yield Record(row, result_columns, self._dialect, column_maps)

def transaction(self) -> TransactionBackend:
return AsyncpgTransaction(connection=self)

@property
def raw_connection(self) -> asyncpg.connection.Connection:
assert self._connection is not None, "Connection is not acquired"
return self._connection
def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]:
compiled = query.compile(
dialect=self._dialect, compile_kwargs={"render_postcompile": True}
)

if not isinstance(query, DDLElement):
compiled_params = sorted(compiled.params.items())

mapping = {
key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1)
}
compiled_query = compiled.string % mapping

processors = compiled._bind_processors
args = [
processors[key](val) if key in processors else val
for key, val in compiled_params
]
result_map = compiled._result_columns
else:
compiled_query = compiled.string
args = []
result_map = None

query_message = compiled_query.replace(" \n", " ").replace("\n", " ")
logger.debug(
"Query: %s Args: %s", query_message, repr(tuple(args)), extra=LOG_EXTRA
)
return compiled_query, args, result_map

@property
def raw_connection(self) -> asyncpg.connection.Connection:
assert self._connection is not None, "Connection is not acquired"
return self._connection


class AsyncpgTransaction(TransactionBackend):
Expand Down
43 changes: 1 addition & 42 deletions databases/backends/dialects/psycopg.py
Expand Up @@ -10,9 +10,6 @@
from sqlalchemy import types, util
from sqlalchemy.dialects.postgresql.base import PGDialect, PGExecutionContext
from sqlalchemy.engine import processors
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.ddl import DDLElement
from sqlalchemy.types import Float, Numeric


Expand Down Expand Up @@ -46,42 +43,4 @@ class PGDialect_psycopg(PGDialect):
execution_ctx_cls = PGExecutionContext_psycopg


def get_dialect() -> Dialect:
dialect = PGDialect_psycopg(paramstyle="pyformat")
dialect.implicit_returning = True
dialect.supports_native_enum = True
dialect.supports_smallserial = True # 9.2+
dialect._backslash_escapes = False
dialect.supports_sane_multi_rowcount = True # psycopg 2.0.9+
dialect._has_native_hstore = True
dialect.supports_native_decimal = True
return dialect


def compile_query(
query: ClauseElement, dialect: Dialect
) -> typing.Tuple[str, list, tuple]:
compiled = query.compile(
dialect=dialect, compile_kwargs={"render_postcompile": True}
)

if not isinstance(query, DDLElement):
compiled_params = sorted(compiled.params.items())

mapping = {
key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1)
}
compiled_query = compiled.string % mapping

processors = compiled._bind_processors
args = [
processors[key](val) if key in processors else val
for key, val in compiled_params
]
result_map = compiled._result_columns
else:
compiled_query = compiled.string
args = []
result_map = None

return compiled_query, args, result_map
dialect = PGDialect_psycopg
5 changes: 2 additions & 3 deletions databases/backends/psycopg.py
Expand Up @@ -39,9 +39,8 @@ async def connect(self) -> None:
if self._pool is not None:
return

self._pool = psycopg_pool.AsyncConnectionPool(
self._database_url._url, open=False, **self._options
)
url = self._database_url._url.replace("postgresql+psycopg", "postgresql")
self._pool = psycopg_pool.AsyncConnectionPool(url, open=False, **self._options)

# TODO: Add configurable timeouts
await self._pool.open()
Expand Down
2 changes: 1 addition & 1 deletion databases/core.py
Expand Up @@ -44,10 +44,10 @@
class Database:
SUPPORTED_BACKENDS = {
"postgres": "databases.backends.asyncpg:AsyncpgBackend",
"postgresql": "databases.backends.asyncpg:AsyncpgBackend",
"postgresql+aiopg": "databases.backends.aiopg:AiopgBackend",
"postgresql+asyncpg": "databases.backends.asyncpg:AsyncpgBackend",
"postgresql+psycopg": "databases.backends.psycopg:PsycopgBackend",
"postgresql": "databases.backends.psycopg:PsycopgBackend",
"mysql": "databases.backends.mysql:MySQLBackend",
"mysql+aiomysql": "databases.backends.asyncmy:MySQLBackend",
"mysql+asyncmy": "databases.backends.asyncmy:AsyncMyBackend",
Expand Down

0 comments on commit 3f26f76

Please sign in to comment.