Skip to content

Commit

Permalink
Rename PostgresBackend to AsyncpgBackend
Browse files Browse the repository at this point in the history
  • Loading branch information
ansipunk committed Mar 2, 2024
1 parent 37c450c commit 8cbcccb
Show file tree
Hide file tree
Showing 5 changed files with 154 additions and 36 deletions.
33 changes: 10 additions & 23 deletions databases/backends/postgres.py → databases/backends/asyncpg.py
Expand Up @@ -7,7 +7,7 @@
from sqlalchemy.sql.ddl import DDLElement

from databases.backends.common.records import Record, create_column_maps
from databases.backends.dialects.psycopg import dialect as psycopg_dialect
from databases.backends.dialects.psycopg import get_dialect
from databases.core import LOG_EXTRA, DatabaseURL
from databases.interfaces import (
ConnectionBackend,
Expand All @@ -19,28 +19,15 @@
logger = logging.getLogger("databases")


class PostgresBackend(DatabaseBackend):
class AsyncpgBackend(DatabaseBackend):
def __init__(
self, database_url: typing.Union[DatabaseURL, str], **options: typing.Any
) -> None:
self._database_url = DatabaseURL(database_url)
self._options = options
self._dialect = self._get_dialect()
self._dialect = get_dialect()
self._pool = None

def _get_dialect(self) -> Dialect:
dialect = psycopg_dialect(paramstyle="pyformat")

dialect.implicit_returning = True
dialect.supports_native_enum = True
dialect.supports_smallserial = True # 9.2+
dialect._backslash_escapes = False
dialect.supports_sane_multi_rowcount = True # psycopg 2.0.9+
dialect._has_native_hstore = True
dialect.supports_native_decimal = True

return dialect

def _get_connection_kwargs(self) -> dict:
url_options = self._database_url.options

Expand Down Expand Up @@ -78,12 +65,12 @@ async def disconnect(self) -> None:
await self._pool.close()
self._pool = None

def connection(self) -> "PostgresConnection":
return PostgresConnection(self, self._dialect)
def connection(self) -> "AsyncpgConnection":
return AsyncpgConnection(self, self._dialect)


class PostgresConnection(ConnectionBackend):
def __init__(self, database: PostgresBackend, dialect: Dialect):
class AsyncpgConnection(ConnectionBackend):
def __init__(self, database: AsyncpgBackend, dialect: Dialect):
self._database = database
self._dialect = dialect
self._connection: typing.Optional[asyncpg.connection.Connection] = None
Expand Down Expand Up @@ -159,7 +146,7 @@ async def iterate(
yield Record(row, result_columns, self._dialect, column_maps)

def transaction(self) -> TransactionBackend:
return PostgresTransaction(connection=self)
return AsyncpgTransaction(connection=self)

def _compile(self, query: ClauseElement) -> typing.Tuple[str, list, tuple]:
compiled = query.compile(
Expand Down Expand Up @@ -197,8 +184,8 @@ def raw_connection(self) -> asyncpg.connection.Connection:
return self._connection


class PostgresTransaction(TransactionBackend):
def __init__(self, connection: PostgresConnection):
class AsyncpgTransaction(TransactionBackend):
def __init__(self, connection: AsyncpgConnection):
self._connection = connection
self._transaction: typing.Optional[asyncpg.transaction.Transaction] = None

Expand Down
11 changes: 10 additions & 1 deletion databases/backends/dialects/psycopg.py
Expand Up @@ -43,4 +43,13 @@ class PGDialect_psycopg(PGDialect):
execution_ctx_cls = PGExecutionContext_psycopg


dialect = PGDialect_psycopg
def get_dialect() -> PGDialect_psycopg:
dialect = PGDialect_psycopg(paramstyle="pyformat")
dialect.implicit_returning = True
dialect.supports_native_enum = True
dialect.supports_smallserial = True # 9.2+
dialect._backslash_escapes = False
dialect.supports_sane_multi_rowcount = True # psycopg 2.0.9+
dialect._has_native_hstore = True
dialect.supports_native_decimal = True
return dialect
118 changes: 118 additions & 0 deletions databases/backends/psycopg.py
@@ -0,0 +1,118 @@
import typing
from collections.abc import Sequence

import psycopg_pool
from sqlalchemy.sql import ClauseElement

from databases.backends.dialects.psycopg import get_dialect
from databases.core import DatabaseURL
from databases.interfaces import (
ConnectionBackend,
DatabaseBackend,
TransactionBackend,
)


class PsycopgBackend(DatabaseBackend):
def __init__(
self, database_url: typing.Union[DatabaseURL, str], **options: typing.Any
) -> None:
self._database_url = DatabaseURL(database_url)
self._options = options
self._dialect = get_dialect()
self._pool: typing.Optional[psycopg_pool.AsyncConnectionPool] = None

async def connect(self) -> None:
if self._pool is not None:
return

self._pool = psycopg_pool.AsyncConnectionPool(
self._database_url.url, open=False, **self._options)
await self._pool.open()

async def disconnect(self) -> None:
if self._pool is None:
return

await self._pool.close()
self._pool = None

def connection(self) -> "PsycopgConnection":
return PsycopgConnection(self)


class PsycopgConnection(ConnectionBackend):
def __init__(self, database: PsycopgBackend) -> None:
self._database = database

async def acquire(self) -> None:
if self._connection is not None:
return

if self._database._pool is None:
raise RuntimeError("PsycopgBackend is not running")

# TODO: Add configurable timeouts
self._connection = await self._database._pool.getconn()

async def release(self) -> None:
if self._connection is None:
return

await self._database._pool.putconn(self._connection)
self._connection = None

async def fetch_all(self, query: ClauseElement) -> typing.List["Record"]:
raise NotImplementedError() # pragma: no cover

async def fetch_one(self, query: ClauseElement) -> typing.Optional["Record"]:
raise NotImplementedError() # pragma: no cover

async def fetch_val(
self, query: ClauseElement, column: typing.Any = 0
) -> typing.Any:
row = await self.fetch_one(query)
return None if row is None else row[column]

async def execute(self, query: ClauseElement) -> typing.Any:
raise NotImplementedError() # pragma: no cover

async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
raise NotImplementedError() # pragma: no cover

async def iterate(
self, query: ClauseElement
) -> typing.AsyncGenerator[typing.Mapping, None]:
raise NotImplementedError() # pragma: no cover
# mypy needs async iterators to contain a `yield`
# https://github.com/python/mypy/issues/5385#issuecomment-407281656
yield True # pragma: no cover

def transaction(self) -> "TransactionBackend":
raise NotImplementedError() # pragma: no cover

@property
def raw_connection(self) -> typing.Any:
raise NotImplementedError() # pragma: no cover


class PsycopgTransaction(TransactionBackend):
async def start(
self, is_root: bool, extra_options: typing.Dict[typing.Any, typing.Any]
) -> None:
raise NotImplementedError() # pragma: no cover

async def commit(self) -> None:
raise NotImplementedError() # pragma: no cover

async def rollback(self) -> None:
raise NotImplementedError() # pragma: no cover


class Record(Sequence):
@property
def _mapping(self) -> typing.Mapping:
raise NotImplementedError() # pragma: no cover

def __getitem__(self, key: typing.Any) -> typing.Any:
raise NotImplementedError() # pragma: no cover
8 changes: 6 additions & 2 deletions databases/core.py
Expand Up @@ -43,12 +43,16 @@

class Database:
SUPPORTED_BACKENDS = {
"postgresql": "databases.backends.postgres:PostgresBackend",
"postgres": "databases.backends.asyncpg:AsyncpgBackend",
"postgresql": "databases.backends.asyncpg:AsyncpgBackend",
"postgresql+aiopg": "databases.backends.aiopg:AiopgBackend",
"postgres": "databases.backends.postgres:PostgresBackend",
"postgresql+asyncpg": "databases.backends.asyncpg:AsyncpgBackend",
"postgresql+psycopg": "databases.backends.psycopg:PsycopgBackend",
"mysql": "databases.backends.mysql:MySQLBackend",
"mysql+aiomysql": "databases.backends.asyncmy:MySQLBackend",
"mysql+asyncmy": "databases.backends.asyncmy:AsyncMyBackend",
"sqlite": "databases.backends.sqlite:SQLiteBackend",
"sqlite+aiosqlite": "databases.backends.sqlite:SQLiteBackend",
}

_connection_map: "weakref.WeakKeyDictionary[asyncio.Task, 'Connection']"
Expand Down
20 changes: 10 additions & 10 deletions tests/test_connection_options.py
Expand Up @@ -6,7 +6,7 @@
import pytest

from databases.backends.aiopg import AiopgBackend
from databases.backends.postgres import PostgresBackend
from databases.backends.asyncpg import AsyncpgBackend
from databases.core import DatabaseURL
from tests.test_databases import DATABASE_URLS, async_adapter

Expand All @@ -19,7 +19,7 @@


def test_postgres_pool_size():
backend = PostgresBackend("postgres://localhost/database?min_size=1&max_size=20")
backend = AsyncpgBackend("postgres://localhost/database?min_size=1&max_size=20")
kwargs = backend._get_connection_kwargs()
assert kwargs == {"min_size": 1, "max_size": 20}

Expand All @@ -29,43 +29,43 @@ async def test_postgres_pool_size_connect():
for url in DATABASE_URLS:
if DatabaseURL(url).dialect != "postgresql":
continue
backend = PostgresBackend(url + "?min_size=1&max_size=20")
backend = AsyncpgBackend(url + "?min_size=1&max_size=20")
await backend.connect()
await backend.disconnect()


def test_postgres_explicit_pool_size():
backend = PostgresBackend("postgres://localhost/database", min_size=1, max_size=20)
backend = AsyncpgBackend("postgres://localhost/database", min_size=1, max_size=20)
kwargs = backend._get_connection_kwargs()
assert kwargs == {"min_size": 1, "max_size": 20}


def test_postgres_ssl():
backend = PostgresBackend("postgres://localhost/database?ssl=true")
backend = AsyncpgBackend("postgres://localhost/database?ssl=true")
kwargs = backend._get_connection_kwargs()
assert kwargs == {"ssl": True}


def test_postgres_ssl_verify_full():
backend = PostgresBackend("postgres://localhost/database?ssl=verify-full")
backend = AsyncpgBackend("postgres://localhost/database?ssl=verify-full")
kwargs = backend._get_connection_kwargs()
assert kwargs == {"ssl": "verify-full"}


def test_postgres_explicit_ssl():
backend = PostgresBackend("postgres://localhost/database", ssl=True)
backend = AsyncpgBackend("postgres://localhost/database", ssl=True)
kwargs = backend._get_connection_kwargs()
assert kwargs == {"ssl": True}


def test_postgres_explicit_ssl_verify_full():
backend = PostgresBackend("postgres://localhost/database", ssl="verify-full")
backend = AsyncpgBackend("postgres://localhost/database", ssl="verify-full")
kwargs = backend._get_connection_kwargs()
assert kwargs == {"ssl": "verify-full"}


def test_postgres_no_extra_options():
backend = PostgresBackend("postgres://localhost/database")
backend = AsyncpgBackend("postgres://localhost/database")
kwargs = backend._get_connection_kwargs()
assert kwargs == {}

Expand All @@ -74,7 +74,7 @@ def test_postgres_password_as_callable():
def gen_password():
return "Foo"

backend = PostgresBackend(
backend = AsyncpgBackend(
"postgres://:password@localhost/database", password=gen_password
)
kwargs = backend._get_connection_kwargs()
Expand Down

0 comments on commit 8cbcccb

Please sign in to comment.