Skip to content

Commit

Permalink
S01E01
Browse files Browse the repository at this point in the history
  • Loading branch information
ansipunk committed Mar 2, 2024
1 parent 8cbcccb commit a86edfa
Show file tree
Hide file tree
Showing 3 changed files with 70 additions and 56 deletions.
45 changes: 7 additions & 38 deletions databases/backends/asyncpg.py
Expand Up @@ -4,11 +4,10 @@
import asyncpg
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.ddl import DDLElement

from databases.backends.common.records import Record, create_column_maps
from databases.backends.dialects.psycopg import get_dialect
from databases.core import LOG_EXTRA, DatabaseURL
from databases.backends.dialects.psycopg import compile_query, get_dialect
from databases.core import DatabaseURL
from databases.interfaces import (
ConnectionBackend,
DatabaseBackend,
Expand Down Expand Up @@ -88,15 +87,15 @@ async def release(self) -> None:

async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, result_columns = self._compile(query)
query_str, args, result_columns = compile_query(query, self._dialect)
rows = await self._connection.fetch(query_str, *args)
dialect = self._dialect
column_maps = create_column_maps(result_columns)
return [Record(row, result_columns, dialect, column_maps) for row in rows]

async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, result_columns = self._compile(query)
query_str, args, result_columns = compile_query(query, self._dialect)
row = await self._connection.fetchrow(query_str, *args)
if row is None:
return None
Expand Down Expand Up @@ -124,7 +123,7 @@ async def fetch_val(

async def execute(self, query: ClauseElement) -> typing.Any:
assert self._connection is not None, "Connection is not acquired"
query_str, args, _ = self._compile(query)
query_str, args, _ = compile_query(query, self._dialect)
return await self._connection.fetchval(query_str, *args)

async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
Expand All @@ -133,51 +132,21 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
# loop through multiple executes here, which should all end up
# using the same prepared statement.
for single_query in queries:
single_query, args, _ = self._compile(single_query)
single_query, args, _ = compile_query(single_query, self._dialect)
await self._connection.execute(single_query, *args)

async def iterate(
self, query: ClauseElement
) -> typing.AsyncGenerator[typing.Any, None]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, result_columns = self._compile(query)
query_str, args, result_columns = compile_query(query, self._dialect)
column_maps = create_column_maps(result_columns)
async for row in self._connection.cursor(query_str, *args):
yield Record(row, result_columns, self._dialect, column_maps)

def transaction(self) -> TransactionBackend:
return AsyncpgTransaction(connection=self)

def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]:
compiled = query.compile(
dialect=self._dialect, compile_kwargs={"render_postcompile": True}
)

if not isinstance(query, DDLElement):
compiled_params = sorted(compiled.params.items())

mapping = {
key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1)
}
compiled_query = compiled.string % mapping

processors = compiled._bind_processors
args = [
processors[key](val) if key in processors else val
for key, val in compiled_params
]
result_map = compiled._result_columns
else:
compiled_query = compiled.string
args = []
result_map = None

query_message = compiled_query.replace(" \n", " ").replace("\n", " ")
logger.debug(
"Query: %s Args: %s", query_message, repr(tuple(args)), extra=LOG_EXTRA
)
return compiled_query, args, result_map

@property
def raw_connection(self) -> asyncpg.connection.Connection:
assert self._connection is not None, "Connection is not acquired"
Expand Down
30 changes: 29 additions & 1 deletion databases/backends/dialects/psycopg.py
Expand Up @@ -10,6 +10,9 @@
from sqlalchemy import types, util
from sqlalchemy.dialects.postgresql.base import PGDialect, PGExecutionContext
from sqlalchemy.engine import processors
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.ddl import DDLElement
from sqlalchemy.types import Float, Numeric


Expand Down Expand Up @@ -43,7 +46,7 @@ class PGDialect_psycopg(PGDialect):
execution_ctx_cls = PGExecutionContext_psycopg


def get_dialect() -> PGDialect_psycopg:
def get_dialect() -> Dialect:
dialect = PGDialect_psycopg(paramstyle="pyformat")
dialect.implicit_returning = True
dialect.supports_native_enum = True
Expand All @@ -53,3 +56,28 @@ def get_dialect() -> PGDialect_psycopg:
dialect._has_native_hstore = True
dialect.supports_native_decimal = True
return dialect


