Skip to content

Commit

Permalink
S01E05
Browse files Browse the repository at this point in the history
  • Loading branch information
ansipunk committed Mar 3, 2024
1 parent d7ff8e8 commit ddd8aaa
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
1 change: 0 additions & 1 deletion databases/backends/common/records.py
Expand Up @@ -4,7 +4,6 @@
from sqlalchemy.engine.row import Row as SQLRow
from sqlalchemy.sql.compiler import _CompileLabel
from sqlalchemy.sql.schema import Column
from sqlalchemy.sql.sqltypes import JSON
from sqlalchemy.types import TypeEngine

from databases.interfaces import Record as RecordInterface
Expand Down
13 changes: 7 additions & 6 deletions databases/backends/psycopg.py
Expand Up @@ -2,6 +2,7 @@

import psycopg
import psycopg_pool
from psycopg.rows import namedtuple_row
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.dialects.postgresql.psycopg import PGDialect_psycopg
from sqlalchemy.sql import ClauseElement
Expand Down Expand Up @@ -58,12 +59,11 @@ def connection(self) -> "PsycopgConnection":
class PsycopgConnection(ConnectionBackend):
_database: PsycopgBackend
_dialect: Dialect
_connection: typing.Optional[psycopg.AsyncConnection]
_connection: typing.Optional[psycopg.AsyncConnection] = None

def __init__(self, database: PsycopgBackend, dialect: Dialect) -> None:
self._database = database
self._dialect = dialect
self._connection = None

async def acquire(self) -> None:
if self._connection is not None:
Expand All @@ -74,6 +74,7 @@ async def acquire(self) -> None:

# TODO: Add configurable timeouts
self._connection = await self._database._pool.getconn()
await self._connection.set_autocommit(True)

async def release(self) -> None:
if self._connection is None:
Expand All @@ -88,7 +89,7 @@ async def fetch_all(self, query: ClauseElement) -> typing.List[RecordInterface]:

query_str, args, result_columns = self._compile(query)

async with self._connection.cursor() as cursor:
async with self._connection.cursor(row_factory=namedtuple_row) as cursor:
await cursor.execute(query_str, args)
rows = await cursor.fetchall()

Expand All @@ -101,7 +102,7 @@ async def fetch_one(self, query: ClauseElement) -> typing.Optional[RecordInterfa

query_str, args, result_columns = self._compile(query)

async with self._connection.cursor() as cursor:
async with self._connection.cursor(row_factory=namedtuple_row) as cursor:
await cursor.execute(query_str, args)
row = await cursor.fetchone()

Expand All @@ -127,7 +128,7 @@ async def execute(self, query: ClauseElement) -> typing.Any:

query_str, args, _ = self._compile(query)

async with self._connection.cursor() as cursor:
async with self._connection.cursor(row_factory=namedtuple_row) as cursor:
await cursor.execute(query_str, args)

async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
Expand All @@ -144,7 +145,7 @@ async def iterate(
query_str, args, result_columns = self._compile(query)
column_maps = create_column_maps(result_columns)

async with self._connection.cursor() as cursor:
async with self._connection.cursor(row_factory=namedtuple_row) as cursor:
await cursor.execute(query_str, args)

while True:
Expand Down

0 comments on commit ddd8aaa

Please sign in to comment.