Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allow iterate over custom num of records #473

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 5 additions & 1 deletion databases/backends/aiopg.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
cursor.close()

async def iterate(
self, query: ClauseElement
self, query: ClauseElement, n: int = None
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n: Optional[int] = None

) -> typing.AsyncGenerator[typing.Any, None]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, context = self._compile(query)
Expand All @@ -195,6 +195,10 @@ async def iterate(
Row._default_key_style,
row,
)
if n is not None:
n -= 1
if n == 0:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe use if n <= 0 in case if negative n was passed

break
finally:
cursor.close()

Expand Down
6 changes: 5 additions & 1 deletion databases/backends/asyncmy.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
await cursor.close()

async def iterate(
self, query: ClauseElement
self, query: ClauseElement, n: int = None
) -> typing.AsyncGenerator[typing.Any, None]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, context = self._compile(query)
Expand All @@ -185,6 +185,10 @@ async def iterate(
Row._default_key_style,
row,
)
if n is not None:
n -= 1
if n == 0:
break
finally:
await cursor.close()

Expand Down
6 changes: 5 additions & 1 deletion databases/backends/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
await cursor.close()

async def iterate(
self, query: ClauseElement
self, query: ClauseElement, n: int = None
) -> typing.AsyncGenerator[typing.Any, None]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, context = self._compile(query)
Expand All @@ -185,6 +185,10 @@ async def iterate(
Row._default_key_style,
row,
)
if n is not None:
n -= 1
if n == 0:
break
finally:
await cursor.close()

Expand Down
7 changes: 5 additions & 2 deletions databases/backends/postgres.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import logging
import typing
from collections.abc import Sequence

import asyncpg
from sqlalchemy.dialects.postgresql import pypostgresql
Expand Down Expand Up @@ -227,13 +226,17 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
await self._connection.execute(single_query, *args)

async def iterate(
self, query: ClauseElement
self, query: ClauseElement, n: int = None
) -> typing.AsyncGenerator[typing.Any, None]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, result_columns = self._compile(query)
column_maps = self._create_column_maps(result_columns)
async for row in self._connection.cursor(query_str, *args):
yield Record(row, result_columns, self._dialect, column_maps)
if n is not None:
n -= 1
if n == 0:
break

def transaction(self) -> TransactionBackend:
return PostgresTransaction(connection=self)
Expand Down
6 changes: 5 additions & 1 deletion databases/backends/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
await self.execute(single_query)

async def iterate(
self, query: ClauseElement
self, query: ClauseElement, n: int = None
) -> typing.AsyncGenerator[typing.Any, None]:
assert self._connection is not None, "Connection is not acquired"
query_str, args, context = self._compile(query)
Expand All @@ -155,6 +155,10 @@ async def iterate(
Row._default_key_style,
row,
)
if n is not None:
n -= 1
if n == 0:
break

def transaction(self) -> TransactionBackend:
return SQLiteTransaction(self)
Expand Down
14 changes: 10 additions & 4 deletions databases/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,13 @@ async def execute_many(
return await connection.execute_many(query, values)

async def iterate(
self, query: typing.Union[ClauseElement, str], values: dict = None
self,
query: typing.Union[ClauseElement, str],
values: dict = None,
n: int = None,
) -> typing.AsyncGenerator[typing.Mapping, None]:
async with self.connection() as connection:
async for record in connection.iterate(query, values):
async for record in connection.iterate(query, values, n):
yield record

def _new_connection(self) -> "Connection":
Expand Down Expand Up @@ -307,12 +310,15 @@ async def execute_many(
await self._connection.execute_many(queries)

async def iterate(
self, query: typing.Union[ClauseElement, str], values: dict = None
self,
query: typing.Union[ClauseElement, str],
values: dict = None,
n: int = None,
) -> typing.AsyncGenerator[typing.Any, None]:
built_query = self._build_query(query, values)
async with self.transaction():
async with self._query_lock:
async for record in self._connection.iterate(built_query):
async for record in self._connection.iterate(built_query, n):
yield record

def transaction(
Expand Down
2 changes: 1 addition & 1 deletion databases/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ async def execute_many(self, queries: typing.List[ClauseElement]) -> None:
raise NotImplementedError() # pragma: no cover

async def iterate(
self, query: ClauseElement
self, query: ClauseElement, n: int = None
) -> typing.AsyncGenerator[typing.Mapping, None]:
raise NotImplementedError() # pragma: no cover
# mypy needs async iterators to contain a `yield`
Expand Down
11 changes: 11 additions & 0 deletions tests/test_databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,17 @@ async def test_queries(database_url):
assert iterate_results[2]["text"] == "example3"
assert iterate_results[2]["completed"] == True

# iterate() with custom number of records
query = notes.select()
iterate_results = []
async for result in database.iterate(query=query, n=2):
iterate_results.append(result)
assert len(iterate_results) == 2
assert iterate_results[0]["text"] == "example1"
assert iterate_results[0]["completed"] == True
assert iterate_results[1]["text"] == "example2"
assert iterate_results[1]["completed"] == False


@pytest.mark.parametrize("database_url", DATABASE_URLS)
@mysql_versions
Expand Down