def compile_query(query: ClauseElement, dialect: Dialect) -> typing.Tuple[str, list, tuple]:
compiled = query.compile(dialect=dialect, compile_kwargs={"render_postcompile": True})

if not isinstance(query, DDLElement):
compiled_params = sorted(compiled.params.items())

mapping = {
key: "$" + str(i) for i, (key, _) in enumerate(compiled_params, start=1)
}
compiled_query = compiled.string % mapping

processors = compiled._bind_processors
args = [
processors[key](val) if key in processors else val
for key, val in compiled_params
]
result_map = compiled._result_columns
else:
compiled_query = compiled.string
args = []
result_map = None

return compiled_query, args, result_map
51 changes: 34 additions & 17 deletions databases/backends/psycopg.py
@@ -1,49 +1,68 @@
import typing
from collections.abc import Sequence

import psycopg
import psycopg_pool
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.sql import ClauseElement

from databases.backends.dialects.psycopg import get_dialect
from databases.backends.common.records import Record, create_column_maps
from databases.backends.dialects.psycopg import compile_query, get_dialect
from databases.core import DatabaseURL
from databases.interfaces import (
ConnectionBackend,
DatabaseBackend,
Record as RecordInterface,
TransactionBackend,
)


class PsycopgBackend(DatabaseBackend):
_database_url: DatabaseURL
_options: typing.Dict[str, typing.Any]
_dialect: Dialect
_pool: typing.Optional[psycopg_pool.AsyncConnectionPool]

def __init__(
self, database_url: typing.Union[DatabaseURL, str], **options: typing.Any
self,
database_url: typing.Union[DatabaseURL, str],
**options: typing.Dict[str, typing.Any],
) -> None:
self._database_url = DatabaseURL(database_url)
self._options = options
self._dialect = get_dialect()
self._pool: typing.Optional[psycopg_pool.AsyncConnectionPool] = None
self._pool = None

async def connect(self) -> None:
if self._pool is not None:
return

self._pool = psycopg_pool.AsyncConnectionPool(
self._database_url.url, open=False, **self._options)

# TODO: Add configurable timeouts
await self._pool.open()

async def disconnect(self) -> None:
if self._pool is None:
return

# TODO: Add configurable timeouts
await self._pool.close()
self._pool = None

def connection(self) -> "PsycopgConnection":
return PsycopgConnection(self)
return PsycopgConnection(self, self._dialect)


class PsycopgConnection(ConnectionBackend):
def __init__(self, database: PsycopgBackend) -> None:
_database: PsycopgBackend
_dialect: Dialect
_connection: typing.Optional[psycopg.AsyncConnection]

def __init__(self, database: PsycopgBackend, dialect: Dialect) -> None:
self._database = database
self._dialect = dialect
self._connection = None

async def acquire(self) -> None:
if self._connection is not None:
Expand All @@ -62,10 +81,17 @@ async def release(self) -> None:
await self._database._pool.putconn(self._connection)
self._connection = None

async def fetch_all(self, query: ClauseElement) -> typing.List["Record"]:
async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]:
if self._connection is None:
raise RuntimeError("Connection is not acquired")

query_str, args, result_columns = compile_query(query, self._dialect)
rows = await self._connection.fetch(query_str, *args)
column_maps = create_column_maps(result_columns)
return [Record(row, result_columns, self._dialect, column_maps) for row in rows]
raise NotImplementedError() # pragma: no cover

async def fetch_one(self, query: ClauseElement) -> typing.Optional["Record"]:
async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterface]:
raise NotImplementedError() # pragma: no cover

async def fetch_val(
Expand Down Expand Up @@ -107,12 +133,3 @@ async def commit(self) -> None:

async def rollback(self) -> None:
raise NotImplementedError() # pragma: no cover


class Record(Sequence):
@property
def _mapping(self) -> typing.Mapping:
raise NotImplementedError() # pragma: no cover

def __getitem__(self, key: typing.Any) -> typing.Any:
raise NotImplementedError() # pragma: no cover

0 comments on commit a86edfa

Please sign in to comment.